diff --git a/.github/workflows/integrate_llvm.yml b/.github/workflows/integrate_llvm.yml new file mode 100644 index 000000000..85ece9095 --- /dev/null +++ b/.github/workflows/integrate_llvm.yml @@ -0,0 +1,112 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Copyright (c) 2024. + +name: Auto Integrate LLVM +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + schedule: + # At minute 0 past hour 1. (see https://crontab.guru) + - cron: '00 01 * * *' + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + update-dep: + + name: "Integrate LLVM and send PR" + + runs-on: ubuntu-latest + + permissions: + contents: write + pull-requests: write + + steps: + - name: "Check out repository" + uses: actions/checkout@v4.2.2 + + - name: Cache LLVM clone + id: cache-llvm + uses: actions/cache@v4 + with: + path: /tmp/llvm-project + key: cache-llvm-project + + - name: "Get filtered llvm-project" + shell: bash + id: get-llvm-project + run: | + + sudo apt install git-filter-repo + + HERE=$(pwd) + + pushd /tmp + echo "cache-hit ${{ steps.cache-llvm.outputs.cache-hit }}" + # https://github.com/actions/cache/issues/1566 + if [ "${{ steps.cache-llvm.outputs.cache-hit }}" == "" ]; then + git clone https://github.com/llvm/llvm-project.git -v + fi + + pushd llvm-project + + git pull origin main + echo "LLVM_SHA_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT + bash $HERE/filter-llvm.sh + + popd + popd + + - name: "Rebase on top of llvm-project" + shell: bash + id: rebase-llvm-project + run: | + + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + git pull /tmp/llvm-project main --rebase + + - name: Generate token + uses: actions/create-github-app-token@v1 + id: generate-token + with: + app-id: ${{ secrets.BUMP_LLVM_CREATE_PR_APP_ID }} + private-key: ${{ secrets.BUMP_LLVM_CREATE_PR_APP_PRIVATE_KEY }} + + - name: "Create Pull Request" + id: cpr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ steps.generate-token.outputs.token }} + commit-message: "[LLVM] Integrate to ${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + title: "[LLVM] Integrate to ${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + body: "Integrate LLVM to https://github.com/llvm/llvm-project/commit/${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + author: 'github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>' + base: main + branch: update-llvm + delete-branch: true + + - name: Enable auto-merge + if: steps.cpr.outputs.pull-request-operation == 'created' + uses: peter-evans/enable-pull-request-automerge@v3 + with: + token: ${{ steps.generate-token.outputs.token }} + pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} + + - name: Auto approve + if: steps.cpr.outputs.pull-request-operation == 'created' + run: gh pr review --approve "${{ steps.cpr.outputs.pull-request-number }}" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..1b17bee17 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +llvm-project \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..9bdede378 --- /dev/null +++ b/README.md @@ -0,0 +1,45 @@ +# TL;DR: + +In this repo: + +``` +$ git clone git@github.com:llvm/llvm-project.git /tmp/llvm-project +$ pushd llvm-project +$ IFS= read -r -d '' x [-!#-'*+/-9=?A-Z^-~]+(\.[-!#-'*+/-9=?A-Z^-~]+)*|\"([]!#-[^-~ \t]|(\\[\t -~]))+\")@(?P[-!#-'*+/-9=?A-Z^-~]+(\.[-!#-'*+/-9=?A-Z^-~]+)*|\[[\t -Z^-~]*])"); +rr = re.compile(r"@(?P[a-z\d](?:[a-z\d]|-(?=[a-z\d])){0,38})(?=\b)"); + +message = message.decode("utf-8"); +message = r.sub(r"\g$$$\g", message); +message = rr.sub(r"\g", message); +message = message.replace("$$$", "@"); + +return message.encode("utf-8") diff --git a/filter-llvm.sh b/filter-llvm.sh new file mode 100755 index 000000000..3ba6ff8d9 --- /dev/null +++ b/filter-llvm.sh @@ -0,0 +1,24 @@ +#!/usr/bin/env bash + +#rm -rf .git +#rm -rf mlir +#git init +#git checkout -b main +#git remote add -f origin /tmp/llvm-project +#git pull origin main + +IFS= read -r -d '' x +#endif #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" @@ -29,19 +39,51 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" +#include "mlir-c/Rewrite.h" + +// The 'mlir' Python package is relocatable and supports co-existing in multiple +// projects. Each project must define its outer package prefix with this define +// in order to provide proper isolation and local name resolution. +// The default is for the upstream "import mlir" package layout. +// Note that this prefix is internally stringified, allowing it to be passed +// unquoted on the compiler command line without shell quote escaping issues. +#ifndef MLIR_PYTHON_PACKAGE_PREFIX +#define MLIR_PYTHON_PACKAGE_PREFIX mlir. +#endif -#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" +// Makes a fully-qualified name relative to the MLIR python package. +#define MLIR_PYTHON_STRINGIZE(s) #s +#define MLIR_PYTHON_STRINGIZE_ARG(arg) MLIR_PYTHON_STRINGIZE(arg) +#define MAKE_MLIR_PYTHON_QUALNAME(local) \ + MLIR_PYTHON_STRINGIZE_ARG(MLIR_PYTHON_PACKAGE_PREFIX) local + +#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR \ + MAKE_MLIR_PYTHON_QUALNAME("ir.AffineExpr._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_AFFINE_MAP \ + MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_ATTRIBUTE \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_BLOCK MAKE_MLIR_PYTHON_QUALNAME("ir.Block._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_CONTEXT \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Context._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY \ + MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry._CAPIPtr") #define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE \ - "mlir.execution_engine.ExecutionEngine._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr" + MAKE_MLIR_PYTHON_QUALNAME("execution_engine.ExecutionEngine._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_INTEGER_SET \ + MAKE_MLIR_PYTHON_QUALNAME("ir.IntegerSet._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_LOCATION \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Location._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_MODULE \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Module._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_OPERATION \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Operation._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPE MAKE_MLIR_PYTHON_QUALNAME("ir.Type._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_PASS_MANAGER \ + MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPEID \ + MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -67,10 +109,43 @@ * delineated). */ #define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate" +/** Attribute on MLIR Python objects that expose a function for downcasting the + * corresponding Python object to a subclass if the object is in fact a subclass + * (Concrete or mlir_type_subclass) of ir.Type. The signature of the function + * is: def maybe_downcast(self) -> object where the resulting object will + * (possibly) be an instance of the subclass. + */ +#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast" + +/** Attribute on main C extension module (_mlir) that corresponds to the + * type caster registration binding. The signature of the function is: + * def register_type_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a typeCaster (register_type_caster is meant to be used as a + * decorator from python), and where replace indicates the typeCaster should + * replace any existing registered type casters (such as those for upstream + * ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type) + * -> SubClassTypeT where SubClassTypeT indicates the result should be a + * subclass (inherit from) ir.Type. + */ +#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster" + +/** Attribute on main C extension module (_mlir) that corresponds to the + * value caster registration binding. The signature of the function is: + * def register_value_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a valueCaster (register_value_caster is meant to be used as + * a decorator, from python), and where replace indicates the valueCaster should + * replace any existing registered value casters. The interface of the + * valueCaster is: def value_caster(ir.Value) -> SubClassValueT where + * SubClassValueT indicates the result should be a subclass (inherit from) + * ir.Value. + */ +#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster" + /// Gets a void* from a wrapped struct. Needed because const cast is different /// between C/C++. #ifdef __cplusplus -#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) const_cast(object.ptr) +#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) \ + (const_cast((object).ptr)) #else #define MLIR_PYTHON_GET_WRAPPED_POINTER(object) (void *)(object.ptr) #endif @@ -117,6 +192,23 @@ static inline MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule) { return attr; } +/** Creates a capsule object encapsulating the raw C-API MlirBlock. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +static inline PyObject *mlirPythonBlockToCapsule(MlirBlock block) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(block), + MLIR_PYTHON_CAPSULE_BLOCK, NULL); +} + +/** Extracts an MlirBlock from a capsule as produced from + * mlirPythonBlockToCapsule. If the capsule is not of the right type, then + * a null pass manager is returned (as checked via mlirBlockIsNull). */ +static inline MlirBlock mlirPythonCapsuleToBlock(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_BLOCK); + MlirBlock block = {ptr}; + return block; +} + /** Creates a capsule object encapsulating the raw C-API MlirContext. * The returned capsule does not extend or affect ownership of any Python * objects that reference the context in any way. @@ -135,6 +227,28 @@ static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) { return context; } +/** Creates a capsule object encapsulating the raw C-API MlirDialectRegistry. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the context in any way. + */ +static inline PyObject * +mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry) { + return PyCapsule_New(registry.ptr, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY, + NULL); +} + +/** Extracts an MlirDialectRegistry from a capsule as produced from + * mlirPythonDialectRegistryToCapsule. If the capsule is not of the right type, + * then a null context is returned (as checked via mlirContextIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirDialectRegistry +mlirPythonCapsuleToDialectRegistry(PyObject *capsule) { + void *ptr = + PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY); + MlirDialectRegistry registry = {ptr}; + return registry; +} + /** Creates a capsule object encapsulating the raw C-API MlirLocation. * The returned capsule does not extend or affect ownership of any Python * objects that reference the location in any way. */ @@ -171,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) { return module; } +/** Creates a capsule object encapsulating the raw C-API + * MlirFrozenRewritePatternSet. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +static inline PyObject * +mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm), + MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL); +} + +/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from + * mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the + * right type, then a null module is returned. */ +static inline MlirFrozenRewritePatternSet +mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER); + MlirFrozenRewritePatternSet pm = {ptr}; + return pm; +} + /** Creates a capsule object encapsulating the raw C-API MlirPassManager. * The returned capsule does not extend or affect ownership of any Python * objects that reference the module in any way. */ @@ -207,6 +341,25 @@ static inline MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule) { return op; } +/** Creates a capsule object encapsulating the raw C-API MlirTypeID. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the type in any way. + */ +static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID), + MLIR_PYTHON_CAPSULE_TYPEID, NULL); +} + +/** Extracts an MlirTypeID from a capsule as produced from + * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then + * a null type is returned (as checked via mlirTypeIDIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID); + MlirTypeID typeID = {ptr}; + return typeID; +} + /** Creates a capsule object encapsulating the raw C-API MlirType. * The returned capsule does not extend or affect ownership of any Python * objects that reference the type in any way. @@ -285,6 +438,25 @@ mlirPythonCapsuleToExecutionEngine(PyObject *capsule) { return jit; } +/** Creates a capsule object encapsulating the raw C-API MlirValue. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the operation in any way. + */ +static inline PyObject *mlirPythonValueToCapsule(MlirValue value) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(value), + MLIR_PYTHON_CAPSULE_VALUE, NULL); +} + +/** Extracts an MlirValue from a capsule as produced from + * mlirPythonValueToCapsule. If the capsule is not of the right type, then a + * null type is returned (as checked via mlirValueIsNull). In such a case, the + * Python APIs will have already set an error. */ +static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_VALUE); + MlirValue value = {ptr}; + return value; +} + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 29df9cf60..1d0edf9ea 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -16,12 +16,22 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" #include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { #endif +/// Returns an empty attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void); + +//===----------------------------------------------------------------------===// +// Location attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsALocation(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// @@ -36,6 +46,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map); /// Returns the affine map wrapped in the given affine map attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an AffineMap attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -55,6 +68,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an Array attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -80,6 +96,9 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name); +/// Returns the typeID of a Dictionary attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -106,6 +125,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, /// the value as double. MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr); +/// Returns the typeID of a Float attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -122,9 +144,20 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value); /// Returns the value stored in the given integer attribute, assuming the value -/// fits into a 64-bit integer. +/// is of signless type and fits into a signed 64-bit integer. MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr); +/// Returns the value stored in the given integer attribute, assuming the value +/// is of signed type and fits into a signed 64-bit integer. +MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr); + +/// Returns the value stored in the given integer attribute, assuming the value +/// is of unsigned type and fits into an unsigned 64-bit integer. +MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); + +/// Returns the typeID of an Integer attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -145,6 +178,17 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr); /// Checks whether the given attribute is an integer set attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +/// Creates an integer set attribute wrapping the given set. The attribute +/// belongs to the same context as the integer set. +MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set); + +/// Returns the integer set wrapped in the given integer set attribute. +MLIR_CAPI_EXPORTED MlirIntegerSet +mlirIntegerSetAttrGetValue(MlirAttribute attr); + +/// Returns the typeID of an IntegerSet attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -168,6 +212,9 @@ mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr); /// the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr); +/// Returns the typeID of an Opaque attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -189,6 +236,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type, /// long as the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a String attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -222,6 +272,13 @@ mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an SymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void); + +/// Creates a DisctinctAttr with the referenced attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirDisctinctAttrCreate(MlirAttribute referencedAttr); + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -253,6 +310,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type); /// Returns the type stored in the given type attribute. MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a Type attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -263,6 +323,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr); /// Creates a unit attribute in the given context. MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx); +/// Returns the typeID of a Unit attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -285,6 +348,63 @@ mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs); /// shaped type and use its sizes to build a multi-dimensional index. MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// Dense array attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void); + +/// Checks whether the given attribute is a dense array attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr); + +/// Create a dense array attribute with the given elements. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, + intptr_t size, + int const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, + intptr_t size, + int8_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, + intptr_t size, + int16_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, + intptr_t size, + int32_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, + intptr_t size, + int64_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, + intptr_t size, + float const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, + intptr_t size, + double const *values); + +/// Get the size of a dense array. +MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr); + +/// Get an element of a dense array. +MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, + intptr_t pos); + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// @@ -298,11 +418,31 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); +/// Returns the typeID of an DenseIntOrFPElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void); + /// Creates a dense elements attribute with the given Shaped type and elements /// in the same context as the type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( MlirType shapedType, intptr_t numElements, MlirAttribute const *elements); +/// Creates a dense elements attribute with the given Shaped type and elements +/// populated from a packed, row-major opaque buffer of contents. +/// +/// The format of the raw buffer is a densely packed array of values that +/// can be bitcast to the storage format of the element type specified. +/// Types that are not byte aligned will be: +/// - For bitwidth > 1: Rounded up to the next byte. +/// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to +/// the linear order of the shape type from MSB to LSB, padded to on the +/// right. +/// +/// A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255) +/// will be interpreted as a splat. User code should be prepared for additional, +/// conformant patterns to be identified as splats in the future. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet( + MlirType shapedType, size_t rawBufferSize, const void *rawBuffer); + /// Creates a dense elements attribute with the given Shaped type containing a /// single replicated element (splat). MLIR_CAPI_EXPORTED MlirAttribute @@ -310,6 +450,10 @@ mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element); MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element); +MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element); @@ -327,6 +471,14 @@ mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element); /// data element type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolGet( MlirType shapedType, intptr_t numElements, const int *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get( + MlirType shapedType, intptr_t numElements, const uint8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get( + MlirType shapedType, intptr_t numElements, const int8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get( + MlirType shapedType, intptr_t numElements, const int16_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get( MlirType shapedType, intptr_t numElements, const uint32_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get( @@ -339,6 +491,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloatGet( MlirType shapedType, intptr_t numElements, const float *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet( MlirType shapedType, intptr_t numElements, const double *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloat16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); /// Creates a dense elements attribute with the given shaped type from string /// elements. @@ -361,6 +517,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED int8_t +mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED uint32_t @@ -380,6 +540,14 @@ mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr); /// contained by the given dense elements attribute. MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int16_t +mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint16_t +mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint32_t @@ -388,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint64_t +mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED double @@ -400,13 +570,92 @@ MLIR_CAPI_EXPORTED const void * mlirDenseElementsAttrGetRawData(MlirAttribute attr); //===----------------------------------------------------------------------===// -// Opaque elements attribute. +// Resource blob attributes. //===----------------------------------------------------------------------===// -// TODO: expose Dialect to the bindings and implement accessors here. +MLIR_CAPI_EXPORTED bool +mlirAttributeIsADenseResourceElements(MlirAttribute attr); + +/// Unlike the typed accessors below, constructs the attribute with a raw +/// data buffer and no type/alignment checking. Use a more strongly typed +/// accessor if possible. If dataIsMutable is false, then an immutable +/// AsmResourceBlob will be created and that passed data contents will be +/// treated as const. +/// If the deleter is non NULL, then it will be called when the data buffer +/// can no longer be accessed (passing userData to it). +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, + size_t dataAlignment, bool dataIsMutable, + void (*deleter)(void *userData, const void *data, size_t size, + size_t align), + void *userData); + +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint32_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int32_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint64_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int64_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const float *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const double *elements); -/// Checks whether the given attribute is an opaque elements attribute. -MLIR_CAPI_EXPORTED bool mlirAttributeIsAOpaqueElements(MlirAttribute attr); +/// Returns the pos-th value (flat contiguous indexing) of a specific type +/// contained by the given dense resource elements attribute. +MLIR_CAPI_EXPORTED bool +mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int8_t +mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int16_t +mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint16_t +mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int32_t +mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint32_t +mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int64_t +mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint64_t +mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED float +mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED double +mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); //===----------------------------------------------------------------------===// // Sparse elements attribute. @@ -433,6 +682,35 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +/// Returns the typeID of a SparseElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void); + +//===----------------------------------------------------------------------===// +// Strided layout attribute. +//===----------------------------------------------------------------------===// + +// Checks wheather the given attribute is a strided layout attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr); + +// Creates a strided layout attribute from given strides and offset. +MLIR_CAPI_EXPORTED MlirAttribute +mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, + const int64_t *strides); + +// Returns the offset in the given strided layout layout attribute. +MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr); + +// Returns the number of strides in the given strided layout attribute. +MLIR_CAPI_EXPORTED intptr_t +mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr); + +// Returns the pos-th stride stored in the given strided layout attribute. +MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, + intptr_t pos); + +/// Returns the typeID of a StridedLayout attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index a706c58ef..c981bfd09 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -22,6 +22,9 @@ extern "C" { // Integer types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Integer type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void); + /// Checks whether the given type is an integer type. MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type); @@ -56,6 +59,9 @@ MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type); // Index type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Index type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void); + /// Checks whether the given type is an index type. MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type); @@ -67,6 +73,125 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// +/// Checks whether the given type is a floating-point type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type); + +/// Returns the bitwidth of a floating-point type. +MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type); + +/// Returns the typeID of an Float4E2M1FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void); + +/// Checks whether the given type is an f4E2M1FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type); + +/// Creates an f4E2M1FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float6E2M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void); + +/// Checks whether the given type is an f6E2M3FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type); + +/// Creates an f6E2M3FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float6E3M2FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void); + +/// Checks whether the given type is an f6E3M2FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type); + +/// Creates an f6E3M2FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E5M2 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); + +/// Checks whether the given type is an f8E5M2 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); + +/// Creates an f8E5M2 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E4M3 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void); + +/// Checks whether the given type is an f8E4M3 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type); + +/// Creates an f8E4M3 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E4M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void); + +/// Checks whether the given type is an f8E4M3FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); + +/// Creates an f8E4M3FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E5M2FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void); + +/// Checks whether the given type is an f8E5M2FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); + +/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E4M3FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void); + +/// Checks whether the given type is an f8E4M3FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); + +/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void); + +/// Checks whether the given type is an f8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); + +/// Creates an f8E4M3B11FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E3M4 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void); + +/// Checks whether the given type is an f8E3M4 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type); + +/// Creates an f8E3M4 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx); + +/// Returns the typeID of an Float8E8M0FNU type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void); + +/// Checks whether the given type is an f8E8M0FNU type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type); + +/// Creates an f8E8M0FNU type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx); + +/// Returns the typeID of an BFloat16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); @@ -74,6 +199,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void); + /// Checks whether the given type is an f16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); @@ -81,6 +209,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void); + /// Checks whether the given type is an f32 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); @@ -88,6 +219,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx); +/// Returns the typeID of an Float64 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void); + /// Checks whether the given type is an f64 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); @@ -95,10 +229,23 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx); +/// Returns the typeID of a TF32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void); + +/// Checks whether the given type is an TF32 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type); + +/// Creates a TF32 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx); + //===----------------------------------------------------------------------===// // None type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an None type. +MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void); + /// Checks whether the given type is a None type. MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type); @@ -110,6 +257,9 @@ MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx); // Complex type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Complex type. +MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void); + /// Checks whether the given type is a Complex type. MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type); @@ -139,9 +289,12 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type); /// Checks whether the given shaped type has a static shape. MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type); -/// Checks wither the dim-th dimension of the given shaped type is dynamic. +/// Checks whether the dim-th dimension of the given shaped type is dynamic. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim); +/// Checks whether the dim-th dimension of the given shaped type is static. +MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim); + /// Returns the dim-th dimension of the given ranked shaped type. MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim); @@ -150,14 +303,34 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, /// in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// Checks whether the given shaped type dimension value is statically-sized. +MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size); + +/// Returns the value indicating a dynamic size in a shaped type. Prefer +/// mlirShapedTypeIsDynamicSize and mlirShapedTypeIsStaticSize to direct +/// comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void); + /// Checks whether the given value is used as a placeholder for dynamic strides /// and offsets in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); +/// Checks whether the given dimension value of a stride or an offset is +/// statically-sized. +MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val); + +/// Returns the value indicating a dynamic stride or offset in a shaped type. +/// Prefer mlirShapedTypeIsDynamicStrideOrOffset and +/// mlirShapedTypeIsStaticStrideOrOffset to direct comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Vector type. +MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void); + /// Checks whether the given type is a Vector type. MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type); @@ -175,6 +348,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, const int64_t *shape, MlirType elementType); +/// Creates a scalable vector type with the shape identified by its rank and +/// dimensions. A subset of dimensions may be marked as scalable via the +/// corresponding flag list, which is expected to have as many entries as the +/// rank of the vector. The vector is created in the same context as the element +/// type. +MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank, + const int64_t *shape, + const bool *scalable, + MlirType elementType); + +/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType +/// on illegal arguments, emitting appropriate diagnostics. +MLIR_CAPI_EXPORTED +MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, + const bool *scalable, + MlirType elementType); + +/// Checks whether the given vector type is scalable, i.e., has at least one +/// scalable dimension. +MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type); + +/// Checks whether the "dim"-th dimension of the given vector is scalable. +MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, + intptr_t dim); + //===----------------------------------------------------------------------===// // Ranked / Unranked Tensor type. //===----------------------------------------------------------------------===// @@ -182,23 +381,36 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, /// Checks whether the given type is a Tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type); +/// Returns the typeID of an RankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void); + /// Checks whether the given type is a ranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); +/// Returns the typeID of an UnrankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void); + /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); -/// Creates a tensor type of a fixed rank with the given shape and element type -/// in the same context as the element type. The type is owned by the context. +/// Creates a tensor type of a fixed rank with the given shape, element type, +/// and optional encoding in the same context as the element type. The type is +/// owned by the context. Tensor types without any specific encoding field +/// should assign mlirAttributeGetNull() to this parameter. MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType); + MlirType elementType, + MlirAttribute encoding); /// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on /// illegal arguments, emitting appropriate diagnostics. -MLIR_CAPI_EXPORTED MlirType -mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, - const int64_t *shape, MlirType elementType); +MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked( + MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, + MlirAttribute encoding); + +/// Gets the 'encoding' attribute from the ranked tensor type, returning a null +/// attribute if none. +MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type); /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. @@ -213,67 +425,83 @@ mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType); // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an MemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void); + /// Checks whether the given type is a MemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type); +/// Returns the typeID of an UnrankedMemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void); + /// Checks whether the given type is an UnrankedMemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); /// Creates a MemRef type with the given rank and shape, a potentially empty /// list of affine layout maps, the given memory space and element type, in the /// same context as element type. The type is owned by the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( - MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType, + intptr_t rank, + const int64_t *shape, + MlirAttribute layout, + MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace); + MlirAttribute layout, MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, /// i.e. represents a default row-major contiguous memref. The type is owned by /// the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType, - intptr_t rank, - const int64_t *shape, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + const int64_t *shape, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace); + MlirAttribute memorySpace); /// Creates an Unranked MemRef type with the given element type and in the given /// memory space. The type is owned by the context of element type. -MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace); /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( - MlirLocation loc, MlirType elementType, unsigned memorySpace); + MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); -/// Returns the number of affine layout maps in the given MemRef type. -MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); +/// Returns the layout of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type); -/// Returns the pos-th affine map of the given MemRef type. -MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, - intptr_t pos); +/// Returns the affine map of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type); /// Returns the memory space of the given MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirMemRefTypeGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); + +/// Returns the strides of the MemRef if the layout map is in strided form. +/// Both strides and offset are out params. strides must point to pre-allocated +/// memory of length equal to the rank of the memref. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset( + MlirType type, int64_t *strides, int64_t *offset); /// Returns the memory spcae of the given Unranked MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnrankedMemrefGetMemorySpace(MlirType type); //===----------------------------------------------------------------------===// // Tuple type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Tuple type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void); + /// Checks whether the given type is a tuple type. MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type); @@ -293,6 +521,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos); // Function type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Function type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void); + /// Checks whether the given type is a function type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type); @@ -317,6 +548,32 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos); +//===----------------------------------------------------------------------===// +// Opaque type. +//===----------------------------------------------------------------------===// + +/// Returns the typeID of an Opaque type. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void); + +/// Checks whether the given type is an opaque type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); + +/// Creates an opaque type in the given context associated with the dialect +/// identified by its namespace. The type contains opaque byte data of the +/// specified length (data need not be null-terminated). +MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, + MlirStringRef dialectNamespace, + MlirStringRef typeData); + +/// Returns the namespace of the dialect with which the given opaque type +/// is associated. The namespace string is owned by the context. +MLIR_CAPI_EXPORTED MlirStringRef +mlirOpaqueTypeGetDialectNamespace(MlirType type); + +/// Returns the raw data as a string reference. The data remains live as long as +/// the context in which the type lives. +MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Conversion.h b/mlir/include/mlir-c/Conversion.h index b69c41710..88c5143ad 100644 --- a/mlir/include/mlir-c/Conversion.h +++ b/mlir/include/mlir-c/Conversion.h @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_CONVERSIONS_H -#define MLIR_C_CONVERSIONS_H +#ifndef MLIR_C_CONVERSION_H +#define MLIR_C_CONVERSION_H #include "mlir-c/Support.h" #include "mlir/Conversion/Passes.capi.h.inc" -#endif // MLIR_C_CONVERSIONS_H +#endif // MLIR_C_CONVERSION_H diff --git a/mlir/include/mlir-c/Debug.h b/mlir/include/mlir-c/Debug.h new file mode 100644 index 000000000..7dad73500 --- /dev/null +++ b/mlir/include/mlir-c/Debug.h @@ -0,0 +1,43 @@ +//===-- mlir-c/Debug.h - C API for MLIR/LLVM debugging functions --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Support.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Sets the global debugging flag. +MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable); + +/// Retuns `true` if the global debugging flag is set, false otherwise. +MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled(); + +/// Sets the current debug type, similarly to `-debug-only=type` in the +/// command-line tools. Note that global debug should be enabled for any output +/// to be produced. +MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type); + +/// Sets multiple current debug types, similarly to `-debug-only=type1,type2" in +/// the command-line tools. Note that global debug should be enabled for any +/// output to be produced. +MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n); + +/// Checks if `type` is set as the current debug type. +MLIR_CAPI_EXPORTED bool mlirIsCurrentDebugType(const char *type); + +#ifdef __cplusplus +} +#endif + +#ifndef MLIR_C_DEBUG_H +#define MLIR_C_DEBUG_H +#endif // MLIR_C_DEBUG_H diff --git a/mlir/include/mlir-c/Dialect/AMDGPU.h b/mlir/include/mlir-c/Dialect/AMDGPU.h new file mode 100644 index 000000000..142044f7f --- /dev/null +++ b/mlir/include/mlir-c/Dialect/AMDGPU.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/AMDGPU.h - C API for AMDGPU dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_AMDGPU_H +#define MLIR_C_DIALECT_AMDGPU_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_AMDGPU_H diff --git a/mlir/include/mlir-c/Dialect/Standard.h b/mlir/include/mlir-c/Dialect/Arith.h similarity index 69% rename from mlir/include/mlir-c/Dialect/Standard.h rename to mlir/include/mlir-c/Dialect/Arith.h index 200962177..41e7cb2b3 100644 --- a/mlir/include/mlir-c/Dialect/Standard.h +++ b/mlir/include/mlir-c/Dialect/Arith.h @@ -1,4 +1,4 @@ -//===-- mlir-c/Dialect/Standard.h - C API for Standard dialect ----*- C -*-===// +//===-- mlir-c/Dialect/Arith.h - C API for Arith dialect ----------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. @@ -8,26 +8,26 @@ //===----------------------------------------------------------------------===// // // This header declares the C interface for registering and accessing the -// Standard dialect. A dialect should be registered with a context to make it +// Arith dialect. A dialect should be registered with a context to make it // available to users of the context. These users must load the dialect // before using any of its attributes, operations or types. Parser and pass // manager can load registered dialects automatically. // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_DIALECT_STANDARD_H -#define MLIR_C_DIALECT_STANDARD_H +#ifndef MLIR_C_DIALECT_ARITH_H +#define MLIR_C_DIALECT_ARITH_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { #endif -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Standard, std); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Arith, arith); #ifdef __cplusplus } #endif -#endif // MLIR_C_DIALECT_STANDARD_H +#endif // MLIR_C_DIALECT_ARITH_H diff --git a/mlir/include/mlir-c/Dialect/Async.h b/mlir/include/mlir-c/Dialect/Async.h new file mode 100644 index 000000000..e4e32f86a --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Async.h @@ -0,0 +1,28 @@ +//===-- mlir-c/Dialect/Async.h - C API for Async dialect ---------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_ASYNC_H +#define MLIR_C_DIALECT_ASYNC_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Async, async); + +#ifdef __cplusplus +} +#endif + +#include "mlir/Dialect/Async/Passes.capi.h.inc" + +#endif // MLIR_C_DIALECT_ASYNC_H diff --git a/mlir/include/mlir-c/Dialect/ControlFlow.h b/mlir/include/mlir-c/Dialect/ControlFlow.h new file mode 100644 index 000000000..6d5ff8c3d --- /dev/null +++ b/mlir/include/mlir-c/Dialect/ControlFlow.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/ControlFlow.h - C API for ControlFlow ------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_CONTROLFLOW_H +#define MLIR_C_DIALECT_CONTROLFLOW_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(ControlFlow, cf); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_CONTROLFLOW_H diff --git a/mlir/include/mlir-c/Dialect/EmitC.h b/mlir/include/mlir-c/Dialect/EmitC.h new file mode 100644 index 000000000..a0e3ea08a --- /dev/null +++ b/mlir/include/mlir-c/Dialect/EmitC.h @@ -0,0 +1,137 @@ +//===-- mlir-c/Dialect/EmitC.h - C API for EmitC dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_EmitC_H +#define MLIR_C_DIALECT_EmitC_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(EmitC, emitc); + +enum MlirEmitCCmpPredicate : uint64_t { + MLIR_EMITC_CMP_PREDICATE_EQ = 0, + MLIR_EMITC_CMP_PREDICATE_NE = 1, + MLIR_EMITC_CMP_PREDICATE_LT = 2, + MLIR_EMITC_CMP_PREDICATE_LE = 3, + MLIR_EMITC_CMP_PREDICATE_GT = 4, + MLIR_EMITC_CMP_PREDICATE_GE = 5, + MLIR_EMITC_CMP_PREDICATE_THREE_WAY = 6, +}; + +//===---------------------------------------------------------------------===// +// ArrayType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCArrayType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCArrayTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCArrayTypeGet(intptr_t nDims, + int64_t *shape, + MlirType elementType); + +//===---------------------------------------------------------------------===// +// LValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCLValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCLValueTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCLValueTypeGet(MlirType valueType); + +//===---------------------------------------------------------------------===// +// OpaqueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCOpaqueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCOpaqueTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCOpaqueTypeGet(MlirContext ctx, + MlirStringRef value); + +//===---------------------------------------------------------------------===// +// PointerType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCPointerType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCPointerTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCPointerTypeGet(MlirType pointee); + +//===---------------------------------------------------------------------===// +// PtrDiffTType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCPtrDiffTType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCPtrDiffTTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCPtrDiffTTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// SignedSizeTType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCSignedSizeTType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCSignedSizeTTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCSignedSizeTTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// SizeTType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCSizeTType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCSizeTTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCSizeTTypeGet(MlirContext ctx); + +//===----------------------------------------------------------------------===// +// CmpPredicate attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAEmitCCmpPredicate(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirEmitCCmpPredicateAttrGet(MlirContext ctx, enum MlirEmitCCmpPredicate val); + +MLIR_CAPI_EXPORTED enum MlirEmitCCmpPredicate +mlirEmitCCmpPredicateAttrGetValue(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCCmpPredicateAttrGetTypeID(void); + +//===----------------------------------------------------------------------===// +// Opaque attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAEmitCOpaque(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute mlirEmitCOpaqueAttrGet(MlirContext ctx, + MlirStringRef value); + +MLIR_CAPI_EXPORTED MlirStringRef +mlirEmitCOpaqueAttrGetValue(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCOpaqueAttrGetTypeID(void); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_EmitC_H diff --git a/mlir/include/mlir-c/Dialect/Func.h b/mlir/include/mlir-c/Dialect/Func.h new file mode 100644 index 000000000..001f915af --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Func.h @@ -0,0 +1,46 @@ +//===-- mlir-c/Dialect/Func.h - C API for Func dialect ------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Func dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_FUNC_H +#define MLIR_C_DIALECT_FUNC_H + +#include + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Func, func); + +/// Sets the argument attribute 'name' of an argument at index 'pos'. +/// Asserts that the operation is a FuncOp. +MLIR_CAPI_EXPORTED void mlirFuncSetArgAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, + MlirAttribute attr); + +MLIR_CAPI_EXPORTED void mlirFuncSetResultAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, + MlirAttribute attr); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_FUNC_H diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h new file mode 100644 index 000000000..321c1122c --- /dev/null +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -0,0 +1,72 @@ +//===-- mlir-c/Dialect/GPU.h - C API for GPU dialect -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_GPU_H +#define MLIR_C_DIALECT_GPU_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu); + +//===-------------------------------------------------------------------===// +// AsyncTokenType +//===-------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAGPUAsyncTokenType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// ObjectAttr +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format, + MlirStringRef objectStrRef, MlirAttribute mlirObjectProps); + +MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetWithKernels( + MlirContext mlirCtx, MlirAttribute target, uint32_t format, + MlirStringRef objectStrRef, MlirAttribute mlirObjectProps, + MlirAttribute mlirKernelsAttr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED uint32_t +mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED MlirStringRef +mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED bool +mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED bool +mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr); + +#ifdef __cplusplus +} +#endif + +#include "mlir/Dialect/GPU/Transforms/Passes.capi.h.inc" + +#endif // MLIR_C_DIALECT_GPU_H diff --git a/mlir/include/mlir-c/Dialect/IRDL.h b/mlir/include/mlir-c/Dialect/IRDL.h new file mode 100644 index 000000000..c4d6ffd98 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/IRDL.h @@ -0,0 +1,29 @@ +//===-- mlir-c/Dialect/IRDL.h - C API for IRDL --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_IRDL_H +#define MLIR_C_DIALECT_IRDL_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IRDL, irdl); + +/// Loads all IRDL dialects in the provided module, registering the dialects in +/// the module's associated context. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirLoadIRDLDialects(MlirModule module); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_IRDL_H diff --git a/mlir/include/mlir-c/Dialect/Index.h b/mlir/include/mlir-c/Dialect/Index.h new file mode 100644 index 000000000..3f05694ac --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Index.h @@ -0,0 +1,24 @@ +//===-- mlir-c/Dialect/Index.h - C API for Index dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_INDEX_H +#define MLIR_C_DIALECT_INDEX_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Index, index); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_INDEX_H diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h new file mode 100644 index 000000000..65b14254e --- /dev/null +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -0,0 +1,397 @@ +//===-- mlir-c/Dialect/LLVM.h - C API for LLVM --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_LLVM_H +#define MLIR_C_DIALECT_LLVM_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); + +/// Creates an llvm.ptr type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx, + unsigned addressSpace); + +/// Returns `true` if the type is an LLVM dialect pointer type. +MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type); + +/// Returns address space of llvm.ptr +MLIR_CAPI_EXPORTED unsigned +mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType); + +/// Creates an llmv.void type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx); + +/// Creates an llvm.array type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGet(MlirType elementType, + unsigned numElements); + +/// Returns the element type of the llvm.array type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGetElementType(MlirType type); + +/// Creates an llvm.func type. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, + MlirType const *argumentTypes, bool isVarArg); + +/// Returns the number of input types. +MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); + +/// Returns the pos-th input type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, + intptr_t pos); + +/// Returns the return type of the function type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); + +/// Returns `true` if the type is an LLVM dialect struct type. +MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); + +/// Returns `true` if the type is a literal (unnamed) LLVM struct type. +MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type); + +/// Returns the number of fields in the struct. Asserts if the struct is opaque +/// or not yet initialized. +MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type); + +/// Returns the `positions`-th field of the struct. Asserts if the struct is +/// opaque, not yet initialized or if the position is out of range. +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type, + intptr_t position); + +/// Returns `true` if the struct is packed. +MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type); + +/// Returns the identifier of the identified struct. Asserts that the struct is +/// identified, i.e., not literal. +MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type); + +/// Returns `true` is the struct is explicitly opaque (will not have a body) or +/// uninitialized (will eventually have a body). +MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type); + +/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields +/// have types not compatible with the LLVM dialect. For a graceful failure, use +/// the checked version. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + +/// Creates an LLVM literal (unnamed) struct type if possible. Emits a +/// diagnostic at the given location and returns null otherwise. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + +/// Creates an LLVM identified struct type with no body. If a struct type with +/// this name already exists in the context, returns that type. Use +/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type, +/// potentially renaming it. The body should be set separatelty by calling +/// mlirLLVMStructTypeSetBody, if it isn't set already. +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, + MlirStringRef name); + +/// Creates an LLVM identified struct type with no body and a name starting with +/// the given prefix. If a struct with the exact name as the given prefix +/// already exists, appends an unspecified suffix to the name so that the name +/// is unique in context. +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet( + MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, + MlirStringRef name); + +/// Sets the body of the identified struct if it hasn't been set yet. Returns +/// whether the operation was successful. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + +enum MlirLLVMCConv { + MlirLLVMCConvC = 0, + MlirLLVMCConvFast = 8, + MlirLLVMCConvCold = 9, + MlirLLVMCConvGHC = 10, + MlirLLVMCConvHiPE = 11, + MlirLLVMCConvAnyReg = 13, + MlirLLVMCConvPreserveMost = 14, + MlirLLVMCConvPreserveAll = 15, + MlirLLVMCConvSwift = 16, + MlirLLVMCConvCXX_FAST_TLS = 17, + MlirLLVMCConvTail = 18, + MlirLLVMCConvCFGuard_Check = 19, + MlirLLVMCConvSwiftTail = 20, + MlirLLVMCConvX86_StdCall = 64, + MlirLLVMCConvX86_FastCall = 65, + MlirLLVMCConvARM_APCS = 66, + MlirLLVMCConvARM_AAPCS = 67, + MlirLLVMCConvARM_AAPCS_VFP = 68, + MlirLLVMCConvMSP430_INTR = 69, + MlirLLVMCConvX86_ThisCall = 70, + MlirLLVMCConvPTX_Kernel = 71, + MlirLLVMCConvPTX_Device = 72, + MlirLLVMCConvSPIR_FUNC = 75, + MlirLLVMCConvSPIR_KERNEL = 76, + MlirLLVMCConvIntel_OCL_BI = 77, + MlirLLVMCConvX86_64_SysV = 78, + MlirLLVMCConvWin64 = 79, + MlirLLVMCConvX86_VectorCall = 80, + MlirLLVMCConvDUMMY_HHVM = 81, + MlirLLVMCConvDUMMY_HHVM_C = 82, + MlirLLVMCConvX86_INTR = 83, + MlirLLVMCConvAVR_INTR = 84, + MlirLLVMCConvAVR_BUILTIN = 86, + MlirLLVMCConvAMDGPU_VS = 87, + MlirLLVMCConvAMDGPU_GS = 88, + MlirLLVMCConvAMDGPU_CS = 90, + MlirLLVMCConvAMDGPU_KERNEL = 91, + MlirLLVMCConvX86_RegCall = 92, + MlirLLVMCConvAMDGPU_HS = 93, + MlirLLVMCConvMSP430_BUILTIN = 94, + MlirLLVMCConvAMDGPU_LS = 95, + MlirLLVMCConvAMDGPU_ES = 96, + MlirLLVMCConvAArch64_VectorCall = 97, + MlirLLVMCConvAArch64_SVE_VectorCall = 98, + MlirLLVMCConvWASM_EmscriptenInvoke = 99, + MlirLLVMCConvAMDGPU_Gfx = 100, + MlirLLVMCConvM68k_INTR = 101, +}; +typedef enum MlirLLVMCConv MlirLLVMCConv; + +/// Creates a LLVM CConv attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMCConvAttrGet(MlirContext ctx, + MlirLLVMCConv cconv); + +enum MlirLLVMComdat { + MlirLLVMComdatAny = 0, + MlirLLVMComdatExactMatch = 1, + MlirLLVMComdatLargest = 2, + MlirLLVMComdatNoDeduplicate = 3, + MlirLLVMComdatSameSize = 4, +}; +typedef enum MlirLLVMComdat MlirLLVMComdat; + +/// Creates a LLVM Comdat attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMComdatAttrGet(MlirContext ctx, + MlirLLVMComdat comdat); + +enum MlirLLVMLinkage { + MlirLLVMLinkageExternal = 0, + MlirLLVMLinkageAvailableExternally = 1, + MlirLLVMLinkageLinkonce = 2, + MlirLLVMLinkageLinkonceODR = 3, + MlirLLVMLinkageWeak = 4, + MlirLLVMLinkageWeakODR = 5, + MlirLLVMLinkageAppending = 6, + MlirLLVMLinkageInternal = 7, + MlirLLVMLinkagePrivate = 8, + MlirLLVMLinkageExternWeak = 9, + MlirLLVMLinkageCommon = 10, +}; +typedef enum MlirLLVMLinkage MlirLLVMLinkage; + +/// Creates a LLVM Linkage attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMLinkageAttrGet(MlirContext ctx, MlirLLVMLinkage linkage); + +/// Creates a LLVM DINullType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx); + +/// Creates a LLVM DIExpressionElem attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDIExpressionElemAttrGet(MlirContext ctx, unsigned int opcode, + intptr_t nArguments, uint64_t const *arguments); + +/// Creates a LLVM DIExpression attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIExpressionAttrGet( + MlirContext ctx, intptr_t nOperations, MlirAttribute const *operations); + +enum MlirLLVMTypeEncoding { + MlirLLVMTypeEncodingAddress = 0x1, + MlirLLVMTypeEncodingBoolean = 0x2, + MlirLLVMTypeEncodingComplexFloat = 0x31, + MlirLLVMTypeEncodingFloatT = 0x4, + MlirLLVMTypeEncodingSigned = 0x5, + MlirLLVMTypeEncodingSignedChar = 0x6, + MlirLLVMTypeEncodingUnsigned = 0x7, + MlirLLVMTypeEncodingUnsignedChar = 0x08, + MlirLLVMTypeEncodingImaginaryFloat = 0x09, + MlirLLVMTypeEncodingPackedDecimal = 0x0a, + MlirLLVMTypeEncodingNumericString = 0x0b, + MlirLLVMTypeEncodingEdited = 0x0c, + MlirLLVMTypeEncodingSignedFixed = 0x0d, + MlirLLVMTypeEncodingUnsignedFixed = 0x0e, + MlirLLVMTypeEncodingDecimalFloat = 0x0f, + MlirLLVMTypeEncodingUTF = 0x10, + MlirLLVMTypeEncodingUCS = 0x11, + MlirLLVMTypeEncodingASCII = 0x12, + MlirLLVMTypeEncodingLoUser = 0x80, + MlirLLVMTypeEncodingHiUser = 0xff, +}; +typedef enum MlirLLVMTypeEncoding MlirLLVMTypeEncoding; + +/// Creates a LLVM DIBasicType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIBasicTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, + MlirLLVMTypeEncoding encoding); + +/// Creates a self-referencing LLVM DICompositeType attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId); + +/// Creates a LLVM DICompositeType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( + MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, + MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, + uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, + MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, + MlirAttribute associated); + +/// Creates a LLVM DIDerivedType attribute. Note that `dwarfAddressSpace` is an +/// optional field, where `MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL` indicates null +/// and non-negative values indicate a value present. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIDerivedTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, + MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, + uint64_t offsetInBits, int64_t dwarfAddressSpace, MlirAttribute extraData); + +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIStringTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, + uint32_t alignInBits, MlirAttribute stringLength, + MlirAttribute stringLengthExp, MlirAttribute stringLocationExp, + MlirLLVMTypeEncoding encoding); + +/// Constant to represent std::nullopt for dwarfAddressSpace to omit the field. +#define MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL -1 + +/// Gets the base type from a LLVM DIDerivedType attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDIDerivedTypeAttrGetBaseType(MlirAttribute diDerivedType); + +/// Creates a LLVM DIFileAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIFileAttrGet(MlirContext ctx, + MlirAttribute name, + MlirAttribute directory); + +enum MlirLLVMDIEmissionKind { + MlirLLVMDIEmissionKindNone = 0, + MlirLLVMDIEmissionKindFull = 1, + MlirLLVMDIEmissionKindLineTablesOnly = 2, + MlirLLVMDIEmissionKindDebugDirectivesOnly = 3, +}; +typedef enum MlirLLVMDIEmissionKind MlirLLVMDIEmissionKind; + +enum MlirLLVMDINameTableKind { + MlirLLVMDINameTableKindDefault = 0, + MlirLLVMDINameTableKindGNU = 1, + MlirLLVMDINameTableKindNone = 2, + MlirLLVMDINameTableKindApple = 3, +}; +typedef enum MlirLLVMDINameTableKind MlirLLVMDINameTableKind; + +/// Creates a LLVM DICompileUnit attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompileUnitAttrGet( + MlirContext ctx, MlirAttribute id, unsigned int sourceLanguage, + MlirAttribute file, MlirAttribute producer, bool isOptimized, + MlirLLVMDIEmissionKind emissionKind, MlirLLVMDINameTableKind nameTableKind); + +/// Creates a LLVM DIFlags attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, + uint64_t value); + +/// Creates a LLVM DILexicalBlock attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILexicalBlockAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute file, unsigned int line, + unsigned int column); + +/// Creates a LLVM DILexicalBlockFile attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILexicalBlockFileAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute file, + unsigned int discriminator); + +/// Creates a LLVM DILocalVariableAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILocalVariableAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute name, + MlirAttribute diFile, unsigned int line, unsigned int arg, + unsigned int alignInBits, MlirAttribute diType, int64_t flags); + +/// Creates a self-referencing LLVM DISubprogramAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId); + +/// Creates a LLVM DISubprogramAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( + MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, + MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, + MlirAttribute linkageName, MlirAttribute file, unsigned int line, + unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes, + intptr_t nAnnotations, MlirAttribute const *annotations); + +/// Creates a LLVM DIAnnotation attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIAnnotationAttrGet( + MlirContext ctx, MlirAttribute name, MlirAttribute value); + +/// Gets the scope from this DISubprogramAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram); + +/// Gets the line from this DISubprogramAttr. +MLIR_CAPI_EXPORTED unsigned int +mlirLLVMDISubprogramAttrGetLine(MlirAttribute diSubprogram); + +/// Gets the scope line from this DISubprogram. +MLIR_CAPI_EXPORTED unsigned int +mlirLLVMDISubprogramAttrGetScopeLine(MlirAttribute diSubprogram); + +/// Gets the compile unit from this DISubprogram. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetCompileUnit(MlirAttribute diSubprogram); + +/// Gets the file from this DISubprogramAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetFile(MlirAttribute diSubprogram); + +/// Gets the type from this DISubprogramAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetType(MlirAttribute diSubprogram); + +/// Creates a LLVM DISubroutineTypeAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, unsigned int callingConvention, + intptr_t nTypes, MlirAttribute const *types); + +/// Creates a LLVM DIModuleAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGet( + MlirContext ctx, MlirAttribute file, MlirAttribute scope, + MlirAttribute name, MlirAttribute configMacros, MlirAttribute includePath, + MlirAttribute apinotes, unsigned int line, bool isDecl); + +/// Creates a LLVM DIImportedEntityAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIImportedEntityAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute scope, + MlirAttribute entity, MlirAttribute file, unsigned int line, + MlirAttribute name, intptr_t nElements, MlirAttribute const *elements); + +/// Gets the scope of this DIModuleAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_LLVM_H diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 56258ac19..339e63d66 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -1,25 +1,64 @@ -//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect --------*- C -*-===// +//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect -------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===----------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { #endif +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `mlirOp` is a builtin named Linalg op. +MLIR_CAPI_EXPORTED void +mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); + +MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op); + +typedef struct MlirLinalgContractionDimensions { + MlirAttribute batch; + MlirAttribute m; + MlirAttribute n; + MlirAttribute k; +} MlirLinalgContractionDimensions; + +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensions(MlirOperation op); + +MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); + +typedef struct MlirLinalgConvolutionDimensions { + MlirAttribute batch; + MlirAttribute outputImage; + MlirAttribute outputChannel; + MlirAttribute filterLoop; + MlirAttribute inputChannel; + MlirAttribute depth; + MlirAttribute strides; + MlirAttribute dilations; +} MlirLinalgConvolutionDimensions; + +MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions +mlirLinalgInferConvolutionDimensions(MlirOperation op); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus } #endif +#include "mlir/Dialect/Linalg/Passes.capi.h.inc" + #endif // MLIR_C_DIALECT_LINALG_H diff --git a/mlir/include/mlir-c/Dialect/MLProgram.h b/mlir/include/mlir-c/Dialect/MLProgram.h new file mode 100644 index 000000000..0874955e3 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/MLProgram.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/MLProgram.h - C API for MLProgram dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_MLPROGRAM_H +#define MLIR_C_DIALECT_MLPROGRAM_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MLProgram, ml_program); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_MLPROGRAM_H diff --git a/mlir/include/mlir-c/Dialect/Math.h b/mlir/include/mlir-c/Dialect/Math.h new file mode 100644 index 000000000..5269e1a1b --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Math.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Math.h - C API for Math dialect ------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Math dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_MATH_H +#define MLIR_C_DIALECT_MATH_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Math, math); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_MATH_H diff --git a/mlir/include/mlir-c/Dialect/MemRef.h b/mlir/include/mlir-c/Dialect/MemRef.h new file mode 100644 index 000000000..087a4b3f8 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/MemRef.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/MemRef.h - C API for MemRef dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// MemRef dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_MEMREF_H +#define MLIR_C_DIALECT_MEMREF_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MemRef, memref); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_MEMREF_H diff --git a/mlir/include/mlir-c/Dialect/NVGPU.h b/mlir/include/mlir-c/Dialect/NVGPU.h new file mode 100644 index 000000000..e58015a4a --- /dev/null +++ b/mlir/include/mlir-c/Dialect/NVGPU.h @@ -0,0 +1,36 @@ +//===-- mlir-c/Dialect/NVGPU.h - C API for NVGPU dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_NVGPU_H +#define MLIR_C_DIALECT_NVGPU_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu); + +//===---------------------------------------------------------------------===// +// TensorMapDescriptorType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirNVGPUTensorMapDescriptorTypeGet( + MlirContext ctx, MlirType tensorMemrefType, int swizzle, int l2promo, + int oobFill, int interleave); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_NVGPU_H diff --git a/mlir/include/mlir-c/Dialect/NVVM.h b/mlir/include/mlir-c/Dialect/NVVM.h new file mode 100644 index 000000000..cf5d9301d --- /dev/null +++ b/mlir/include/mlir-c/Dialect/NVVM.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/NVVM.h - C API for NVVM dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_NVVM_H +#define MLIR_C_DIALECT_NVVM_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVVM, nvvm); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_NVVM_H diff --git a/mlir/include/mlir-c/Dialect/OpenMP.h b/mlir/include/mlir-c/Dialect/OpenMP.h new file mode 100644 index 000000000..719ed702a --- /dev/null +++ b/mlir/include/mlir-c/Dialect/OpenMP.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/OpenMP.h - C API for OpenMP Dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_OPENM_H +#define MLIR_C_DIALECT_OPENM_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(OpenMP, omp); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_OPENM_H diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h new file mode 100644 index 000000000..6ad2e2da6 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -0,0 +1,73 @@ +//===-- mlir-c/Dialect/PDL.h - C API for PDL Dialect --------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_PDL_H +#define MLIR_C_DIALECT_PDL_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PDL, pdl); + +//===---------------------------------------------------------------------===// +// PDLType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type); + +//===---------------------------------------------------------------------===// +// AttributeType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// RangeType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType); + +MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type); + +//===---------------------------------------------------------------------===// +// TypeType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// ValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_PDL_H diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h new file mode 100644 index 000000000..dc0989e53 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -0,0 +1,239 @@ +//===-- mlir-c/Dialect/Quant.h - C API for LLVM -------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_QUANT_H +#define MLIR_C_DIALECT_QUANT_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(quant, quant); + +//===---------------------------------------------------------------------===// +// QuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a quantization dialect type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAQuantizedType(MlirType type); + +/// Returns the bit flag used to indicate signedness of a quantized type. +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetSignedFlag(void); + +/// Returns the minimum possible value stored by a quantized type. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMinimumForInteger( + bool isSigned, unsigned integralWidth); + +/// Returns the maximum possible value stored by a quantized type. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMaximumForInteger( + bool isSigned, unsigned integralWidth); + +/// Gets the original type approximated by the given quantized type. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeGetExpressedType(MlirType type); + +/// Gets the flags associated with the given quantized type. +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetFlags(MlirType type); + +/// Returns `true` if the given type is signed, `false` otherwise. +MLIR_CAPI_EXPORTED bool mlirQuantizedTypeIsSigned(MlirType type); + +/// Returns the underlying type used to store the values. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeGetStorageType(MlirType type); + +/// Returns the minimum value that the storage type of the given quantized type +/// can take. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type); + +/// Returns the maximum value that the storage type of the given quantized type +/// can take. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type); + +/// Returns the integral bitwidth that the storage type of the given quantized +/// type can represent exactly. +MLIR_CAPI_EXPORTED unsigned +mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type); + +/// Returns `true` if the `candidate` type is compatible with the given +/// quantized `type`. +MLIR_CAPI_EXPORTED bool +mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate); + +/// Returns the element type of the given quantized type as another quantized +/// type. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeGetQuantizedElementType(MlirType type); + +/// Casts from a type based on the storage type of the given type to a +/// corresponding type based on the given type. Returns a null type if the cast +/// is not valid. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate); + +/// Casts from a type based on a quantized type to a corresponding typed based +/// on the storage type. Returns a null type if the cast is not valid. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeCastToStorageType(MlirType type); + +/// Casts from a type based on the expressed type of the given type to a +/// corresponding type based on the given type. Returns a null type if the cast +/// is not valid. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate); + +/// Casts from a type based on a quantized type to a corresponding typed based +/// on the expressed type. Returns a null type if the cast is not valid. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeCastToExpressedType(MlirType type); + +/// Casts from a type based on the expressed type of the given quantized type to +/// equivalent type based on storage type of the same quantized type. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate); + +//===---------------------------------------------------------------------===// +// AnyQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is an AnyQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type); + +/// Creates an instance of AnyQuantizedType with the given parameters in the +/// same context as `storageType` and returns it. The instance is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags, + MlirType storageType, + MlirType expressedType, + int64_t storageTypeMin, + int64_t storageTypeMax); + +//===---------------------------------------------------------------------===// +// UniformQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type); + +/// Creates an instance of UniformQuantizedType with the given parameters in the +/// same context as `storageType` and returns it. The instance is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the scale of the given uniform quantized type. +MLIR_CAPI_EXPORTED double mlirUniformQuantizedTypeGetScale(MlirType type); + +/// Returns the zero point of the given uniform quantized type. +MLIR_CAPI_EXPORTED int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type); + +/// Returns `true` if the given uniform quantized type is fixed-point. +MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type); + +//===---------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedPerAxisType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type); + +/// Creates an instance of UniformQuantizedPerAxisType with the given parameters +/// in the same context as `storageType` and returns it. `scales` and +/// `zeroPoints` point to `nDims` number of elements. The instance is owned +/// by the context. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedPerAxisTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + intptr_t nDims, double *scales, int64_t *zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the number of axes in the given quantized per-axis type. +MLIR_CAPI_EXPORTED intptr_t +mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type); + +/// Returns `pos`-th scale of the given quantized per-axis type. +MLIR_CAPI_EXPORTED double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, + intptr_t pos); + +/// Returns `pos`-th zero point of the given quantized per-axis type. +MLIR_CAPI_EXPORTED int64_t +mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos); + +/// Returns the index of the quantized dimension in the given quantized per-axis +/// type. +MLIR_CAPI_EXPORTED int32_t +mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type); + +/// Returns `true` if the given uniform quantized per-axis type is fixed-point. +MLIR_CAPI_EXPORTED bool +mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type); + +//===---------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedSubChannel. +MLIR_CAPI_EXPORTED bool +mlirTypeIsAUniformQuantizedSubChannelType(MlirType type); + +/// Creates a UniformQuantizedSubChannelType with the given parameters. +/// +/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be +/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes` +/// point to `blockSizeInfoLength` number of elements, describing respectively +/// the quantization axis and corresponding block size. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, + intptr_t blockSizeInfoLength, int32_t *quantizedDimensions, + int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the number of block sizes provided in type. +MLIR_CAPI_EXPORTED intptr_t +mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type); + +/// Returns the quantized dimension at the given position. +MLIR_CAPI_EXPORTED int32_t +mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type, + intptr_t pos); + +/// Returns the block size at the given position. +MLIR_CAPI_EXPORTED int64_t +mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos); + +/// Returns the scales of the quantized type. +MLIR_CAPI_EXPORTED MlirAttribute +mlirUniformQuantizedSubChannelTypeGetScales(MlirType type); + +/// Returns the zero-points of the quantized type. +MLIR_CAPI_EXPORTED MlirAttribute +mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type); + +//===---------------------------------------------------------------------===// +// CalibratedQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a CalibratedQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type); + +/// Creates an instance of CalibratedQuantizedType with the given parameters +/// in the same context as `expressedType` and returns it. The instance is owned +/// by the context. +MLIR_CAPI_EXPORTED MlirType +mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, double max); + +/// Returns the min value of the given calibrated quantized type. +MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMin(MlirType type); + +/// Returns the max value of the given calibrated quantized type. +MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMax(MlirType type); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_QUANT_H diff --git a/mlir/include/mlir-c/Dialect/ROCDL.h b/mlir/include/mlir-c/Dialect/ROCDL.h new file mode 100644 index 000000000..e5dbb55b5 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/ROCDL.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/ROCDL.h - C API for ROCDL dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_ROCDL_H +#define MLIR_C_DIALECT_ROCDL_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(ROCDL, rocdl); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_ROCDL_H diff --git a/mlir/include/mlir-c/Dialect/SCF.h b/mlir/include/mlir-c/Dialect/SCF.h index c1b256779..75f1b6839 100644 --- a/mlir/include/mlir-c/Dialect/SCF.h +++ b/mlir/include/mlir-c/Dialect/SCF.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_SCF_H #define MLIR_C_DIALECT_SCF_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h new file mode 100644 index 000000000..0ad64746f --- /dev/null +++ b/mlir/include/mlir-c/Dialect/SMT.h @@ -0,0 +1,111 @@ +//===- SMT.h - C interface for the SMT dialect --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_SMT_H +#define MLIR_C_DIALECT_SMT_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt); + +//===----------------------------------------------------------------------===// +// Type API. +//===----------------------------------------------------------------------===// + +/// Checks if the given type is any non-func SMT value type. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type); + +/// Checks if the given type is any SMT value type. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type); + +/// Checks if the given type is a smt::ArrayType. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type); + +/// Creates an array type with the given domain and range types. +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx, + MlirType domainType, + MlirType rangeType); + +/// Checks if the given type is a smt::BitVectorType. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type); + +/// Creates a smt::BitVectorType with the given width. +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx, + int32_t width); + +/// Checks if the given type is a smt::BoolType. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type); + +/// Creates a smt::BoolType. +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx); + +/// Checks if the given type is a smt::IntType. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type); + +/// Creates a smt::IntType. +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx); + +/// Checks if the given type is a smt::FuncType. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type); + +/// Creates a smt::FuncType with the given domain and range types. +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, + size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType); + +/// Checks if the given type is a smt::SortType. +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type); + +/// Creates a smt::SortType with the given identifier and sort parameters. +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx, + MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams); + +//===----------------------------------------------------------------------===// +// Attribute API. +//===----------------------------------------------------------------------===// + +/// Checks if the given string is a valid smt::BVCmpPredicate. +MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, + MlirStringRef str); + +/// Checks if the given string is a valid smt::IntPredicate. +MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, + MlirStringRef str); + +/// Checks if the given attribute is a smt::SMTAttribute. +MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr); + +/// Creates a smt::BitVectorAttr with the given value and width. +MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, + uint64_t value, + unsigned width); + +/// Creates a smt::BVCmpPredicateAttr with the given string. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str); + +/// Creates a smt::IntPredicateAttr with the given string. +MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, + MlirStringRef str); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_SMT_H diff --git a/mlir/include/mlir-c/Dialect/SPIRV.h b/mlir/include/mlir-c/Dialect/SPIRV.h new file mode 100644 index 000000000..f22708c9d --- /dev/null +++ b/mlir/include/mlir-c/Dialect/SPIRV.h @@ -0,0 +1,26 @@ +//===-- mlir-c/Dialect/SPIRV.h - C API for SPIRV dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_SPIRV_H +#define MLIR_C_DIALECT_SPIRV_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SPIRV, spirv); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_SPIRV_H diff --git a/mlir/include/mlir-c/Dialect/Shape.h b/mlir/include/mlir-c/Dialect/Shape.h index f64da8016..3fe3ddf5c 100644 --- a/mlir/include/mlir-c/Dialect/Shape.h +++ b/mlir/include/mlir-c/Dialect/Shape.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_SHAPE_H #define MLIR_C_DIALECT_SHAPE_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h new file mode 100644 index 000000000..c816c1b58 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -0,0 +1,116 @@ +//===-- mlir-c/Dialect/SparseTensor.h - C API for SparseTensor ----*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_SPARSETENSOR_H +#define MLIR_C_DIALECT_SPARSETENSOR_H + +#include "mlir-c/AffineMap.h" +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); + +/// Dimension level types (and properties) that define sparse tensors. +/// See the documentation in SparseTensorAttrDefs.td for their meaning. +/// +/// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API. +/// If updating, keep them in sync and update the static_assert in the impl +/// file. +typedef uint64_t MlirSparseTensorLevelType; + +enum MlirSparseTensorLevelFormat { + MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000, + MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000, + MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000, +}; + +enum MlirSparseTensorLevelPropertyNondefault { + MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001, + MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002, + MLIR_SPARSE_PROPERTY_SOA = 0x0004, +}; + +//===----------------------------------------------------------------------===// +// SparseTensorEncodingAttr +//===----------------------------------------------------------------------===// + +/// Checks whether the given attribute is a `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED bool +mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); + +/// Creates a `sparse_tensor.encoding` attribute with the given parameters. +MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( + MlirContext ctx, intptr_t lvlRank, + MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, + MlirAffineMap lvlTodim, int posWidth, int crdWidth, + MlirAttribute explicitVal, MlirAttribute implicitVal); + +/// Returns the level-rank of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED intptr_t +mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); + +/// Returns a specified level-type of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED MlirSparseTensorLevelType +mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); + +/// Returns a specified level-format of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelFormat +mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl); + +/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding` +/// attribute. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr); + +/// Returns the level-to-dimension mapping of the `sparse_tensor.encoding` +/// attribute. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr); + +/// Returns the position bitwidth of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED int +mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr); + +/// Returns the coordinate bitwidth of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED int +mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr); + +/// Returns the explicit value of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr); + +/// Returns the implicit value of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr); + +MLIR_CAPI_EXPORTED unsigned +mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType); + +MLIR_CAPI_EXPORTED unsigned +mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType); + +MLIR_CAPI_EXPORTED MlirSparseTensorLevelType +mlirSparseTensorEncodingAttrBuildLvlType( + enum MlirSparseTensorLevelFormat lvlFmt, + const enum MlirSparseTensorLevelPropertyNondefault *properties, + unsigned propSize, unsigned n, unsigned m); + +#ifdef __cplusplus +} +#endif + +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" + +#endif // MLIR_C_DIALECT_SPARSETENSOR_H diff --git a/mlir/include/mlir-c/Dialect/Tensor.h b/mlir/include/mlir-c/Dialect/Tensor.h index f74978248..74cbc5a6f 100644 --- a/mlir/include/mlir-c/Dialect/Tensor.h +++ b/mlir/include/mlir-c/Dialect/Tensor.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_TENSOR_H #define MLIR_C_DIALECT_TENSOR_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h new file mode 100644 index 000000000..02c99b592 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -0,0 +1,83 @@ +//===-- mlir-c/Dialect/Transform.h - C API for Transform Dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_TRANSFORM_H +#define MLIR_C_DIALECT_TRANSFORM_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); + +//===---------------------------------------------------------------------===// +// AnyOpType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// AnyParamType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// AnyValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType +mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName); + +MLIR_CAPI_EXPORTED MlirStringRef +mlirTransformOperationTypeGetOperationName(MlirType type); + +//===---------------------------------------------------------------------===// +// ParamType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, + MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_TRANSFORM_H diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h new file mode 100644 index 000000000..fa3203242 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h @@ -0,0 +1,87 @@ +//===-- mlir-c/Dialect/Transform/Interpreter.h --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// C interface to the transform dialect interpreter. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirTransformOptions, void); + +#undef DEFINE_C_API_STRUCT + +//----------------------------------------------------------------------------// +// MlirTransformOptions +//----------------------------------------------------------------------------// + +/// Creates a default-initialized transform options object. +MLIR_CAPI_EXPORTED MlirTransformOptions mlirTransformOptionsCreate(void); + +/// Enables or disables expensive checks in transform options. +MLIR_CAPI_EXPORTED void +mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions, + bool enable); + +/// Returns true if expensive checks are enabled in transform options. +MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetExpensiveChecksEnabled( + MlirTransformOptions transformOptions); + +/// Enables or disables the enforcement of the top-level transform op being +/// single in transform options. +MLIR_CAPI_EXPORTED void mlirTransformOptionsEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions, bool enable); + +/// Returns true if the enforcement of the top-level transform op being single +/// is enabled in transform options. +MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions); + +/// Destroys a transform options object previously created by +/// mlirTransformOptionsCreate. +MLIR_CAPI_EXPORTED void +mlirTransformOptionsDestroy(MlirTransformOptions transformOptions); + +//----------------------------------------------------------------------------// +// Transform interpreter and utilities. +//----------------------------------------------------------------------------// + +/// Applies the transformation script starting at the given transform root +/// operation to the given payload operation. The module containing the +/// transform root as well as the transform options should be provided. The +/// transform operation must implement TransformOpInterface and the module must +/// be a ModuleOp. Returns the status of the application. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence( + MlirOperation payload, MlirOperation transformRoot, + MlirOperation transformModule, MlirTransformOptions transformOptions); + +/// Merge the symbols from `other` into `target`, potentially renaming them to +/// avoid conflicts. Private symbols may be renamed during the merge, public +/// symbols must have at most one declaration. A name conflict in public symbols +/// is reported as an error before returning a failure. +/// +/// Note that this clones the `other` operation unlike the C++ counterpart that +/// takes ownership. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other); + +#ifdef __cplusplus +} +#endif diff --git a/mlir/include/mlir-c/Dialect/Vector.h b/mlir/include/mlir-c/Dialect/Vector.h new file mode 100644 index 000000000..6256c82d1 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Vector.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Vector.h - C API for Vector dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Vector dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_VECTOR_H +#define MLIR_C_DIALECT_VECTOR_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Vector, vector); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_VECTOR_H diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index c25635771..99cddc5c2 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -36,9 +36,15 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void); /// expected to be "translatable" to LLVM IR (only contains operations in /// dialects that implement the `LLVMTranslationDialectInterface`). The module /// ownership stays with the client and can be destroyed as soon as the call -/// returns. -/// TODO: figure out options (optimization level, etc.). -MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op); +/// returns. `optLevel` is the optimization level to be used for transformation +/// and code generation. LLVM passes at `optLevel` are run before code +/// generation. The number and array of paths corresponding to shared libraries +/// that will be loaded are specified via `numPaths` and `sharedLibPaths` +/// respectively. +/// TODO: figure out other options. +MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( + MlirModule op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths, bool enableObjectDump); /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); @@ -56,13 +62,29 @@ static inline bool mlirExecutionEngineIsNull(MlirExecutionEngine jit) { MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked( MlirExecutionEngine jit, MlirStringRef name, void **arguments); +/// Lookup the wrapper of the native function in the execution engine with the +/// given name, returns nullptr if the function can't be looked-up. +MLIR_CAPI_EXPORTED void * +mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, MlirStringRef name); + /// Lookup a native function in the execution engine by name, returns nullptr /// if the name can't be looked-up. MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name); +/// Register a symbol with the jit: this symbol will be accessible to the jitted +/// code. +MLIR_CAPI_EXPORTED void +mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, MlirStringRef name, + void *sym); + +/// Dump as an object in `fileName`. +MLIR_CAPI_EXPORTED void +mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit, + MlirStringRef fileName); + #ifdef __cplusplus } #endif -#endif // EXECUTIONENGINE_H +#endif // MLIR_C_EXECUTIONENGINE_H diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d807cd46d..71c7d4378 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -48,12 +48,17 @@ extern "C" { }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirAsmState, void); +DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void); DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); +DEFINE_C_API_STRUCT(MlirDialectRegistry, void); DEFINE_C_API_STRUCT(MlirOperation, void); +DEFINE_C_API_STRUCT(MlirOpOperand, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); +DEFINE_C_API_STRUCT(MlirSymbolTable, void); DEFINE_C_API_STRUCT(MlirAttribute, const void); DEFINE_C_API_STRUCT(MlirIdentifier, const void); @@ -68,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void); /// /// A named attribute is essentially a (name, attribute) pair where the name is /// a string. - struct MlirNamedAttribute { MlirIdentifier name; MlirAttribute attribute; @@ -80,7 +84,18 @@ typedef struct MlirNamedAttribute MlirNamedAttribute; //===----------------------------------------------------------------------===// /// Creates an MLIR context and transfers its ownership to the caller. -MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(); +/// This sets the default multithreading option (enabled). +MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void); + +/// Creates an MLIR context with an explicit setting of the multithreading +/// setting and transfers its ownership to the caller. +MLIR_CAPI_EXPORTED MlirContext +mlirContextCreateWithThreading(bool threadingEnabled); + +/// Creates an MLIR context, setting the multithreading setting explicitly and +/// pre-loading the dialects from the provided DialectRegistry. +MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithRegistry( + MlirDialectRegistry registry, bool threadingEnabled); /// Checks if two contexts are equal. MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2); @@ -106,6 +121,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context); MLIR_CAPI_EXPORTED intptr_t mlirContextGetNumRegisteredDialects(MlirContext context); +/// Append the contents of the given dialect registry to the registry associated +/// with the context. +MLIR_CAPI_EXPORTED void +mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry); + /// Returns the number of dialects loaded by the context. MLIR_CAPI_EXPORTED intptr_t @@ -119,6 +139,38 @@ mlirContextGetNumLoadedDialects(MlirContext context); MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); +/// Set threading mode (must be set to false to mlir-print-ir-after-all). +MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, + bool enable); + +/// Eagerly loads all available dialects registered with a context, making +/// them available for use for IR construction. +MLIR_CAPI_EXPORTED void +mlirContextLoadAllAvailableDialects(MlirContext context); + +/// Returns whether the given fully-qualified operation (i.e. +/// 'dialect.operation') is registered with the context. This will return true +/// if the dialect is loaded and the operation is registered within the +/// dialect. +MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, + MlirStringRef name); + +/// Sets the thread pool of the context explicitly, enabling multithreading in +/// the process. This API should be used to avoid re-creating thread pools in +/// long-running applications that perform multiple compilations, see +/// the C++ documentation for MLIRContext for details. +MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, + MlirLlvmThreadPool threadPool); + +/// Gets the number of threads of the thread pool of the context when +/// multithreading is enabled. Returns 1 if no multithreading. +MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context); + +/// Gets the thread pool of the context when enabled multithreading, otherwise +/// an assertion is raised. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool +mlirContextGetThreadPool(MlirContext context); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// @@ -139,18 +191,175 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1, /// Returns the namespace of the given dialect. MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); +//===----------------------------------------------------------------------===// +// DialectHandle API. +// Registration entry-points for each dialect are declared using the common +// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect +// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", +// "tensor", "linalg"). The following declarations are produced: +// +// /// Gets the above hook methods in struct form for a dialect by namespace. +// /// This is intended to facilitate dynamic lookup and registration of +// /// dialects via a plugin facility based on shared library symbol lookup. +// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); +// +// This is done via a common macro to facilitate future expansion to +// registration schemes. +//===----------------------------------------------------------------------===// + +struct MlirDialectHandle { + const void *ptr; +}; +typedef struct MlirDialectHandle MlirDialectHandle; + +#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__( \ + void) + +/// Returns the namespace associated with the provided dialect handle. +MLIR_CAPI_EXPORTED +MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); + +/// Inserts the dialect associated with the provided dialect handle into the +/// provided dialect registry +MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, + MlirDialectRegistry); + +/// Registers the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, + MlirContext); + +/// Loads the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, + MlirContext); + +//===----------------------------------------------------------------------===// +// DialectRegistry API. +//===----------------------------------------------------------------------===// + +/// Creates a dialect registry and transfers its ownership to the caller. +MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate(void); + +/// Checks if the dialect registry is null. +static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) { + return !registry.ptr; +} + +/// Takes a dialect registry owned by the caller and destroys it. +MLIR_CAPI_EXPORTED void +mlirDialectRegistryDestroy(MlirDialectRegistry registry); + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// +/// Returns the underlying location attribute of this location. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLocationGetAttribute(MlirLocation location); + +/// Creates a location from a location attribute. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationFromAttribute(MlirAttribute attribute); + /// Creates an File/Line/Column location owned by the given context. MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet( MlirContext context, MlirStringRef filename, unsigned line, unsigned col); +/// Creates an File/Line/Column range location owned by the given context. +MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet( + MlirContext context, MlirStringRef filename, unsigned start_line, + unsigned start_col, unsigned end_line, unsigned end_col); + +/// Getter for filename of FileLineColRange. +MLIR_CAPI_EXPORTED MlirIdentifier +mlirLocationFileLineColRangeGetFilename(MlirLocation location); + +/// Getter for start_line of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetStartLine(MlirLocation location); + +/// Getter for start_column of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetStartColumn(MlirLocation location); + +/// Getter for end_line of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetEndLine(MlirLocation location); + +/// Getter for end_column of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetEndColumn(MlirLocation location); + +/// TypeID Getter for FileLineColRange. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void); + +/// Checks whether the given location is an FileLineColRange. +MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location); + /// Creates a call site location with a callee and a caller. MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller); +/// Getter for callee of CallSite. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationCallSiteGetCallee(MlirLocation location); + +/// Getter for caller of CallSite. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationCallSiteGetCaller(MlirLocation location); + +/// TypeID Getter for CallSite. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void); + +/// Checks whether the given location is an CallSite. +MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location); + +/// Creates a fused location with an array of locations and metadata. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, + MlirLocation const *locations, MlirAttribute metadata); + +/// Getter for number of locations fused together. +MLIR_CAPI_EXPORTED unsigned +mlirLocationFusedGetNumLocations(MlirLocation location); + +/// Getter for locations of Fused. Requires pre-allocated memory of +/// #fusedLocations X sizeof(MlirLocation). +MLIR_CAPI_EXPORTED void +mlirLocationFusedGetLocations(MlirLocation location, + MlirLocation *locationsCPtr); + +/// Getter for metadata of Fused. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLocationFusedGetMetadata(MlirLocation location); + +/// TypeID Getter for Fused. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void); + +/// Checks whether the given location is an Fused. +MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location); + +/// Creates a name location owned by the given context. Providing null location +/// for childLoc is allowed and if childLoc is null location, then the behavior +/// is the same as having unknown child location. +MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, + MlirStringRef name, + MlirLocation childLoc); + +/// Getter for name of Name. +MLIR_CAPI_EXPORTED MlirIdentifier +mlirLocationNameGetName(MlirLocation location); + +/// Getter for childLoc of Name. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationNameGetChildLoc(MlirLocation location); + +/// TypeID Getter for Name. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void); + +/// Checks whether the given location is an Name. +MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location); + /// Creates a location with unknown position owned by the given context. MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context); @@ -183,6 +392,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location); MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); +/// Parses a module from file and transfers ownership to the caller. +MLIR_CAPI_EXPORTED MlirModule +mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName); + /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); @@ -198,6 +411,10 @@ MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module); /// Views the module as a generic operation. MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); +/// Views the generic operation as a module. +/// The returned module is null when the input operation was not a ModuleOp. +MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// @@ -258,6 +475,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MLIR_CAPI_EXPORTED void mlirOperationStateEnableResultTypeInference(MlirOperationState *state); +//===----------------------------------------------------------------------===// +// AsmState API. +// While many of these are simple settings that could be represented in a +// struct, they are wrapped in a heap allocated object and accessed via +// functions to maximize the possibility of compatibility over time. +//===----------------------------------------------------------------------===// + +/// Creates new AsmState, as with AsmState the IR should not be mutated +/// in-between using this state. +/// Must be freed with a call to mlirAsmStateDestroy(). +// TODO: This should be expanded to handle location & resouce map. +MLIR_CAPI_EXPORTED MlirAsmState +mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags); + +/// Creates new AsmState from value. +/// Must be freed with a call to mlirAsmStateDestroy(). +// TODO: This should be expanded to handle location & resouce map. +MLIR_CAPI_EXPORTED MlirAsmState +mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags); + +/// Destroys printing flags created with mlirAsmStateCreate. +MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state); + //===----------------------------------------------------------------------===// // Op Printing flags API. // While many of these are simple settings that could be represented in a @@ -267,7 +507,7 @@ mlirOperationStateEnableResultTypeInference(MlirOperationState *state); /// Creates new printing flags with defaults, intended for customization. /// Must be freed with a call to mlirOpPrintingFlagsDestroy(). -MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(); +MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void); /// Destroys printing flags created with mlirOpPrintingFlagsCreate. MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags); @@ -280,16 +520,29 @@ MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit); -/// Enable printing of debug information. If 'prettyForm' is set to true, -/// debug information is printed in a more readable 'pretty' form. Note: The -/// IR generated with 'prettyForm' is not parsable. +/// Enables the elision of large resources strings by omitting them from the +/// `dialect_resources` section. The `largeResourceLimit` is used to configure +/// what is considered to be a "large" resource by providing an upper limit to +/// the string size. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, + intptr_t largeResourceLimit); + +/// Enable or disable printing of debug information (based on `enable`). If +/// 'prettyForm' is set to true, debug information is printed in a more readable +/// 'pretty' form. Note: The IR generated with 'prettyForm' is not parsable. MLIR_CAPI_EXPORTED void -mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool prettyForm); +mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, + bool prettyForm); /// Always print operations in the generic form. MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags); +/// Print the name and location, if NamedLoc, as a prefix to the SSA ID. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags); + /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not /// necessarily be identical to what the IR will look like when dumping @@ -297,6 +550,32 @@ mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags); MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); +/// Do not verify the operation when using custom operation printers. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); + +/// Skip printing regions. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags); + +//===----------------------------------------------------------------------===// +// Bytecode printing flags API. +//===----------------------------------------------------------------------===// + +/// Creates new printing flags with defaults, intended for customization. +/// Must be freed with a call to mlirBytecodeWriterConfigDestroy(). +MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig +mlirBytecodeWriterConfigCreate(void); + +/// Destroys printing flags created with mlirBytecodeWriterConfigCreate. +MLIR_CAPI_EXPORTED void +mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config); + +/// Sets the version to emit in the writer config. +MLIR_CAPI_EXPORTED void +mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, + int64_t version); + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// @@ -311,9 +590,27 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); /// - Result type inference is enabled and cannot be performed. MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); +/// Parses an operation, giving ownership to the caller. If parsing fails a null +/// operation will be returned, and an error diagnostic emitted. +/// +/// `sourceStr` may be either the text assembly format, or binary bytecode +/// format. `sourceName` is used as the file name of the source; any IR without +/// locations will get a `FileLineColLoc` location with `sourceName` as the file +/// name. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse( + MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName); + +/// Creates a deep copy of an operation. The operation is not inserted and +/// ownership is transferred to the caller. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op); + /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); +/// Removes the given operation from its parent block. The operation is not +/// destroyed. The ownership of the operation is transferred to the caller. +MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op); + /// Checks whether the underlying operation is null. static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; } @@ -325,6 +622,14 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, /// Gets the context this operation is associated with MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); +/// Gets the location of the operation. +MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op); + +/// Gets the type id of the operation. +/// Returns null if the operation does not have a registered operation +/// description. +MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op); + /// Gets the name of the operation as an identifier. MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op); @@ -355,6 +660,15 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op); MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos); +/// Sets the `pos`-th operand of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue); + +/// Replaces the operands of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetOperands(MlirOperation op, + intptr_t nOperands, + MlirValue const *operands); + /// Returns the number of results of the operation. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op); @@ -369,25 +683,81 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op); MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos); +/// Set `pos`-th successor of the operation. +MLIR_CAPI_EXPORTED void +mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block); + +/// Returns true if this operation defines an inherent attribute with this name. +/// Note: the attribute can be optional, so +/// `mlirOperationGetInherentAttributeByName` can still return a null attribute. +MLIR_CAPI_EXPORTED bool +mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name); + +/// Returns an inherent attribute attached to the operation given its name. +MLIR_CAPI_EXPORTED MlirAttribute +mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name); + +/// Sets an inherent attribute by name, replacing the existing if it exists. +/// This has no effect if "name" does not match an inherent attribute. +MLIR_CAPI_EXPORTED void +mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name, + MlirAttribute attr); + +/// Returns the number of discardable attributes attached to the operation. +MLIR_CAPI_EXPORTED intptr_t +mlirOperationGetNumDiscardableAttributes(MlirOperation op); + +/// Return `pos`-th discardable attribute of the operation. +MLIR_CAPI_EXPORTED MlirNamedAttribute +mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos); + +/// Returns a discardable attribute attached to the operation given its name. +MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName( + MlirOperation op, MlirStringRef name); + +/// Sets a discardable attribute by name, replacing the existing if it exists or +/// adding a new one otherwise. The new `attr` Attribute is not allowed to be +/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an +/// Attribute instead. +MLIR_CAPI_EXPORTED void +mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name, + MlirAttribute attr); + +/// Removes a discardable attribute by name. Returns false if the attribute was +/// not found and true if removed. +MLIR_CAPI_EXPORTED bool +mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, + MlirStringRef name); + /// Returns the number of attributes attached to the operation. +/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or +/// `mlirOperationGetNumDiscardableAttributes`. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op); /// Return `pos`-th attribute of the operation. +/// Deprecated, please use `mlirOperationGetInherentAttribute` or +/// `mlirOperationGetDiscardableAttribute`. MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos); /// Returns an attribute attached to the operation given its name. +/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or +/// `mlirOperationGetDiscardableAttributeByName`. MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name); /// Sets an attribute by name, replacing the existing if it exists or /// adding a new one otherwise. +/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or +/// `mlirOperationSetDiscardableAttributeByName`. MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr); /// Removes an attribute by name. Returns false if the attribute was not found /// and true if removed. +/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or +/// `mlirOperationRemoveDiscardableAttributeByName`. MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name); @@ -405,18 +775,82 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirStringCallback callback, void *userData); +/// Same as mlirOperationPrint but accepts AsmState controlling the printing +/// behavior as well as caching computed names. +MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, + MlirAsmState state, + MlirStringCallback callback, + void *userData); + +/// Same as mlirOperationPrint but writing the bytecode format. +MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData); + +/// Same as mlirOperationWriteBytecode but with writer config and returns +/// failure only if desired bytecode could not be honored. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig( + MlirOperation op, MlirBytecodeWriterConfig config, + MlirStringCallback callback, void *userData); + /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); /// Verify the operation and return true if it passes, false if it fails. MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op); +/// Moves the given operation immediately after the other operation in its +/// parent block. The given operation may be owned by the caller or by its +/// current block. The other operation must belong to a block. In any case, the +/// ownership is transferred to the block of the other operation. +MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, + MlirOperation other); + +/// Moves the given operation immediately before the other operation in its +/// parent block. The given operation may be owner by the caller or by its +/// current block. The other operation must belong to a block. In any case, the +/// ownership is transferred to the block of the other operation. +MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, + MlirOperation other); + +/// Given an operation 'other' that is within the same parent block, return +/// whether the current operation is before 'other' in the operation list +/// of the parent block. +/// Note: This function has an average complexity of O(1), but worst case may +/// take O(N) where N is the number of operations within the parent block. +MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, + MlirOperation other); +/// Operation walk result. +typedef enum MlirWalkResult { + MlirWalkResultAdvance, + MlirWalkResultInterrupt, + MlirWalkResultSkip +} MlirWalkResult; + +/// Traversal order for operation walk. +typedef enum MlirWalkOrder { + MlirWalkPreOrder, + MlirWalkPostOrder +} MlirWalkOrder; + +/// Operation walker type. The handler is passed an (opaque) reference to an +/// operation and a pointer to a `userData`. +typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation, + void *userData); + +/// Walks operation `op` in `walkOrder` and calls `callback` on that operation. +/// `*userData` is passed to the callback as well and can be used to tunnel some +/// context or other data into the callback. +MLIR_CAPI_EXPORTED +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder); + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// /// Creates a new empty region and transfers ownership to the caller. -MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(); +MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void); /// Takes a region owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region); @@ -424,6 +858,10 @@ MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region); /// Checks whether a region is null. static inline bool mlirRegionIsNull(MlirRegion region) { return !region.ptr; } +/// Checks whether two region handles point to the same region. This does not +/// perform deep comparison. +MLIR_CAPI_EXPORTED bool mlirRegionEqual(MlirRegion region, MlirRegion other); + /// Gets the first block in the region. MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region); @@ -451,6 +889,17 @@ MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block); +/// Returns first region attached to the operation. +MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetFirstRegion(MlirOperation op); + +/// Returns the region immediately following the given region in its parent +/// operation. +MLIR_CAPI_EXPORTED MlirRegion mlirRegionGetNextInOperation(MlirRegion region); + +/// Moves the entire content of the source region to the target region. +MLIR_CAPI_EXPORTED void mlirRegionTakeBody(MlirRegion target, + MlirRegion source); + //===----------------------------------------------------------------------===// // Block API. //===----------------------------------------------------------------------===// @@ -458,11 +907,15 @@ MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, /// Creates a new empty block with the given argument types and transfers /// ownership to the caller. MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, - MlirType const *args); + MlirType const *args, + MlirLocation const *locs); /// Takes a block owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirBlockDestroy(MlirBlock block); +/// Detach a block from the owning region and assume ownership. +MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block); + /// Checks whether a block is null. static inline bool mlirBlockIsNull(MlirBlock block) { return !block.ptr; } @@ -473,6 +926,9 @@ MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other); /// Returns the closest surrounding operation that contains this block. MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock); +/// Returns the region that contains this block. +MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block); + /// Returns the block immediately following the given block in its parent /// region. MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block); @@ -514,7 +970,18 @@ MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block); /// Appends an argument of the specified type to the block. Returns the newly /// added argument. MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, - MlirType type); + MlirType type, + MlirLocation loc); + +/// Erase the argument at 'index' and remove it from the argument list. +MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index); + +/// Inserts an argument of the specified type at a specified index to the block. +/// Returns the newly added argument. +MLIR_CAPI_EXPORTED MlirValue mlirBlockInsertArgument(MlirBlock block, + intptr_t pos, + MlirType type, + MlirLocation loc); /// Returns `pos`-th argument of the block. MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, @@ -526,6 +993,24 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); +/// Returns the number of successor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block); + +/// Returns `pos`-th successor of the block. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block, + intptr_t pos); + +/// Returns the number of predecessor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block); + +/// Returns `pos`-th predecessor of the block. +/// +/// WARNING: This getter is more expensive than the others here because +/// the impl actually iterates the use-def chain (of block operands) anew for +/// each indexed access. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block, + intptr_t pos); + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// @@ -534,7 +1019,7 @@ mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); static inline bool mlirValueIsNull(MlirValue value) { return !value.ptr; } /// Returns 1 if two values are equal, 0 otherwise. -bool mlirValueEqual(MlirValue value1, MlirValue value2); +MLIR_CAPI_EXPORTED bool mlirValueEqual(MlirValue value1, MlirValue value2); /// Returns 1 if the value is a block argument, 0 otherwise. MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value); @@ -564,6 +1049,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value); /// Returns the type of the value. MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value); +/// Set the type of the value. +MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type); + /// Prints the value to the standard error stream. MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); @@ -573,6 +1061,59 @@ MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +/// Prints a value as an operand (i.e., the ValueID). +MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, + MlirAsmState state, + MlirStringCallback callback, + void *userData); + +/// Returns an op operand representing the first use of the value, or a null op +/// operand if there are no uses. +MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); + +/// Replace all uses of 'of' value with the 'with' value, updating anything in +/// the IR that uses 'of' to use the other value instead. When this returns +/// there are zero uses of 'of'. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, + MlirValue with); + +/// Replace all uses of 'of' value with 'with' value, updating anything in the +/// IR that uses 'of' to use 'with' instead, except if the user is listed in +/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation +/// pointers with a length of 'numExceptions'. +MLIR_CAPI_EXPORTED void +mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, + intptr_t numExceptions, + MlirOperation *exceptions); + +/// Gets the location of the value. +MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v); + +/// Gets the context that a value was created with. +MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v); + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +/// Returns whether the op operand is null. +MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand); + +/// Returns the value of an op operand. +MLIR_CAPI_EXPORTED MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand); + +/// Returns the owner operation of an op operand. +MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand); + +/// Returns the operand number of an op operand. +MLIR_CAPI_EXPORTED unsigned +mlirOpOperandGetOperandNumber(MlirOpOperand opOperand); + +/// Returns an op operand representing the next use of the value, or a null op +/// operand if there is no next use. +MLIR_CAPI_EXPORTED MlirOpOperand +mlirOpOperandGetNextUse(MlirOpOperand opOperand); + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// @@ -584,6 +1125,12 @@ MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, /// Gets the context that a type was created with. MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type); +/// Gets the type ID of the type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type); + +/// Gets the dialect a type belongs to. +MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type); + /// Checks whether a type is null. static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; } @@ -613,6 +1160,12 @@ MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute); /// Gets the type of this attribute. MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute); +/// Gets the type id of the attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); + +/// Gets the dialect of the attribute. +MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } @@ -651,6 +1204,68 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident, /// Gets the string value of the identifier. MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); +//===----------------------------------------------------------------------===// +// Symbol and SymbolTable API. +//===----------------------------------------------------------------------===// + +/// Returns the name of the attribute used to store symbol names compatible with +/// symbol tables. +MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void); + +/// Returns the name of the attribute used to store symbol visibility. +MLIR_CAPI_EXPORTED MlirStringRef +mlirSymbolTableGetVisibilityAttributeName(void); + +/// Creates a symbol table for the given operation. If the operation does not +/// have the SymbolTable trait, returns a null symbol table. +MLIR_CAPI_EXPORTED MlirSymbolTable +mlirSymbolTableCreate(MlirOperation operation); + +/// Returns true if the symbol table is null. +static inline bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable) { + return !symbolTable.ptr; +} + +/// Destroys the symbol table created with mlirSymbolTableCreate. This does not +/// affect the operations in the table. +MLIR_CAPI_EXPORTED void mlirSymbolTableDestroy(MlirSymbolTable symbolTable); + +/// Looks up a symbol with the given name in the given symbol table and returns +/// the operation that corresponds to the symbol. If the symbol cannot be found, +/// returns a null operation. +MLIR_CAPI_EXPORTED MlirOperation +mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name); + +/// Inserts the given operation into the given symbol table. The operation must +/// have the symbol trait. If the symbol table already has a symbol with the +/// same name, renames the symbol being inserted to ensure name uniqueness. Note +/// that this does not move the operation itself into the block of the symbol +/// table operation, this should be done separately. Returns the name of the +/// symbol after insertion. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation); + +/// Removes the given operation from the symbol table and erases it. +MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, + MlirOperation operation); + +/// Attempt to replace all uses that are nested within the given operation +/// of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does +/// not traverse into nested symbol tables. Will fail atomically if there are +/// any unknown operations that may be potential symbol tables. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses( + MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from); + +/// Walks all symbol table operations nested within, and including, `op`. For +/// each symbol table operation, the provided callback is invoked with the op +/// and a boolean signifying if the symbols within that symbol table can be +/// treated as if all uses within the IR are visible to the caller. +/// `allSymUsesVisible` identifies whether all of the symbol uses of symbols +/// within `op` are visible. +MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables( + MlirOperation from, bool allSymUsesVisible, + void (*callback)(MlirOperation, bool, void *userData), void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h new file mode 100644 index 000000000..a5a3473ea --- /dev/null +++ b/mlir/include/mlir-c/Interfaces.h @@ -0,0 +1,94 @@ +//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to MLIR interface classes. It is +// intended to contain interfaces defined in lib/Interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_INTERFACES_H +#define MLIR_C_INTERFACES_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Returns `true` if the given operation implements an interface identified by +/// its TypeID. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID); + +/// Returns `true` if the operation identified by its canonical string name +/// implements the interface identified by its TypeID in the given context. +/// Note that interfaces may be attached to operations in some contexts and not +/// others. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID); + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple types from functions while +/// transferring ownership to the caller. The first argument is the number of +/// consecutive elements pointed to by the second argument. The third argument +/// is an opaque pointer forwarded to the callback by the caller. +typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); + +/// Infers the return types of the operation identified by its canonical given +/// the arguments that will be supplied to its generic builder. Calls `callback` +/// with the types of inferred arguments, potentially several times, on success. +/// Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirTypesCallback callback, void *userData); + +//===----------------------------------------------------------------------===// +// InferShapedTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferShapedTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple shaped type components from +/// functions while transferring ownership to the caller. The first argument is +/// the has rank boolean followed by the the rank and a pointer to the shape +/// (if applicable). The next argument is the element type, then the attribute. +/// The last argument is an opaque pointer forwarded to the callback by the +/// caller. This callback will be called potentially multiple times for each +/// shaped type components. +typedef void (*MlirShapedTypeComponentsCallback)(bool, intptr_t, + const int64_t *, MlirType, + MlirAttribute, void *); + +/// Infers the return shaped type components of the operation. Calls `callback` +/// with the types of inferred arguments on success. Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_INTERFACES_H diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 9669a53cd..0d2e19ee7 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -41,14 +41,23 @@ extern "C" { typedef struct name name DEFINE_C_API_STRUCT(MlirPass, void); +DEFINE_C_API_STRUCT(MlirExternalPass, void); DEFINE_C_API_STRUCT(MlirPassManager, void); DEFINE_C_API_STRUCT(MlirOpPassManager, void); #undef DEFINE_C_API_STRUCT -/// Create a new top-level PassManager. +//===----------------------------------------------------------------------===// +// PassManager/OpPassManager APIs. +//===----------------------------------------------------------------------===// + +/// Create a new top-level PassManager with the default anchor. MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx); +/// Create a new top-level PassManager anchored on `anchorOp`. +MLIR_CAPI_EXPORTED MlirPassManager +mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp); + /// Destroy the provided PassManager. MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager); @@ -61,9 +70,27 @@ static inline bool mlirPassManagerIsNull(MlirPassManager passManager) { MLIR_CAPI_EXPORTED MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); -/// Run the provided `passManager` on the given `module`. +/// Run the provided `passManager` on the given `op`. MLIR_CAPI_EXPORTED MlirLogicalResult -mlirPassManagerRun(MlirPassManager passManager, MlirModule module); +mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); + +/// Enable IR printing. +/// The treePrintingPath argument is an optional path to a directory +/// where the dumps will be produced. If it isn't provided then dumps +/// are produced to stderr. +MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( + MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, + bool printModuleScope, bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, + MlirStringRef treePrintingPath); + +/// Enable / disable verify-each. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); + +/// Enable pass timing. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableTiming(MlirPassManager passManager); /// Nest an OpPassManager under the top-level PassManager, the nested /// passmanager will only run on operations matching the provided name. @@ -92,6 +119,13 @@ MLIR_CAPI_EXPORTED void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MLIR_CAPI_EXPORTED void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass); +/// Parse a sequence of textual MLIR pass pipeline elements and add them to the +/// provided OpPassManager. If parsing fails an error message is reported using +/// the provided callback. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline( + MlirOpPassManager passManager, MlirStringRef pipelineElements, + MlirStringCallback callback, void *userData); + /// Print a textual MLIR pass pipeline by sending chunks of the string /// representation and forwarding `userData to `callback`. Note that the /// callback may be called several times with consecutive chunks of the string. @@ -99,10 +133,61 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData); -/// Parse a textual MLIR pass pipeline and add it to the provided OpPassManager. - +/// Parse a textual MLIR pass pipeline and assign it to the provided +/// OpPassManager. If parsing fails an error message is reported using the +/// provided callback. MLIR_CAPI_EXPORTED MlirLogicalResult -mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline); +mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, + MlirStringCallback callback, void *userData); + +//===----------------------------------------------------------------------===// +// External Pass API. +// +// This API allows to define passes outside of MLIR, not necessarily in +// C++, and register them with the MLIR pass management infrastructure. +// +//===----------------------------------------------------------------------===// + +/// Structure of external `MlirPass` callbacks. +/// All callbacks are required to be set unless otherwise specified. +struct MlirExternalPassCallbacks { + /// This callback is called from the pass is created. + /// This is analogous to a C++ pass constructor. + void (*construct)(void *userData); + + /// This callback is called when the pass is destroyed + /// This is analogous to a C++ pass destructor. + void (*destruct)(void *userData); + + /// This callback is optional. + /// The callback is called before the pass is run, allowing a chance to + /// initialize any complex state necessary for running the pass. + /// See Pass::initialize(MLIRContext *). + MlirLogicalResult (*initialize)(MlirContext ctx, void *userData); + + /// This callback is called when the pass is cloned. + /// See Pass::clonePass(). + void *(*clone)(void *userData); + + /// This callback is called when the pass is run. + /// See Pass::runOnOperation(). + void (*run)(MlirOperation op, MlirExternalPass pass, void *userData); +}; +typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks; + +/// Creates an external `MlirPass` that calls the supplied `callbacks` using the +/// supplied `userData`. If `opName` is empty, the pass is a generic operation +/// pass. Otherwise it is an operation pass specific to the specified pass name. +MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass( + MlirTypeID passID, MlirStringRef name, MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData); + +/// This signals that the pass has failed. This is only valid to call during +/// the `run` callback of `MlirExternalPassCallbacks`. +/// See Pass::signalPassFailure(). +MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass); #ifdef __cplusplus } diff --git a/mlir/include/mlir-c/RegisterEverything.h b/mlir/include/mlir-c/RegisterEverything.h new file mode 100644 index 000000000..ea2ea8644 --- /dev/null +++ b/mlir/include/mlir-c/RegisterEverything.h @@ -0,0 +1,38 @@ +//===-- mlir-c/RegisterEverything.h - Register all MLIR entities --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This header contains registration entry points for MLIR upstream dialects +// and passes. Downstream projects typically will not want to use this unless +// if they don't care about binary size or build bloat and just wish access +// to the entire set of upstream facilities. For those that do care, they +// should use registration functions specific to their project. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_REGISTER_EVERYTHING_H +#define MLIR_C_REGISTER_EVERYTHING_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Appends all upstream dialects and extensions to the dialect registry. +MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirDialectRegistry registry); + +/// Register all translations to LLVM IR for dialects that can support it. +MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); + +/// Register all compiler passes of MLIR. +MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(void); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_REGISTER_EVERYTHING_H diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h deleted file mode 100644 index 4cfc96719..000000000 --- a/mlir/include/mlir-c/Registration.h +++ /dev/null @@ -1,67 +0,0 @@ -//===-- mlir-c/Registration.h - Registration functions for MLIR ---*- C -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_C_REGISTRATION_H -#define MLIR_C_REGISTRATION_H - -#include "mlir-c/IR.h" - -#ifdef __cplusplus -extern "C" { -#endif - -//===----------------------------------------------------------------------===// -// Dialect registration declarations. -// Registration entry-points for each dialect are declared using the common -// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect -// API name (i.e. "Standard", "Tensor", "Linalg") and namespace (i.e. "std", -// "tensor", "linalg"). The following declarations are produced: -// -// /// Gets the above hook methods in struct form for a dialect by namespace. -// /// This is intended to facilitate dynamic lookup and registration of -// /// dialects via a plugin facility based on shared library symbol lookup. -// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); -// -// This is done via a common macro to facilitate future expansion to -// registration schemes. -//===----------------------------------------------------------------------===// - -struct MlirDialectHandle { - const void *ptr; -}; -typedef struct MlirDialectHandle MlirDialectHandle; - -#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ - MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() - -/// Returns the namespace associated with the provided dialect handle. -MLIR_CAPI_EXPORTED -MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); - -/// Registers the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, - MlirContext); - -/// Loads the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, - MlirContext); - -/// Registers all dialects known to core MLIR with the provided Context. -/// This is needed before creating IR for these Dialects. -/// TODO: Remove this function once the real registration API is finished. -MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirContext context); - -/// Register all translations to LLVM IR for dialects that can support it. -MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); - -#ifdef __cplusplus -} -#endif - -#endif // MLIR_C_REGISTRATION_H diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h new file mode 100644 index 000000000..61d344631 --- /dev/null +++ b/mlir/include/mlir-c/Rewrite.h @@ -0,0 +1,330 @@ +//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the registration and creation method for +// rewrite patterns. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_REWRITE_H +#define MLIR_C_REWRITE_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Config/mlir-config.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +/// Opaque type declarations (see mlir-c/IR.h for more details). +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirRewriterBase, void); +DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); +DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); +DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); + +//===----------------------------------------------------------------------===// +/// RewriterBase API inherited from OpBuilder +//===----------------------------------------------------------------------===// + +/// Get the MLIR context referenced by the rewriter. +MLIR_CAPI_EXPORTED MlirContext +mlirRewriterBaseGetContext(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// Insertion points methods +//===----------------------------------------------------------------------===// + +// These do not include functions using Block::iterator or Region::iterator, as +// they are not exposed by the C API yet. Similarly for methods using +// `InsertPoint` directly. + +/// Reset the insertion point to no location. Creating an operation without a +/// set insertion point is an error, but this can still be useful when the +/// current insertion point a builder refers to is being removed. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter); + +/// Sets the insertion point to the specified operation, which will cause +/// subsequent insertions to go right before it. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, + MlirOperation op); + +/// Sets the insertion point to the node after the specified operation, which +/// will cause subsequent insertions to go right after it. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, + MlirOperation op); + +/// Sets the insertion point to the node after the specified value. If value +/// has a defining operation, sets the insertion point to the node after such +/// defining operation. This will cause subsequent insertions to go right +/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to +/// the start of its block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, + MlirValue value); + +/// Sets the insertion point to the start of the specified block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, + MlirBlock block); + +/// Sets the insertion point to the end of the specified block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, + MlirBlock block); + +/// Return the block the current insertion point belongs to. Note that the +/// insertion point is not necessarily the end of the block. +MLIR_CAPI_EXPORTED MlirBlock +mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter); + +/// Returns the current block of the rewriter. +MLIR_CAPI_EXPORTED MlirBlock +mlirRewriterBaseGetBlock(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// Block and operation creation/insertion/cloning +//===----------------------------------------------------------------------===// + +// These functions do not include the IRMapper, as it is not yet exposed by the +// C API. + +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is placed before 'insertBefore'. `locs` contains the +/// locations of the inserted arguments, and should match the size of +/// `argTypes`. +MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore( + MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, + MlirType const *argTypes, MlirLocation const *locations); + +/// Insert the given operation at the current insertion point and return it. +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op); + +/// Creates a deep copy of the specified operation. +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op); + +/// Creates a deep copy of this operation but keep the operation regions +/// empty. +MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions( + MlirRewriterBase rewriter, MlirOperation op); + +/// Clone the blocks that belong to "region" before the given position in +/// another region "parent". +MLIR_CAPI_EXPORTED void +mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, + MlirBlock before); + +//===----------------------------------------------------------------------===// +/// RewriterBase API +//===----------------------------------------------------------------------===// + +/// Move the blocks that belong to "region" before the given position in +/// another region "parent". The two regions must be different. The caller +/// is responsible for creating or updating the operation transferring flow +/// of control to the region and passing it the correct block arguments. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, + MlirBlock before); + +/// Replace the results of the given (original) operation with the specified +/// list of values (replacements). The result types of the given op and the +/// replacements must match. The original op is erased. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, + intptr_t nValues, MlirValue const *values); + +/// Replace the results of the given (original) operation with the specified +/// new op (replacement). The result types of the two ops must match. The +/// original op is erased. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, + MlirOperation op, MlirOperation newOp); + +/// Erases an operation that is known to have no uses. +MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, + MlirOperation op); + +/// Erases a block along with all operations inside it. +MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, + MlirBlock block); + +/// Inline the operations of block 'source' before the operation 'op'. The +/// source block will be deleted and must have no uses. 'argValues' is used to +/// replace the block arguments of 'source' +/// +/// The source block must have no successors. Otherwise, the resulting IR +/// would have unreachable operations. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, + MlirOperation op, intptr_t nArgValues, + MlirValue const *argValues); + +/// Inline the operations of block 'source' into the end of block 'dest'. The +/// source block will be deleted and must have no uses. 'argValues' is used to +/// replace the block arguments of 'source' +/// +/// The dest block must have no successors. Otherwise, the resulting IR would +/// have unreachable operation. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, + MlirBlock source, + MlirBlock dest, + intptr_t nArgValues, + MlirValue const *argValues); + +/// Unlink this operation from its current block and insert it right before +/// `existingOp` which may be in the same or another block in the same +/// function. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation existingOp); + +/// Unlink this operation from its current block and insert it right after +/// `existingOp` which may be in the same or another block in the same +/// function. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation existingOp); + +/// Unlink this block and insert it right before `existingBlock`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, + MlirBlock existingBlock); + +/// This method is used to notify the rewriter that an in-place operation +/// modification is about to happen. A call to this function *must* be +/// followed by a call to either `finalizeOpModification` or +/// `cancelOpModification`. This is a minor efficiency win (it avoids creating +/// a new operation and removing the old one) but also often allows simpler +/// code in the client. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// This method is used to signal the end of an in-place modification of the +/// given operation. This can only be called on operations that were provided +/// to a call to `startOpModification`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// This method cancels a pending in-place modification. This can only be +/// called on operations that were provided to a call to +/// `startOpModification`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, + MlirValue to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith( + MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, + MlirValue const *to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced) +/// and that the `from` operation is about to be replaced. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, + MlirOperation from, intptr_t nTo, + MlirValue const *to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced) +/// and that the `from` operation is about to be replaced. +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation( + MlirRewriterBase rewriter, MlirOperation from, MlirOperation to); + +/// Find uses of `from` within `block` and replace them with `to`. Also notify +/// the listener about every in-place op modification (for every use that was +/// replaced). The optional `allUsesReplaced` flag is set to "true" if all +/// uses were replaced. +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock( + MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, + MlirValue const *newValues, MlirBlock block); + +/// Find uses of `from` and replace them with `to` except if the user is +/// `exceptedUser`. Also notify the listener about every in-place op +/// modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, + MlirValue to, MlirOperation exceptedUser); + +//===----------------------------------------------------------------------===// +/// IRRewriter API +//===----------------------------------------------------------------------===// + +/// Create an IRRewriter and transfer ownership to the caller. +MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context); + +/// Create an IRRewriter and transfer ownership to the caller. Additionally +/// set the insertion point before the operation. +MLIR_CAPI_EXPORTED MlirRewriterBase +mlirIRRewriterCreateFromOp(MlirOperation op); + +/// Takes an IRRewriter owned by the caller and destroys it. It is the +/// responsibility of the user to only pass an IRRewriter class. +MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// FrozenRewritePatternSet API +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet op); + +MLIR_CAPI_EXPORTED void +mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); + +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( + MlirModule op, MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig); + +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); + +MLIR_CAPI_EXPORTED MlirPDLPatternModule +mlirPDLPatternModuleFromModule(MlirModule op); + +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); + +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + +#undef DEFINE_C_API_STRUCT + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_REWRITE_H diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index 340f8ec8b..78fc94f93 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -22,9 +22,17 @@ //===----------------------------------------------------------------------===// // Visibility annotations. // Use MLIR_CAPI_EXPORTED for exported functions. +// +// On Windows, if MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC is defined, then +// __declspec(dllexport) and __declspec(dllimport) will be generated. This +// can only be enabled if actually building DLLs. It is generally, mutually +// exclusive with the use of other mechanisms for managing imports/exports +// (i.e. CMake's WINDOWS_EXPORT_ALL_SYMBOLS feature). //===----------------------------------------------------------------------===// -#if defined(MLIR_CAPI_DISABLE_VISIBILITY_ANNOTATIONS) +#if (defined(_WIN32) || defined(__CYGWIN__)) && \ + !defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC) +// Visibility annotations disabled. #define MLIR_CAPI_EXPORTED #elif defined(_WIN32) || defined(__CYGWIN__) // Windows visibility declarations. @@ -42,6 +50,19 @@ extern "C" { #endif +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +/// Re-export llvm::ThreadPool so as to avoid including the LLVM C API directly. +DEFINE_C_API_STRUCT(MlirLlvmThreadPool, void); +DEFINE_C_API_STRUCT(MlirTypeID, const void); +DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void); + +#undef DEFINE_C_API_STRUCT + //===----------------------------------------------------------------------===// // MlirStringRef. //===----------------------------------------------------------------------===// @@ -71,6 +92,10 @@ inline static MlirStringRef mlirStringRefCreate(const char *str, MLIR_CAPI_EXPORTED MlirStringRef mlirStringRefCreateFromCString(const char *str); +/// Returns true if two string references are equal, false otherwise. +MLIR_CAPI_EXPORTED bool mlirStringRefEqual(MlirStringRef string, + MlirStringRef other); + /// A callback for returning string references. /// /// This function is called back by the functions that need to return a @@ -104,17 +129,60 @@ inline static bool mlirLogicalResultIsFailure(MlirLogicalResult res) { } /// Creates a logical result representing a success. -inline static MlirLogicalResult mlirLogicalResultSuccess() { +inline static MlirLogicalResult mlirLogicalResultSuccess(void) { MlirLogicalResult res = {1}; return res; } /// Creates a logical result representing a failure. -inline static MlirLogicalResult mlirLogicalResultFailure() { +inline static MlirLogicalResult mlirLogicalResultFailure(void) { MlirLogicalResult res = {0}; return res; } +//===----------------------------------------------------------------------===// +// MlirLlvmThreadPool. +//===----------------------------------------------------------------------===// + +/// Create an LLVM thread pool. This is reexported here to avoid directly +/// pulling in the LLVM headers directly. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void); + +/// Destroy an LLVM thread pool. +MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool); + +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +/// `ptr` must be 8 byte aligned and unique to a type valid for the duration of +/// the returned type id's usage +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr); + +/// Checks whether a type id is null. +static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; } + +/// Checks if two type ids are equal. +MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); + +/// Returns the hash value of the type id. +MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); + +//===----------------------------------------------------------------------===// +// TypeIDAllocator API. +//===----------------------------------------------------------------------===// + +/// Creates a type id allocator for dynamic type id creation +MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate(void); + +/// Deallocates the allocator and all allocated type ids +MLIR_CAPI_EXPORTED void +mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator); + +/// Allocates a type id that is valid for the lifetime of the allocator +MLIR_CAPI_EXPORTED MlirTypeID +mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Target/ExportSMTLIB.h b/mlir/include/mlir-c/Target/ExportSMTLIB.h new file mode 100644 index 000000000..59beda54d --- /dev/null +++ b/mlir/include/mlir-c/Target/ExportSMTLIB.h @@ -0,0 +1,36 @@ +//===- mlir-c/Target/ExportSMTLIB.h - C API for emitting SMTLIB ---*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for emitting SMTLIB from an MLIR module. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_EXPORTSMTLIB_H +#define MLIR_C_EXPORTSMTLIB_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Emits SMTLIB for the specified module using the provided callback and user +/// data +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData, + bool inlineSingleUseValues, bool indentLetBody); + +MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB( + MlirOperation, MlirStringCallback, void *userData, + bool inlineSingleUseValues, bool indentLetBody); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_EXPORTSMTLIB_H diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h new file mode 100644 index 000000000..b5f948961 --- /dev/null +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -0,0 +1,82 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to target LLVMIR with MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_TARGET_LLVMIR_H +#define MLIR_C_TARGET_LLVMIR_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "llvm-c/Core.h" +#include "llvm-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Translate operation that satisfies LLVM dialect module requirements into an +/// LLVM IR module living in the given context. This translates operations from +/// any dilalect that has a registered implementation of +/// LLVMTranslationDialectInterface. +/// +/// \returns the generated LLVM IR Module from the translated MLIR module, it is +/// owned by the caller. +MLIR_CAPI_EXPORTED LLVMModuleRef +mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); + +struct MlirTypeFromLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeFromLLVMIRTranslator MlirTypeFromLLVMIRTranslator; + +/// Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx); + +/// Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeFromLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeFromLLVMIRTranslatorDestroy(MlirTypeFromLLVMIRTranslator translator); + +/// Translates the given LLVM IR type to the MLIR LLVM dialect. +MLIR_CAPI_EXPORTED MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType); + +struct MlirTypeToLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeToLLVMIRTranslator MlirTypeToLLVMIRTranslator; + +/// Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx); + +/// Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeToLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator); + +/// Translates the given MLIR LLVM dialect to the LLVM IR type. +MLIR_CAPI_EXPORTED LLVMTypeRef mlirTypeToLLVMIRTranslatorTranslateType( + MlirTypeToLLVMIRTranslator translator, MlirType mlirType); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_TARGET_LLVMIR_H diff --git a/mlir/include/mlir/Bindings/Python/Attributes.td b/mlir/include/mlir/Bindings/Python/Attributes.td deleted file mode 100644 index f9a7fa703..000000000 --- a/mlir/include/mlir/Bindings/Python/Attributes.td +++ /dev/null @@ -1,34 +0,0 @@ -//===-- Attributes.td - Attribute mapping for Python -------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This defines the mapping between MLIR ODS attributes and the corresponding -// Python binding classes. -// -//===----------------------------------------------------------------------===// - -#ifndef PYTHON_BINDINGS_ATTRIBUTES -#define PYTHON_BINDINGS_ATTRIBUTES - -// A mapping between the attribute storage type and the corresponding Python -// type. There is not necessarily a 1-1 match for non-builtin attributes. -class PythonAttr { - string cppStorageType = c; - string pythonType = p; -} - -// Mappings between supported builtin attribtues and Python types. -def : PythonAttr<"::mlir::Attribute", "_ods_ir.Attribute">; -def : PythonAttr<"::mlir::BoolAttr", "_ods_ir.BoolAttr">; -def : PythonAttr<"::mlir::IntegerAttr", "_ods_ir.IntegerAttr">; -def : PythonAttr<"::mlir::FloatAttr", "_ods_ir.FloatAttr">; -def : PythonAttr<"::mlir::StringAttr", "_ods_ir.StringAttr">; -def : PythonAttr<"::mlir::DenseElementsAttr", "_ods_ir.DenseElementsAttr">; -def : PythonAttr<"::mlir::DenseIntElementsAttr", "_ods_ir.DenseIntElementsAttr">; -def : PythonAttr<"::mlir::DenseFPElementsAttr", "_ods_ir.DenseFPElementsAttr">; - -#endif diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h new file mode 100644 index 000000000..167002d56 --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h @@ -0,0 +1,72 @@ +//===- Diagnostics.h - Helpers for diagnostics in Python bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H +#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H + +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include + +namespace mlir { +namespace python { + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = + mlirContextAttachDiagnosticHandler(ctx, &handler, &messageStream, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(message.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { + std::string newMessage; + std::swap(message, newMessage); + return newMessage; + } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) + << std::string_view(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) << "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) << ": "; + mlirDiagnosticPrint(diag, printer, data); + for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) { + *static_cast(data) << "\n"; + MlirDiagnostic note = mlirDiagnosticGetNote(diag, i); + handler(note, data); + } + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + + std::string message; + llvm::raw_string_ostream messageStream{message}; +}; + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h new file mode 100644 index 000000000..ba9642cf2 --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -0,0 +1,31 @@ +//===- IRTypes.h - Type Interfaces ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H +#define MLIR_BINDINGS_PYTHON_IRTYPES_H + +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace mlir { + +/// Shaped Type Interface - ShapedType +class PyShapedType : public python::PyConcreteType { +public: + static const IsAFunctionTy isaFunction; + static constexpr const char *pyClassName = "ShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); + +private: + void requireHasRank(); +}; + +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h new file mode 100644 index 000000000..ca942c83d --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -0,0 +1,37 @@ +//===- Nanobind.h - Trampoline header with ignored warnings ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file is a trampoline for the nanobind headers while disabling warnings +// reported by the LLVM/MLIR build. This file avoids adding complexity build +// system side. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_NANOBIND_H +#define MLIR_BINDINGS_PYTHON_NANOBIND_H + +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wzero-length-array" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wnested-anon-types" +#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi" +#pragma GCC diagnostic ignored "-Wcovered-switch-default" +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // MLIR_BINDINGS_PYTHON_NANOBIND_H diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h new file mode 100644 index 000000000..1428d5ccf --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -0,0 +1,682 @@ +//===- NanobindAdaptors.h - Interop with MLIR APIs via nanobind -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file contains adaptors for clients of the core MLIR Python APIs to +// interop via MLIR CAPI types, using nanobind. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). +// +// It is encouraged to be used both in-tree and out-of-tree. For in-tree use +// cases, it should be used for dialect implementations (versus relying on +// Pybind-based internals of the core libraries). +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H +#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H + +#include +#include + +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +// clang-format off +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on +#include "llvm/ADT/Twine.h" + +// Raw CAPI type casters need to be declared before use, so always include them +// first. +namespace nanobind { +namespace detail { + +/// Helper to convert a presumed MLIR API object to a capsule, accepting either +/// an explicit Capsule (which can happen when two C APIs are communicating +/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR +/// attribute (through which supported MLIR Python API objects export their +/// contained API pointer as a capsule). Throws a type error if the object is +/// neither. This is intended to be used from type casters, which are invoked +/// with a raw handle (unowned). The returned object's lifetime may not extend +/// beyond the apiObject handle without explicitly having its refcount increased +/// (i.e. on return). +static std::optional +mlirApiObjectToCapsule(nanobind::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return nanobind::borrow(apiObject); + nanobind::object api = + nanobind::getattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR, nanobind::none()); + if (api.is_none()) + return {}; + return api; +} + +// Note: Currently all of the following support cast from nanobind::object to +// the Mlir* C-API type, but only a few light-weight, context-bound ones +// implicitly cast the other way because the use case has not yet emerged and +// ownership is unclear. + +/// Casts object <-> MlirAffineMap. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToAffineMap(capsule->ptr()); + return !mlirAffineMapIsNull(value); + } + return false; + } + static handle from_cpp(MlirAffineMap v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonAffineMapToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("AffineMap") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirAttribute. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToAttribute(capsule->ptr()); + return !mlirAttributeIsNull(value); + } + return false; + } + static handle from_cpp(MlirAttribute v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonAttributeToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +/// Casts object -> MlirBlock. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToBlock(capsule->ptr()); + return !mlirBlockIsNull(value); + } + return false; + } +}; + +/// Casts object -> MlirContext. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirContext, const_name("MlirContext")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (src.is_none()) { + // Gets the current thread-bound context. + // TODO: This raises an error of "No current context" currently. + // Update the implementation to pretty-print the helpful error that the + // core implementations print in this case. + src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Context") + .attr("current"); + } + std::optional capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule->ptr()); + return !mlirContextIsNull(value); + } +}; + +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToDialectRegistry(capsule->ptr()); + return !mlirDialectRegistryIsNull(value); + } + return false; + } + static handle from_cpp(MlirDialectRegistry v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = nanobind::steal( + mlirPythonDialectRegistryToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirLocation. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (src.is_none()) { + // Gets the current thread-bound context. + src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr("current"); + } + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToLocation(capsule->ptr()); + return !mlirLocationIsNull(value); + } + return false; + } + static handle from_cpp(MlirLocation v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonLocationToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirModule. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirModule, const_name("MlirModule")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToModule(capsule->ptr()); + return !mlirModuleIsNull(value); + } + return false; + } + static handle from_cpp(MlirModule v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonModuleToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Module") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirFrozenRewritePatternSet, + const_name("MlirFrozenRewritePatternSet")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr()); + return value.ptr != nullptr; + } + return false; + } + static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, + handle) noexcept { + nanobind::object capsule = nanobind::steal( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirOperation. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToOperation(capsule->ptr()); + return !mlirOperationIsNull(value); + } + return false; + } + static handle from_cpp(MlirOperation v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonOperationToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Operation") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirValue. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirValue, const_name("MlirValue")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToValue(capsule->ptr()); + return !mlirValueIsNull(value); + } + return false; + } + static handle from_cpp(MlirValue v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonValueToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + }; +}; + +/// Casts object -> MlirPassManager. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToPassManager(capsule->ptr()); + return !mlirPassManagerIsNull(value); + } + return false; + } +}; + +/// Casts object <-> MlirTypeID. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToTypeID(capsule->ptr()); + return !mlirTypeIDIsNull(value); + } + return false; + } + static handle from_cpp(MlirTypeID v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonTypeIDToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirType. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirType, const_name("MlirType")) + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) { + value = mlirPythonCapsuleToType(capsule->ptr()); + return !mlirTypeIsNull(value); + } + return false; + } + static handle from_cpp(MlirType t, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonTypeToCapsule(t)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +/// Casts MlirStringRef -> object. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef")) + static handle from_cpp(MlirStringRef s, rv_policy, + cleanup_list *cleanup) noexcept { + return nanobind::str(s.data, s.length).release(); + } +}; + +} // namespace detail +} // namespace nanobind + +namespace mlir { +namespace python { +namespace nanobind_adaptors { + +/// Provides a facility like nanobind::class_ for defining a new class in a +/// scope, but this allows extension of an arbitrary Python class, defining +/// methods on it is a similar way. Classes defined in this way are very similar +/// to if defined in Python in the usual way but use nanobind machinery to +/// do it. These are not "real" nanobind classes but pure Python classes +/// with no relation to a concrete C++ class. +/// +/// Derived from a discussion upstream: +/// https://github.com/pybind/pybind11/issues/1193 +/// (plus a fair amount of extra curricular poking) +/// TODO: If this proves useful, see about including it in nanobind. +class pure_subclass { +public: + pure_subclass(nanobind::handle scope, const char *derivedClassName, + const nanobind::object &superClass) { + nanobind::object pyType = + nanobind::borrow((PyObject *)&PyType_Type); + nanobind::object metaclass = pyType(superClass); + nanobind::dict attributes; + + thisClass = metaclass(derivedClassName, nanobind::make_tuple(superClass), + attributes); + scope.attr(derivedClassName) = thisClass; + thisClass.attr("__module__") = scope.attr("__name__"); + } + + template + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + nanobind::object cf = nanobind::cpp_function( + std::forward(f), nanobind::name(name), nanobind::is_method(), + nanobind::scope(thisClass), extra...); + thisClass.attr(name) = cf; + return *this; + } + + template + pure_subclass &def_property_readonly(const char *name, Func &&f, + const Extra &...extra) { + nanobind::object cf = nanobind::cpp_function( + std::forward(f), nanobind::name(name), nanobind::is_method(), + nanobind::scope(thisClass), extra...); + auto builtinProperty = + nanobind::borrow((PyObject *)&PyProperty_Type); + thisClass.attr(name) = builtinProperty(cf); + return *this; + } + + template + pure_subclass &def_staticmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_staticmethod(...) called with a non-static member " + "function pointer"); + nanobind::object cf = nanobind::cpp_function( + std::forward(f), + nanobind::name(name), // nanobind::scope(thisClass), + extra...); + thisClass.attr(name) = cf; + return *this; + } + + template + pure_subclass &def_classmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_classmethod(...) called with a non-static member " + "function pointer"); + nanobind::object cf = nanobind::cpp_function( + std::forward(f), + nanobind::name(name), // nanobind::scope(thisClass), + extra...); + thisClass.attr(name) = + nanobind::borrow(PyClassMethod_New(cf.ptr())); + return *this; + } + + nanobind::object get_class() const { return thisClass; } + +protected: + nanobind::object superClass; + nanobind::object thisClass; +}; + +/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting +/// constructor and type checking methods. +class mlir_attribute_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_attribute_subclass( + scope, attrClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Attribute super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_attribute_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it is hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureTypeName]( + nanobind::object cls, nanobind::object otherAttribute) { + MlirAttribute rawAttribute; + if (!nanobind::try_cast(otherAttribute, + rawAttribute) || + !isaFunction(rawAttribute)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherAttribute)); + throw std::invalid_argument( + (llvm::Twine("Cannot cast attribute to ") + captureTypeName + + " (from " + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + static const char kIsinstanceSig[] = + "def isinstance(other_attribute: " MAKE_MLIR_PYTHON_QUALNAME( + "ir") ".Attribute) -> bool"; + def_staticmethod( + "isinstance", + [isaFunction](MlirAttribute other) { return isaFunction(other); }, + nanobind::arg("other_attribute"), nanobind::sig(kIsinstanceSig)); + def("__repr__", [superCls, captureTypeName](nanobind::object self) { + return nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(nanobind::cpp_function( + [thisClass = thisClass](const nanobind::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_type_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_type_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Type super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_type_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it is hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureTypeName](nanobind::object cls, + nanobind::object otherType) { + MlirType rawType; + if (!nanobind::try_cast(otherType, rawType) || + !isaFunction(rawType)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherType)); + throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + + captureTypeName + " (from " + + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherType); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + static const char kIsinstanceSig[] = + "def isinstance(other_type: " MAKE_MLIR_PYTHON_QUALNAME( + "ir") ".Type) -> bool"; + def_staticmethod( + "isinstance", + [isaFunction](MlirType other) { return isaFunction(other); }, + nanobind::arg("other_type"), nanobind::sig(kIsinstanceSig)); + def("__repr__", [superCls, captureTypeName](nanobind::object self) { + return nanobind::cast( + nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName)); + }); + if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(nanobind::cpp_function( + [thisClass = thisClass](const nanobind::object &mlirType) { + return thisClass(mlirType); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(nanobind::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value")) {} + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(nanobind::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it is hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureValueName](nanobind::object cls, + nanobind::object otherValue) { + MlirValue rawValue; + if (!nanobind::try_cast(otherValue, rawValue) || + !isaFunction(rawValue)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherValue)); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + static const char kIsinstanceSig[] = + "def isinstance(other_value: " MAKE_MLIR_PYTHON_QUALNAME( + "ir") ".Value) -> bool"; + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + nanobind::arg("other_value"), nanobind::sig(kIsinstanceSig)); + } +}; + +} // namespace nanobind_adaptors + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h new file mode 100644 index 000000000..edc69774b --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -0,0 +1,616 @@ +//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file contains adaptors for clients of the core MLIR Python APIs to +// interop via MLIR CAPI types, using pybind11. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). +// +// It is encouraged to be used both in-tree and out-of-tree. For in-tree use +// cases, it should be used for dialect implementations (versus relying on +// Pybind-based internals of the core libraries). +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H +#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H + +#include +#include +#include +#include + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" + +#include "llvm/ADT/Twine.h" + +namespace py = pybind11; +using namespace py::literals; + +// Raw CAPI type casters need to be declared before use, so always include them +// first. +namespace pybind11 { +namespace detail { + +/// Helper to convert a presumed MLIR API object to a capsule, accepting either +/// an explicit Capsule (which can happen when two C APIs are communicating +/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR +/// attribute (through which supported MLIR Python API objects export their +/// contained API pointer as a capsule). Throws a type error if the object is +/// neither. This is intended to be used from type casters, which are invoked +/// with a raw handle (unowned). The returned object's lifetime may not extend +/// beyond the apiObject handle without explicitly having its refcount increased +/// (i.e. on return). +static py::object mlirApiObjectToCapsule(py::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return py::reinterpret_borrow(apiObject); + if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + auto repr = py::repr(apiObject).cast(); + throw py::type_error( + (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); + } + return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); +} + +// Note: Currently all of the following support cast from py::object to the +// Mlir* C-API type, but only a few light-weight, context-bound ones +// implicitly cast the other way because the use case has not yet emerged and +// ownership is unclear. + +/// Casts object <-> MlirAffineMap. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(value)) { + return false; + } + return !mlirAffineMapIsNull(value); + } + static handle cast(MlirAffineMap v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonAffineMapToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("AffineMap") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirAttribute. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAttribute(capsule.ptr()); + return !mlirAttributeIsNull(value); + } + static handle cast(MlirAttribute v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonAttributeToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +/// Casts object -> MlirBlock. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToBlock(capsule.ptr()); + return !mlirBlockIsNull(value); + } +}; + +/// Casts object -> MlirContext. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); + bool load(handle src, bool) { + if (src.is_none()) { + // Gets the current thread-bound context. + // TODO: This raises an error of "No current context" currently. + // Update the implementation to pretty-print the helpful error that the + // core implementations print in this case. + src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Context") + .attr("current"); + } + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule.ptr()); + return !mlirContextIsNull(value); + } +}; + +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle cast(MlirDialectRegistry v, return_value_policy, handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirLocation. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); + bool load(handle src, bool) { + if (src.is_none()) { + // Gets the current thread-bound context. + src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr("current"); + } + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToLocation(capsule.ptr()); + return !mlirLocationIsNull(value); + } + static handle cast(MlirLocation v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonLocationToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirModule. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToModule(capsule.ptr()); + return !mlirModuleIsNull(value); + } + static handle cast(MlirModule v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonModuleToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Module") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, + _("MlirFrozenRewritePatternSet")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + return value.ptr != nullptr; + } + static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, + handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirOperation. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToOperation(capsule.ptr()); + return !mlirOperationIsNull(value); + } + static handle cast(MlirOperation v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonOperationToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Operation") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirValue. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToValue(capsule.ptr()); + return !mlirValueIsNull(value); + } + static handle cast(MlirValue v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonValueToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + }; +}; + +/// Casts object -> MlirPassManager. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + return !mlirPassManagerIsNull(value); + } +}; + +/// Casts object <-> MlirTypeID. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToTypeID(capsule.ptr()); + return !mlirTypeIDIsNull(value); + } + static handle cast(MlirTypeID v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonTypeIDToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirType. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToType(capsule.ptr()); + return !mlirTypeIsNull(value); + } + static handle cast(MlirType t, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonTypeToCapsule(t)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +} // namespace detail +} // namespace pybind11 + +namespace mlir { +namespace python { +namespace adaptors { + +/// Provides a facility like py::class_ for defining a new class in a scope, +/// but this allows extension of an arbitrary Python class, defining methods +/// on it is a similar way. Classes defined in this way are very similar to +/// if defined in Python in the usual way but use Pybind11 machinery to do +/// it. These are not "real" Pybind11 classes but pure Python classes with no +/// relation to a concrete C++ class. +/// +/// Derived from a discussion upstream: +/// https://github.com/pybind/pybind11/issues/1193 +/// (plus a fair amount of extra curricular poking) +/// TODO: If this proves useful, see about including it in pybind11. +class pure_subclass { +public: + pure_subclass(py::handle scope, const char *derivedClassName, + const py::object &superClass) { + py::object pyType = + py::reinterpret_borrow((PyObject *)&PyType_Type); + py::object metaclass = pyType(superClass); + py::dict attributes; + + thisClass = + metaclass(derivedClassName, py::make_tuple(superClass), attributes); + scope.attr(derivedClassName) = thisClass; + } + + template + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + py::cpp_function cf( + std::forward(f), py::name(name), py::is_method(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + thisClass.attr(cf.name()) = cf; + return *this; + } + + template + pure_subclass &def_property_readonly(const char *name, Func &&f, + const Extra &...extra) { + py::cpp_function cf( + std::forward(f), py::name(name), py::is_method(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + auto builtinProperty = + py::reinterpret_borrow((PyObject *)&PyProperty_Type); + thisClass.attr(name) = builtinProperty(cf); + return *this; + } + + template + pure_subclass &def_staticmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_staticmethod(...) called with a non-static member " + "function pointer"); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); + thisClass.attr(cf.name()) = py::staticmethod(cf); + return *this; + } + + template + pure_subclass &def_classmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_classmethod(...) called with a non-static member " + "function pointer"); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); + thisClass.attr(cf.name()) = + py::reinterpret_borrow(PyClassMethod_New(cf.ptr())); + return *this; + } + + py::object get_class() const { return thisClass; } + +protected: + py::object superClass; + py::object thisClass; +}; + +/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting +/// constructor and type checking methods. +class mlir_attribute_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_attribute_subclass(py::handle scope, const char *attrClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_attribute_subclass( + scope, attrClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Attribute super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_attribute_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherAttribute) { + MlirAttribute rawAttribute = py::cast(otherAttribute); + if (!isaFunction(rawAttribute)) { + auto origRepr = py::repr(otherAttribute).cast(); + throw std::invalid_argument( + (llvm::Twine("Cannot cast attribute to ") + captureTypeName + + " (from " + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirAttribute other) { return isaFunction(other); }, + py::arg("other_attribute")); + def("__repr__", [superCls, captureTypeName](py::object self) { + return py::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_type_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_type_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Type super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_type_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherType) { + MlirType rawType = py::cast(otherType); + if (!isaFunction(rawType)) { + auto origRepr = py::repr(otherType).cast(); + throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + + captureTypeName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherType); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirType other) { return isaFunction(other); }, + py::arg("other_type")); + def("__repr__", [superCls, captureTypeName](py::object self) { + return py::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirType) { + return thisClass(mlirType); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) { + } + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureValueName](py::object cls, + py::object otherValue) { + MlirValue rawValue = py::cast(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = py::repr(otherValue).cast(); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + py::arg("other_value")); + } +}; + +} // namespace adaptors + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index ea7b265dd..1836cb0ac 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -15,21 +15,26 @@ #ifndef MLIR_CAPI_IR_H #define MLIR_CAPI_IR_H +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState) +DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig) DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) +DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) +DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) +DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) -DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) +DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr) DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h new file mode 100644 index 000000000..4154b8c9e --- /dev/null +++ b/mlir/include/mlir/CAPI/Interfaces.h @@ -0,0 +1,18 @@ +//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// MLIR interface classes. This file should not be included from C++ code other +// than C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_INTERFACES_H +#define MLIR_CAPI_INTERFACES_H + +#endif // MLIR_CAPI_INTERFACES_H diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h index ac909d1dd..355c4bcfe 100644 --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -10,7 +10,6 @@ #define MLIR_CAPI_REGISTRATION_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" @@ -21,23 +20,23 @@ //===----------------------------------------------------------------------===// /// Hooks for dynamic discovery of dialects. -typedef void (*MlirContextRegisterDialectHook)(MlirContext context); +typedef void (*MlirDialectRegistryInsertDialectHook)( + MlirDialectRegistry registry); typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context); typedef MlirStringRef (*MlirDialectGetNamespaceHook)(); /// Structure of dialect registration hooks. struct MlirDialectRegistrationHooks { - MlirContextRegisterDialectHook registerHook; + MlirDialectRegistryInsertDialectHook insertHook; MlirContextLoadDialectHook loadHook; MlirDialectGetNamespaceHook getNamespaceHook; }; typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \ - static void mlirContextRegister##Name##Dialect(MlirContext context) { \ - mlir::DialectRegistry registry; \ - registry.insert(); \ - unwrap(context)->appendDialectRegistry(registry); \ + static void mlirDialectRegistryInsert##Name##Dialect( \ + MlirDialectRegistry registry) { \ + unwrap(registry)->insert(); \ } \ static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \ return wrap(unwrap(context)->getOrLoadDialect()); \ @@ -47,8 +46,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; } \ MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \ static MlirDialectRegistrationHooks hooks = { \ - mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \ - mlir##Name##DialectGetNamespace}; \ + mlirDialectRegistryInsert##Name##Dialect, \ + mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace}; \ return MlirDialectHandle{&hooks}; \ } diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h new file mode 100644 index 000000000..1038c0a57 --- /dev/null +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -0,0 +1,24 @@ +//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// rewrite patterns. This file should not be included from C++ code other than +// C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_REWRITE_H +#define MLIR_CAPI_REWRITE_H + +#include "mlir-c/Rewrite.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/PatternMatch.h" + +DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) + +#endif // MLIR_CAPIREWRITER_H diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index 6d9a59abf..89a460375 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -16,8 +16,14 @@ #define MLIR_CAPI_SUPPORT_H #include "mlir-c/Support.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" + +namespace llvm { +class ThreadPoolInterface; +} // namespace llvm /// Converts a StringRef into its MLIR C API equivalent. inline MlirStringRef wrap(llvm::StringRef ref) { @@ -29,14 +35,18 @@ inline llvm::StringRef unwrap(MlirStringRef ref) { return llvm::StringRef(ref.data, ref.length); } -inline MlirLogicalResult wrap(mlir::LogicalResult res) { +inline MlirLogicalResult wrap(llvm::LogicalResult res) { if (mlir::succeeded(res)) return mlirLogicalResultSuccess(); return mlirLogicalResultFailure(); } -inline mlir::LogicalResult unwrap(MlirLogicalResult res) { +inline llvm::LogicalResult unwrap(MlirLogicalResult res) { return mlir::success(mlirLogicalResultIsSuccess(res)); } +DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPoolInterface) +DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) +DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) + #endif // MLIR_CAPI_SUPPORT_H diff --git a/mlir/include/mlir/CAPI/Utils.h b/mlir/include/mlir/CAPI/Utils.h index c2e43850c..d78cdbf31 100644 --- a/mlir/include/mlir/CAPI/Utils.h +++ b/mlir/include/mlir/CAPI/Utils.h @@ -14,6 +14,8 @@ #ifndef MLIR_CAPI_UTILS_H #define MLIR_CAPI_UTILS_H +#include + #include "mlir-c/Support.h" #include "llvm/Support/raw_ostream.h" @@ -29,7 +31,7 @@ class CallbackOstream : public llvm::raw_ostream { public: CallbackOstream(std::function callback, void *opaqueData) - : raw_ostream(/*unbuffered=*/true), callback(callback), + : raw_ostream(/*unbuffered=*/true), callback(std::move(callback)), opaqueData(opaqueData), pos(0u) {} void write_impl(const char *ptr, size_t size) override { @@ -45,7 +47,7 @@ class CallbackOstream : public llvm::raw_ostream { void *opaqueData; uint64_t pos; }; -} // end namespace detail -} // end namespace mlir +} // namespace detail +} // namespace mlir #endif // MLIR_CAPI_UTILS_H diff --git a/mlir/include/mlir/CAPI/Wrap.h b/mlir/include/mlir/CAPI/Wrap.h index b8cc745d7..fd5b6e18d 100644 --- a/mlir/include/mlir/CAPI/Wrap.h +++ b/mlir/include/mlir/CAPI/Wrap.h @@ -44,7 +44,7 @@ static llvm::ArrayRef unwrapList(size_t size, CTy *first, "incompatible C and C++ types"); if (size == 0) - return llvm::None; + return {}; assert(storage.empty() && "expected to populate storage"); storage.reserve(size); diff --git a/mlir/lib/Bindings/Python/.style.yapf b/mlir/lib/Bindings/Python/.style.yapf deleted file mode 100644 index 9ef1dc15b..000000000 --- a/mlir/lib/Bindings/Python/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 diff --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp similarity index 58% rename from mlir/lib/Bindings/Python/Transforms/Transforms.cpp rename to mlir/lib/Bindings/Python/AsyncPasses.cpp index 46c469192..cfb8dcaaa 100644 --- a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp +++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp @@ -1,4 +1,4 @@ -//===- Transforms.cpp - Pybind module for the Transforms library ----------===// +//===- AsyncPasses.cpp - Pybind module for the Async passes -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,19 +6,17 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Transforms.h" +#include "mlir-c/Dialect/Async.h" -#include - -namespace py = pybind11; +#include "mlir/Bindings/Python/Nanobind.h" // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirTransforms, m) { - m.doc() = "MLIR Transforms library"; +NB_MODULE(_mlirAsyncPasses, m) { + m.doc() = "MLIR Async Dialect Passes"; - // Register all the passes in the Transforms library on load. - mlirRegisterTransformsPasses(); + // Register all Async passes on load. + mlirRegisterAsyncPasses(); } diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt deleted file mode 100644 index c444ddcc4..000000000 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ /dev/null @@ -1,112 +0,0 @@ -include(AddMLIRPythonExtension) -add_custom_target(MLIRBindingsPythonExtension) - -################################################################################ -# Copy python source tree. -################################################################################ - -file(GLOB_RECURSE PY_SRC_FILES - RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "${CMAKE_CURRENT_SOURCE_DIR}/mlir/*.py") - -add_custom_target(MLIRBindingsPythonSources ALL - DEPENDS ${PY_SRC_FILES} -) -add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources) - -foreach(PY_SRC_FILE ${PY_SRC_FILES}) - set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}") - get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY) - file(MAKE_DIRECTORY "${PY_DEST_DIR}") - add_custom_command( - TARGET MLIRBindingsPythonSources PRE_BUILD - COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}" - DEPENDS "${PY_SRC_FILE}" - COMMAND "${CMAKE_COMMAND}" -E create_symlink - "${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}" - ) -endforeach() - -################################################################################ -# Generate dialect-specific bindings. -################################################################################ - -add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps - TD_FILE BuiltinOps.td - DIALECT_NAME builtin) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps - TD_FILE LinalgOps.td - DIALECT_NAME linalg - DEPENDS LinalgOdsGen) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonShapeOps - TD_FILE ShapeOps.td - DIALECT_NAME shape) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonShapeOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps - TD_FILE StandardOps.td - DIALECT_NAME std) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonStandardOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps - TD_FILE TensorOps.td - DIALECT_NAME tensor) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) - -################################################################################ -# Build core python extension -################################################################################ -add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir - INSTALL_DIR - python - SOURCES - MainModule.cpp - IRModules.cpp - PybindUtils.cpp - Pass.cpp - ExecutionEngine.cpp -) -add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) - -# Note that we copy from the source tree just like for headers because -# it will not be polluted with py_cache runtime artifacts (from testing and -# such). -install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlir - DESTINATION python - COMPONENT MLIRBindingsPythonSources - FILES_MATCHING PATTERN "*.py" -) - -if (NOT LLVM_ENABLE_IDE) - add_llvm_install_targets( - install-MLIRBindingsPythonSources - DEPENDS MLIRBindingsPythonSources - COMPONENT MLIRBindingsPythonSources) -endif() - -# Dialect sources are generated. Install separately. -# Note that __pycache__ directories may have been left by tests and other -# executions. And __init__.py is handled as a regular source file. -install( - DIRECTORY ${PROJECT_BINARY_DIR}/python/mlir/dialects - DESTINATION python/mlir - COMPONENT MLIRBindingsPythonDialects - FILES_MATCHING PATTERN "*.py" - PATTERN "__pycache__" EXCLUDE - PATTERN "__init__.py" EXCLUDE -) - -if (NOT LLVM_ENABLE_IDE) - add_llvm_install_targets( - install-MLIRBindingsPythonDialects - DEPENDS MLIRBindingsPythonSources - COMPONENT MLIRBindingsPythonDialects) -endif() - -add_subdirectory(Transforms) -add_subdirectory(Conversions) diff --git a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt deleted file mode 100644 index ad2aeefca..000000000 --- a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -################################################################################ -# Build python extension -################################################################################ - -add_mlir_python_extension(MLIRConversionsBindingsPythonExtension _mlirConversions - INSTALL_DIR - python - SOURCES - Conversions.cpp -) diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp new file mode 100644 index 000000000..e5045cf0b --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -0,0 +1,90 @@ +//===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir-c/Dialect/GPU.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace nanobind::literals; + +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +NB_MODULE(_mlirDialectsGPU, m) { + m.doc() = "MLIR GPU Dialect"; + //===-------------------------------------------------------------------===// + // AsyncTokenType + //===-------------------------------------------------------------------===// + + auto mlirGPUAsyncTokenType = + mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType); + + mlirGPUAsyncTokenType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirGPUAsyncTokenTypeGet(ctx)); + }, + "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"), + nb::arg("ctx").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // ObjectAttr + //===-------------------------------------------------------------------===// + + mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) + .def_classmethod( + "get", + [](nb::object cls, MlirAttribute target, uint32_t format, + nb::bytes object, std::optional mlirObjectProps, + std::optional mlirKernelsAttr) { + MlirStringRef objectStrRef = mlirStringRefCreate( + static_cast(const_cast(object.data())), + object.size()); + return cls(mlirGPUObjectAttrGetWithKernels( + mlirAttributeGetContext(target), target, format, objectStrRef, + mlirObjectProps.has_value() ? *mlirObjectProps + : MlirAttribute{nullptr}, + mlirKernelsAttr.has_value() ? *mlirKernelsAttr + : MlirAttribute{nullptr})); + }, + "cls"_a, "target"_a, "format"_a, "object"_a, + "properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(), + "Gets a gpu.object from parameters.") + .def_property_readonly( + "target", + [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) + .def_property_readonly( + "format", + [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); }) + .def_property_readonly( + "object", + [](MlirAttribute self) { + MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); + return nb::bytes(stringRef.data, stringRef.length); + }) + .def_property_readonly("properties", + [](MlirAttribute self) -> nb::object { + if (mlirGPUObjectAttrHasProperties(self)) + return nb::cast( + mlirGPUObjectAttrGetProperties(self)); + return nb::none(); + }) + .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object { + if (mlirGPUObjectAttrHasKernels(self)) + return nb::cast(mlirGPUObjectAttrGetKernels(self)); + return nb::none(); + }); +} diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp new file mode 100644 index 000000000..f211e769d --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -0,0 +1,147 @@ +//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; + +using namespace nanobind::literals; + +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +void populateDialectLLVMSubmodule(const nanobind::module_ &m) { + + //===--------------------------------------------------------------------===// + // StructType + //===--------------------------------------------------------------------===// + + auto llvmStructType = + mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); + + llvmStructType.def_classmethod( + "get_literal", + [](nb::object cls, const std::vector &elements, bool packed, + MlirLocation loc) { + CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); + + MlirType type = mlirLLVMStructTypeLiteralGetChecked( + loc, elements.size(), elements.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return cls(type); + }, + "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "loc"_a.none() = nb::none()); + + llvmStructType.def_classmethod( + "get_identified", + [](nb::object cls, const std::string &name, MlirContext context) { + return cls(mlirLLVMStructTypeIdentifiedGet( + context, mlirStringRefCreate(name.data(), name.size()))); + }, + "cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none()); + + llvmStructType.def_classmethod( + "get_opaque", + [](nb::object cls, const std::string &name, MlirContext context) { + return cls(mlirLLVMStructTypeOpaqueGet( + context, mlirStringRefCreate(name.data(), name.size()))); + }, + "cls"_a, "name"_a, "context"_a.none() = nb::none()); + + llvmStructType.def( + "set_body", + [](MlirType self, const std::vector &elements, bool packed) { + MlirLogicalResult result = mlirLLVMStructTypeSetBody( + self, elements.size(), elements.data(), packed); + if (!mlirLogicalResultIsSuccess(result)) { + throw nb::value_error( + "Struct body already set to different content."); + } + }, + "elements"_a, nb::kw_only(), "packed"_a = false); + + llvmStructType.def_classmethod( + "new_identified", + [](nb::object cls, const std::string &name, + const std::vector &elements, bool packed, MlirContext ctx) { + return cls(mlirLLVMStructTypeIdentifiedNewGet( + ctx, mlirStringRefCreate(name.data(), name.length()), + elements.size(), elements.data(), packed)); + }, + "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "context"_a.none() = nb::none()); + + llvmStructType.def_property_readonly( + "name", [](MlirType type) -> std::optional { + if (mlirLLVMStructTypeIsLiteral(type)) + return std::nullopt; + + MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type); + return StringRef(stringRef.data, stringRef.length).str(); + }); + + llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object { + // Don't crash in absence of a body. + if (mlirLLVMStructTypeIsOpaque(type)) + return nb::none(); + + nb::list body; + for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e; + ++i) { + body.append(mlirLLVMStructTypeGetElementType(type, i)); + } + return body; + }); + + llvmStructType.def_property_readonly( + "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); }); + + llvmStructType.def_property_readonly( + "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); }); + + //===--------------------------------------------------------------------===// + // PointerType + //===--------------------------------------------------------------------===// + + mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) + .def_classmethod( + "get", + [](nb::object cls, std::optional addressSpace, + MlirContext context) { + CollectDiagnosticsToStringScope scope(context); + MlirType type = mlirLLVMPointerTypeGet( + context, addressSpace.has_value() ? *addressSpace : 0); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return cls(type); + }, + "cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(), + "context"_a.none() = nb::none()) + .def_property_readonly("address_space", [](MlirType type) { + return mlirLLVMPointerTypeGetAddressSpace(type); + }); +} + +NB_MODULE(_mlirDialectsLLVM, m) { + m.doc() = "MLIR LLVM Dialect"; + + populateDialectLLVMSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp new file mode 100644 index 000000000..015502371 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -0,0 +1,139 @@ +//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Linalg.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; +using namespace mlir::python::nanobind_adaptors; + +static std::optional +InferContractionDimensions(MlirOperation op) { + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensions(op); + + // Detect "empty" result. This occurs when `op` is not a contraction op, + // or when `linalg::inferContractionDims` fails. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; +} + +static std::optional +InferConvolutionDimensions(MlirOperation op) { + MlirLinalgConvolutionDimensions dims = + mlirLinalgInferConvolutionDimensions(op); + + // Detect "empty" result. This occurs when `op` is not a convolution op, + // or when `linalg::inferConvolutionDims` fails. + if (mlirAttributeIsNull(dims.batch) && + mlirAttributeIsNull(dims.outputImage) && + mlirAttributeIsNull(dims.outputChannel) && + mlirAttributeIsNull(dims.filterLoop) && + mlirAttributeIsNull(dims.inputChannel) && + mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) && + mlirAttributeIsNull(dims.dilations)) { + return std::nullopt; + } + + return dims; +} + +static void populateDialectLinalgSubmodule(nb::module_ m) { + m.def( + "fill_builtin_region", + [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); }, + nb::arg("op"), + "Fill the region for `op`, which is assumed to be a builtin named Linalg " + "op."); + + m.def("isa_contraction_op", &mlirLinalgIsAContractionOp, + "Checks if the given operation is a Linalg contraction operation.", + nb::arg("op")); + + nb::class_(m, "ContractionDimensions") + .def_prop_ro("batch", + [](const MlirLinalgContractionDimensions &self) { + return self.batch; + }) + .def_prop_ro( + "m", + [](const MlirLinalgContractionDimensions &self) { return self.m; }) + .def_prop_ro( + "n", + [](const MlirLinalgContractionDimensions &self) { return self.n; }) + .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) { + return self.k; + }); + + m.def("infer_contraction_dimensions", &InferContractionDimensions, + "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction " + "op.", + nb::arg("op")); + + m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp, + "Checks if the given operation is a Linalg convolution operation.", + nb::arg("op")); + + nb::class_(m, "ConvolutionDimensions") + .def_prop_ro("batch", + [](const MlirLinalgConvolutionDimensions &self) { + return self.batch; + }) + .def_prop_ro("output_image", + [](const MlirLinalgConvolutionDimensions &self) { + return self.outputImage; + }) + .def_prop_ro("output_channel", + [](const MlirLinalgConvolutionDimensions &self) { + return self.outputChannel; + }) + .def_prop_ro("filter_loop", + [](const MlirLinalgConvolutionDimensions &self) { + return self.filterLoop; + }) + .def_prop_ro("input_channel", + [](const MlirLinalgConvolutionDimensions &self) { + return self.inputChannel; + }) + .def_prop_ro("depth", + [](const MlirLinalgConvolutionDimensions &self) { + return self.depth; + }) + .def_prop_ro("strides", + [](const MlirLinalgConvolutionDimensions &self) { + return self.strides; + }) + .def_prop_ro("dilations", + [](const MlirLinalgConvolutionDimensions &self) { + return self.dilations; + }); + + m.def("infer_convolution_dimensions", &InferConvolutionDimensions, + "Infers convolution dimensions", nb::arg("op")); + + m.def( + "get_indexing_maps", + [](MlirOperation op) -> std::optional { + MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op); + if (mlirAttributeIsNull(attr)) + return std::nullopt; + return attr; + }, + "Returns the indexing_maps attribute for a linalg op."); +} + +NB_MODULE(_mlirDialectsLinalg, m) { + m.doc() = "MLIR Linalg dialect."; + + populateDialectLinalgSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp new file mode 100644 index 000000000..a0d6a4b4c --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -0,0 +1,41 @@ +//===--- DialectNVGPU.cpp - Pybind module for NVGPU dialect API support ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectNVGPUSubmodule(const nb::module_ &m) { + auto nvgpuTensorMapDescriptorType = mlir_type_subclass( + m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType); + + nvgpuTensorMapDescriptorType.def_classmethod( + "get", + [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo, + int oobFill, int interleave, MlirContext ctx) { + return cls(mlirNVGPUTensorMapDescriptorTypeGet( + ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave)); + }, + "Gets an instance of TensorMapDescriptorType in the same context", + nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"), + nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"), + nb::arg("ctx").none() = nb::none()); +} + +NB_MODULE(_mlirDialectsNVGPU, m) { + m.doc() = "MLIR NVGPU dialect."; + + populateDialectNVGPUSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp new file mode 100644 index 000000000..bcc6ff406 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -0,0 +1,103 @@ +//===- DialectPDL.cpp - 'pdl' dialect submodule ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/PDL.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +void populateDialectPDLSubmodule(const nanobind::module_ &m) { + //===-------------------------------------------------------------------===// + // PDLType + //===-------------------------------------------------------------------===// + + auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType); + + //===-------------------------------------------------------------------===// + // AttributeType + //===-------------------------------------------------------------------===// + + auto attributeType = + mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType); + attributeType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirPDLAttributeTypeGet(ctx)); + }, + "Get an instance of AttributeType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // OperationType + //===-------------------------------------------------------------------===// + + auto operationType = + mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType); + operationType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirPDLOperationTypeGet(ctx)); + }, + "Get an instance of OperationType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // RangeType + //===-------------------------------------------------------------------===// + + auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType); + rangeType.def_classmethod( + "get", + [](nb::object cls, MlirType elementType) { + return cls(mlirPDLRangeTypeGet(elementType)); + }, + "Gets an instance of RangeType in the same context as the provided " + "element type.", + nb::arg("cls"), nb::arg("element_type")); + rangeType.def_property_readonly( + "element_type", + [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); }, + "Get the element type."); + + //===-------------------------------------------------------------------===// + // TypeType + //===-------------------------------------------------------------------===// + + auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType); + typeType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirPDLTypeTypeGet(ctx)); + }, + "Get an instance of TypeType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // ValueType + //===-------------------------------------------------------------------===// + + auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType); + valueType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirPDLValueTypeGet(ctx)); + }, + "Get an instance of TypeType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); +} + +NB_MODULE(_mlirDialectsPDL, m) { + m.doc() = "MLIR PDL dialect."; + populateDialectPDLSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp new file mode 100644 index 000000000..55571cd1e --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -0,0 +1,389 @@ +//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; +using namespace llvm; +using namespace mlir; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectQuantSubmodule(const nb::module_ &m) { + //===-------------------------------------------------------------------===// + // QuantizedType + //===-------------------------------------------------------------------===// + + auto quantizedType = + mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); + quantizedType.def_staticmethod( + "default_minimum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, + integralWidth); + }, + "Default minimum value for the integer with the specified signedness and " + "bit width.", + nb::arg("is_signed"), nb::arg("integral_width")); + quantizedType.def_staticmethod( + "default_maximum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, + integralWidth); + }, + "Default maximum value for the integer with the specified signedness and " + "bit width.", + nb::arg("is_signed"), nb::arg("integral_width")); + quantizedType.def_property_readonly( + "expressed_type", + [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, + "Type expressed by this quantized type."); + quantizedType.def_property_readonly( + "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, + "Flags of this quantized type (named accessors should be preferred to " + "this)"); + quantizedType.def_property_readonly( + "is_signed", + [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, + "Signedness of this quantized type."); + quantizedType.def_property_readonly( + "storage_type", + [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, + "Storage type backing this quantized type."); + quantizedType.def_property_readonly( + "storage_type_min", + [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, + "The minimum value held by the storage type of this quantized type."); + quantizedType.def_property_readonly( + "storage_type_max", + [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, + "The maximum value held by the storage type of this quantized type."); + quantizedType.def_property_readonly( + "storage_type_integral_width", + [](MlirType type) { + return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); + }, + "The bitwidth of the storage type of this quantized type."); + quantizedType.def( + "is_compatible_expressed_type", + [](MlirType type, MlirType candidate) { + return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); + }, + "Checks whether the candidate type can be expressed by this quantized " + "type.", + nb::arg("candidate")); + quantizedType.def_property_readonly( + "quantized_element_type", + [](MlirType type) { + return mlirQuantizedTypeGetQuantizedElementType(type); + }, + "Element type of this quantized type expressed as quantized type."); + quantizedType.def( + "cast_from_storage_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on the storage type of this quantized type to a " + "corresponding type based on the quantized type. Raises TypeError if the " + "cast is not valid.", + nb::arg("candidate")); + quantizedType.def_staticmethod( + "cast_to_storage_type", + [](MlirType type) { + MlirType castResult = mlirQuantizedTypeCastToStorageType(type); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the storage type of this quantized type. Raises TypeError if " + "the cast is not valid.", + nb::arg("type")); + quantizedType.def( + "cast_from_expressed_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromExpressedType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type to " + "a corresponding type based on the quantized type. Raises TypeError if " + "the cast is not valid.", + nb::arg("candidate")); + quantizedType.def_staticmethod( + "cast_to_expressed_type", + [](MlirType type) { + MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the expressed type of this quantized type. Raises TypeError if " + "the cast is not valid.", + nb::arg("type")); + quantizedType.def( + "cast_expressed_to_storage_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastExpressedToStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw nb::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type to " + "a corresponding type based on the storage type. Raises TypeError if the " + "cast is not valid.", + nb::arg("candidate")); + + quantizedType.get_class().attr("FLAG_SIGNED") = + mlirQuantizedTypeGetSignedFlag(); + + //===-------------------------------------------------------------------===// + // AnyQuantizedType + //===-------------------------------------------------------------------===// + + auto anyQuantizedType = + mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, + quantizedType.get_class()); + anyQuantizedType.def_classmethod( + "get", + [](nb::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of AnyQuantizedType in the same context as the " + "provided storage type.", + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("storage_type_min"), + nb::arg("storage_type_max")); + + //===-------------------------------------------------------------------===// + // UniformQuantizedType + //===-------------------------------------------------------------------===// + + auto uniformQuantizedType = mlir_type_subclass( + m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, + quantizedType.get_class()); + uniformQuantizedType.def_classmethod( + "get", + [](nb::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return cls(mlirUniformQuantizedTypeGet(flags, storageType, + expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of UniformQuantizedType in the same context as the " + "provided storage type.", + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"), + nb::arg("storage_type_min"), nb::arg("storage_type_max")); + uniformQuantizedType.def_property_readonly( + "scale", + [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, + "The scale designates the difference between the real values " + "corresponding to consecutive quantized values differing by 1."); + uniformQuantizedType.def_property_readonly( + "zero_point", + [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, + "The storage value corresponding to the real value 0 in the affine " + "equation."); + uniformQuantizedType.def_property_readonly( + "is_fixed_point", + [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, + "Fixed point values are real numbers divided by a scale."); + + //===-------------------------------------------------------------------===// + // UniformQuantizedPerAxisType + //===-------------------------------------------------------------------===// + auto uniformQuantizedPerAxisType = mlir_type_subclass( + m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, + quantizedType.get_class()); + uniformQuantizedPerAxisType.def_classmethod( + "get", + [](nb::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, std::vector scales, + std::vector zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (scales.size() != zeroPoints.size()) + throw nb::value_error( + "Mismatching number of scales and zero points."); + auto nDims = static_cast(scales.size()); + return cls(mlirUniformQuantizedPerAxisTypeGet( + flags, storageType, expressedType, nDims, scales.data(), + zeroPoints.data(), quantizedDimension, storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedPerAxisType in the same context as " + "the provided storage type.", + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimension"), nb::arg("storage_type_min"), + nb::arg("storage_type_max")); + uniformQuantizedPerAxisType.def_property_readonly( + "scales", + [](MlirType type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector scales; + scales.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); + scales.push_back(scale); + } + return scales; + }, + "The scales designate the difference between the real values " + "corresponding to consecutive quantized values differing by 1. The ith " + "scale corresponds to the ith slice in the quantized_dimension."); + uniformQuantizedPerAxisType.def_property_readonly( + "zero_points", + [](MlirType type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector zeroPoints; + zeroPoints.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + int64_t zeroPoint = + mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); + zeroPoints.push_back(zeroPoint); + } + return zeroPoints; + }, + "the storage values corresponding to the real value 0 in the affine " + "equation. The ith zero point corresponds to the ith slice in the " + "quantized_dimension."); + uniformQuantizedPerAxisType.def_property_readonly( + "quantized_dimension", + [](MlirType type) { + return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); + }, + "Specifies the dimension of the shape that the scales and zero points " + "correspond to."); + uniformQuantizedPerAxisType.def_property_readonly( + "is_fixed_point", + [](MlirType type) { + return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); + }, + "Fixed point values are real numbers divided by a scale."); + + //===-------------------------------------------------------------------===// + // UniformQuantizedSubChannelType + //===-------------------------------------------------------------------===// + auto uniformQuantizedSubChannelType = mlir_type_subclass( + m, "UniformQuantizedSubChannelType", + mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); + uniformQuantizedSubChannelType.def_classmethod( + "get", + [](nb::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, + std::vector quantizedDimensions, + std::vector blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + return cls(mlirUniformQuantizedSubChannelTypeGet( + flags, storageType, expressedType, scales, zeroPoints, + static_cast(blockSizes.size()), + quantizedDimensions.data(), blockSizes.data(), storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedSubChannel in the same context as " + "the provided storage type.", + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimensions"), nb::arg("block_sizes"), + nb::arg("storage_type_min"), nb::arg("storage_type_max")); + uniformQuantizedSubChannelType.def_property_readonly( + "quantized_dimensions", + [](MlirType type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector quantizedDimensions; + quantizedDimensions.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + quantizedDimensions.push_back( + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); + } + return quantizedDimensions; + }, + "Gets the quantized dimensions. Each element in the returned list " + "represents an axis of the quantized data tensor that has a specified " + "block size. The order of elements corresponds to the order of block " + "sizes returned by 'block_sizes' method. It means that the data tensor " + "is quantized along the i-th dimension in the returned list using the " + "i-th block size from block_sizes method."); + uniformQuantizedSubChannelType.def_property_readonly( + "block_sizes", + [](MlirType type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector blockSizes; + blockSizes.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + blockSizes.push_back( + mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); + } + return blockSizes; + }, + "Gets the block sizes for the quantized dimensions. The i-th element in " + "the returned list corresponds to the block size for the i-th dimension " + "in the list returned by quantized_dimensions method."); + uniformQuantizedSubChannelType.def_property_readonly( + "scales", + [](MlirType type) -> MlirAttribute { + return mlirUniformQuantizedSubChannelTypeGetScales(type); + }, + "The scales of the quantized type."); + uniformQuantizedSubChannelType.def_property_readonly( + "zero_points", + [](MlirType type) -> MlirAttribute { + return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); + }, + "The zero points of the quantized type."); + + //===-------------------------------------------------------------------===// + // CalibratedQuantizedType + //===-------------------------------------------------------------------===// + + auto calibratedQuantizedType = mlir_type_subclass( + m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, + quantizedType.get_class()); + calibratedQuantizedType.def_classmethod( + "get", + [](nb::object cls, MlirType expressedType, double min, double max) { + return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); + }, + "Gets an instance of CalibratedQuantizedType in the same context as the " + "provided expressed type.", + nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"), + nb::arg("max")); + calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { + return mlirCalibratedQuantizedTypeGetMin(type); + }); + calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { + return mlirCalibratedQuantizedTypeGetMax(type); + }); +} + +NB_MODULE(_mlirDialectsQuant, m) { + m.doc() = "MLIR Quantization dialect"; + + populateDialectQuantSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp new file mode 100644 index 000000000..4e7647729 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -0,0 +1,83 @@ +//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "NanobindUtils.h" + +#include "mlir-c/Dialect/SMT.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir-c/Target/ExportSMTLIB.h" +#include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; + +using namespace nanobind::literals; + +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +void populateDialectSMTSubmodule(nanobind::module_ &m) { + + auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) + .def_classmethod( + "get", + [](const nb::object &, MlirContext context) { + return mlirSMTTypeGetBool(context); + }, + "cls"_a, "context"_a.none() = nb::none()); + auto smtBitVectorType = + mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector) + .def_classmethod( + "get", + [](const nb::object &, int32_t width, MlirContext context) { + return mlirSMTTypeGetBitVector(context, width); + }, + "cls"_a, "width"_a, "context"_a.none() = nb::none()); + + auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, + bool indentLetBody) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(module)); + PyPrintAccumulator printAccum; + MlirLogicalResult result = mlirTranslateOperationToSMTLIB( + module, printAccum.getCallback(), printAccum.getUserData(), + inlineSingleUseValues, indentLetBody); + if (mlirLogicalResultIsSuccess(result)) + return printAccum.join(); + throw nb::value_error( + ("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage()) + .c_str()); + }; + + m.def( + "export_smtlib", + [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues, + bool indentLetBody) { + return exportSMTLIB(module, inlineSingleUseValues, indentLetBody); + }, + "module"_a, "inline_single_use_values"_a = false, + "indent_let_body"_a = false); + m.def( + "export_smtlib", + [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues, + bool indentLetBody) { + return exportSMTLIB(mlirModuleGetOperation(module), + inlineSingleUseValues, indentLetBody); + }, + "module"_a, "inline_single_use_values"_a = false, + "indent_let_body"_a = false); +} + +NB_MODULE(_mlirDialectsSMT, m) { + m.doc() = "MLIR SMT Dialect"; + + populateDialectSMTSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp new file mode 100644 index 000000000..97cebccee --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -0,0 +1,148 @@ +//===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +#include "mlir-c/AffineMap.h" +#include "mlir-c/Dialect/SparseTensor.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace llvm; +using namespace mlir; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectSparseTensorSubmodule(const nb::module_ &m) { + nb::enum_(m, "LevelFormat", nb::is_arithmetic(), + nb::is_flag()) + .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) + .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) + .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) + .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) + .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED); + + nb::enum_(m, "LevelProperty") + .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED) + .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE) + .value("soa", MLIR_SPARSE_PROPERTY_SOA); + + mlir_attribute_subclass(m, "EncodingAttr", + mlirAttributeIsASparseTensorEncodingAttr) + .def_classmethod( + "get", + [](nb::object cls, std::vector lvlTypes, + std::optional dimToLvl, + std::optional lvlToDim, int posWidth, int crdWidth, + std::optional explicitVal, + std::optional implicitVal, MlirContext context) { + return cls(mlirSparseTensorEncodingAttrGet( + context, lvlTypes.size(), lvlTypes.data(), + dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, + lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth, + crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr}, + implicitVal ? *implicitVal : MlirAttribute{nullptr})); + }, + nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(), + nb::arg("lvl_to_dim").none(), nb::arg("pos_width"), + nb::arg("crd_width"), nb::arg("explicit_val").none() = nb::none(), + nb::arg("implicit_val").none() = nb::none(), + nb::arg("context").none() = nb::none(), + "Gets a sparse_tensor.encoding from parameters.") + .def_classmethod( + "build_level_type", + [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt, + const std::vector + &properties, + unsigned n, unsigned m) { + return mlirSparseTensorEncodingAttrBuildLvlType( + lvlFmt, properties.data(), properties.size(), n, m); + }, + nb::arg("cls"), nb::arg("lvl_fmt"), + nb::arg("properties") = + std::vector(), + nb::arg("n") = 0, nb::arg("m") = 0, + "Builds a sparse_tensor.encoding.level_type from parameters.") + .def_property_readonly( + "lvl_types", + [](MlirAttribute self) { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + std::vector ret; + ret.reserve(lvlRank); + for (int l = 0; l < lvlRank; ++l) + ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l)); + return ret; + }) + .def_property_readonly( + "dim_to_lvl", + [](MlirAttribute self) -> std::optional { + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return ret; + }) + .def_property_readonly( + "lvl_to_dim", + [](MlirAttribute self) -> std::optional { + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return ret; + }) + .def_property_readonly("pos_width", + mlirSparseTensorEncodingAttrGetPosWidth) + .def_property_readonly("crd_width", + mlirSparseTensorEncodingAttrGetCrdWidth) + .def_property_readonly( + "explicit_val", + [](MlirAttribute self) -> std::optional { + MlirAttribute ret = + mlirSparseTensorEncodingAttrGetExplicitVal(self); + if (mlirAttributeIsNull(ret)) + return {}; + return ret; + }) + .def_property_readonly( + "implicit_val", + [](MlirAttribute self) -> std::optional { + MlirAttribute ret = + mlirSparseTensorEncodingAttrGetImplicitVal(self); + if (mlirAttributeIsNull(ret)) + return {}; + return ret; + }) + .def_property_readonly( + "structured_n", + [](MlirAttribute self) -> unsigned { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + return mlirSparseTensorEncodingAttrGetStructuredN( + mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); + }) + .def_property_readonly( + "structured_m", + [](MlirAttribute self) -> unsigned { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + return mlirSparseTensorEncodingAttrGetStructuredM( + mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); + }) + .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + std::vector ret; + ret.reserve(lvlRank); + for (int l = 0; l < lvlRank; l++) + ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l)); + return ret; + }); +} + +NB_MODULE(_mlirDialectsSparseTensor, m) { + m.doc() = "MLIR SparseTensor dialect."; + populateDialectSparseTensorSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp new file mode 100644 index 000000000..59a030ac6 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -0,0 +1,121 @@ +//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include + +#include "mlir-c/Dialect/Transform.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +void populateDialectTransformSubmodule(const nb::module_ &m) { + //===-------------------------------------------------------------------===// + // AnyOpType + //===-------------------------------------------------------------------===// + + auto anyOpType = + mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType, + mlirTransformAnyOpTypeGetTypeID); + anyOpType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirTransformAnyOpTypeGet(ctx)); + }, + "Get an instance of AnyOpType in the given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // AnyParamType + //===-------------------------------------------------------------------===// + + auto anyParamType = + mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType, + mlirTransformAnyParamTypeGetTypeID); + anyParamType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirTransformAnyParamTypeGet(ctx)); + }, + "Get an instance of AnyParamType in the given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // AnyValueType + //===-------------------------------------------------------------------===// + + auto anyValueType = + mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType, + mlirTransformAnyValueTypeGetTypeID); + anyValueType.def_classmethod( + "get", + [](nb::object cls, MlirContext ctx) { + return cls(mlirTransformAnyValueTypeGet(ctx)); + }, + "Get an instance of AnyValueType in the given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); + + //===-------------------------------------------------------------------===// + // OperationType + //===-------------------------------------------------------------------===// + + auto operationType = + mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType, + mlirTransformOperationTypeGetTypeID); + operationType.def_classmethod( + "get", + [](nb::object cls, const std::string &operationName, MlirContext ctx) { + MlirStringRef cOperationName = + mlirStringRefCreate(operationName.data(), operationName.size()); + return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); + }, + "Get an instance of OperationType for the given kind in the given " + "context", + nb::arg("cls"), nb::arg("operation_name"), + nb::arg("context").none() = nb::none()); + operationType.def_property_readonly( + "operation_name", + [](MlirType type) { + MlirStringRef operationName = + mlirTransformOperationTypeGetOperationName(type); + return nb::str(operationName.data, operationName.length); + }, + "Get the name of the payload operation accepted by the handle."); + + //===-------------------------------------------------------------------===// + // ParamType + //===-------------------------------------------------------------------===// + + auto paramType = + mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType, + mlirTransformParamTypeGetTypeID); + paramType.def_classmethod( + "get", + [](nb::object cls, MlirType type, MlirContext ctx) { + return cls(mlirTransformParamTypeGet(ctx, type)); + }, + "Get an instance of ParamType for the given type in the given context.", + nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none()); + paramType.def_property_readonly( + "type", + [](MlirType type) { + MlirType paramType = mlirTransformParamTypeGetType(type); + return paramType; + }, + "Get the type this ParamType is associated with."); +} + +NB_MODULE(_mlirDialectsTransform, m) { + m.doc() = "MLIR Transform dialect."; + populateDialectTransformSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp deleted file mode 100644 index f6f52e2e0..000000000 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ /dev/null @@ -1,87 +0,0 @@ -//===- ExecutionEngine.cpp - Python MLIR ExecutionEngine Bindings ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "ExecutionEngine.h" - -#include "IRModules.h" -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/ExecutionEngine.h" - -namespace py = pybind11; -using namespace mlir; -using namespace mlir::python; - -namespace { - -/// Owning Wrapper around an ExecutionEngine. -class PyExecutionEngine { -public: - PyExecutionEngine(MlirExecutionEngine executionEngine) - : executionEngine(executionEngine) {} - PyExecutionEngine(PyExecutionEngine &&other) - : executionEngine(other.executionEngine) { - other.executionEngine.ptr = nullptr; - } - ~PyExecutionEngine() { - if (!mlirExecutionEngineIsNull(executionEngine)) - mlirExecutionEngineDestroy(executionEngine); - } - MlirExecutionEngine get() { return executionEngine; } - - void release() { executionEngine.ptr = nullptr; } - pybind11::object getCapsule() { - return py::reinterpret_steal( - mlirPythonExecutionEngineToCapsule(get())); - } - - static pybind11::object createFromCapsule(pybind11::object capsule) { - MlirExecutionEngine rawPm = - mlirPythonCapsuleToExecutionEngine(capsule.ptr()); - if (mlirExecutionEngineIsNull(rawPm)) - throw py::error_already_set(); - return py::cast(PyExecutionEngine(rawPm), py::return_value_policy::move); - } - -private: - MlirExecutionEngine executionEngine; -}; - -} // anonymous namespace - -/// Create the `mlir.execution_engine` module here. -void mlir::python::populateExecutionEngineSubmodule(py::module &m) { - //---------------------------------------------------------------------------- - // Mapping of the top-level PassManager - //---------------------------------------------------------------------------- - py::class_(m, "ExecutionEngine") - .def(py::init<>([](PyModule &module) { - MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module.get()); - if (mlirExecutionEngineIsNull(executionEngine)) - throw std::runtime_error( - "Failure while creating the ExecutionEngine."); - return new PyExecutionEngine(executionEngine); - }), - "Create a new ExecutionEngine instance for the given Module. The " - "module must " - "contain only dialects that can be translated to LLVM.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyExecutionEngine::getCapsule) - .def("_testing_release", &PyExecutionEngine::release, - "Releases (leaks) the backing ExecutionEngine (for testing purpose)") - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule) - .def( - "raw_lookup", - [](PyExecutionEngine &executionEngine, const std::string &func) { - auto *res = mlirExecutionEngineLookup( - executionEngine.get(), - mlirStringRefCreate(func.c_str(), func.size())); - return (int64_t)res; - }, - "Lookup function `func` in the ExecutionEngine."); -} diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp new file mode 100644 index 000000000..81dada355 --- /dev/null +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -0,0 +1,135 @@ +//===- ExecutionEngineModule.cpp - Python module for execution engine -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/ExecutionEngine.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; + +namespace { + +/// Owning Wrapper around an ExecutionEngine. +class PyExecutionEngine { +public: + PyExecutionEngine(MlirExecutionEngine executionEngine) + : executionEngine(executionEngine) {} + PyExecutionEngine(PyExecutionEngine &&other) noexcept + : executionEngine(other.executionEngine) { + other.executionEngine.ptr = nullptr; + } + ~PyExecutionEngine() { + if (!mlirExecutionEngineIsNull(executionEngine)) + mlirExecutionEngineDestroy(executionEngine); + } + MlirExecutionEngine get() { return executionEngine; } + + void release() { + executionEngine.ptr = nullptr; + referencedObjects.clear(); + } + nb::object getCapsule() { + return nb::steal(mlirPythonExecutionEngineToCapsule(get())); + } + + // Add an object to the list of referenced objects whose lifetime must exceed + // those of the ExecutionEngine. + void addReferencedObject(const nb::object &obj) { + referencedObjects.push_back(obj); + } + + static nb::object createFromCapsule(nb::object capsule) { + MlirExecutionEngine rawPm = + mlirPythonCapsuleToExecutionEngine(capsule.ptr()); + if (mlirExecutionEngineIsNull(rawPm)) + throw nb::python_error(); + return nb::cast(PyExecutionEngine(rawPm), nb::rv_policy::move); + } + +private: + MlirExecutionEngine executionEngine; + // We support Python ctypes closures as callbacks. Keep a list of the objects + // so that they don't get garbage collected. (The ExecutionEngine itself + // just holds raw pointers with no lifetime semantics). + std::vector referencedObjects; +}; + +} // namespace + +/// Create the `mlir.execution_engine` module here. +NB_MODULE(_mlirExecutionEngine, m) { + m.doc() = "MLIR Execution Engine"; + + //---------------------------------------------------------------------------- + // Mapping of the top-level PassManager + //---------------------------------------------------------------------------- + nb::class_(m, "ExecutionEngine") + .def( + "__init__", + [](PyExecutionEngine &self, MlirModule module, int optLevel, + const std::vector &sharedLibPaths, + bool enableObjectDump) { + llvm::SmallVector libPaths; + for (const std::string &path : sharedLibPaths) + libPaths.push_back({path.c_str(), path.length()}); + MlirExecutionEngine executionEngine = + mlirExecutionEngineCreate(module, optLevel, libPaths.size(), + libPaths.data(), enableObjectDump); + if (mlirExecutionEngineIsNull(executionEngine)) + throw std::runtime_error( + "Failure while creating the ExecutionEngine."); + new (&self) PyExecutionEngine(executionEngine); + }, + nb::arg("module"), nb::arg("opt_level") = 2, + nb::arg("shared_libs") = nb::list(), + nb::arg("enable_object_dump") = true, + "Create a new ExecutionEngine instance for the given Module. The " + "module must contain only dialects that can be translated to LLVM. " + "Perform transformations and code generation at the optimization " + "level `opt_level` if specified, or otherwise at the default " + "level of two (-O2). Load a list of libraries specified in " + "`shared_libs`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule) + .def("_testing_release", &PyExecutionEngine::release, + "Releases (leaks) the backing ExecutionEngine (for testing purpose)") + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule) + .def( + "raw_lookup", + [](PyExecutionEngine &executionEngine, const std::string &func) { + auto *res = mlirExecutionEngineLookupPacked( + executionEngine.get(), + mlirStringRefCreate(func.c_str(), func.size())); + return reinterpret_cast(res); + }, + nb::arg("func_name"), + "Lookup function `func` in the ExecutionEngine.") + .def( + "raw_register_runtime", + [](PyExecutionEngine &executionEngine, const std::string &name, + nb::object callbackObj) { + executionEngine.addReferencedObject(callbackObj); + uintptr_t rawSym = + nb::cast(nb::getattr(callbackObj, "value")); + mlirExecutionEngineRegisterSymbol( + executionEngine.get(), + mlirStringRefCreate(name.c_str(), name.size()), + reinterpret_cast(rawSym)); + }, + nb::arg("name"), nb::arg("callback"), + "Register `callback` as the runtime symbol `name`.") + .def( + "dump_to_object_file", + [](PyExecutionEngine &executionEngine, const std::string &fileName) { + mlirExecutionEngineDumpToObjectFile( + executionEngine.get(), + mlirStringRefCreate(fileName.c_str(), fileName.size())); + }, + nb::arg("file_name"), "Dump ExecutionEngine to an object file."); +} diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp new file mode 100644 index 000000000..be474edbe --- /dev/null +++ b/mlir/lib/Bindings/Python/GPUPasses.cpp @@ -0,0 +1,22 @@ +//===- GPUPasses.cpp - Pybind module for the GPU passes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir-c/Dialect/GPU.h" + +#include "mlir/Bindings/Python/Nanobind.h" + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +NB_MODULE(_mlirGPUPasses, m) { + m.doc() = "MLIR GPU Dialect Passes"; + + // Register all GPU passes on load. + mlirRegisterGPUPasses(); +} diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 6613d2b69..71a051cb3 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,19 +9,26 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H +#include +#include #include +#include #include -#include "PybindUtils.h" - -#include "llvm/ADT/Optional.h" +#include "NanobindUtils.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/Support.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Regex.h" namespace mlir { namespace python { /// Globals that are always accessible once the extension has been initialized. +/// Methods of this class are thread-safe. class PyGlobals { public: PyGlobals(); @@ -35,73 +42,137 @@ class PyGlobals { /// Get and set the list of parent modules to search for dialect /// implementation classes. - std::vector &getDialectSearchPrefixes() { + std::vector getDialectSearchPrefixes() { + nanobind::ft_lock_guard lock(mutex); return dialectSearchPrefixes; } void setDialectSearchPrefixes(std::vector newValues) { + nanobind::ft_lock_guard lock(mutex); dialectSearchPrefixes.swap(newValues); } - - /// Clears positive and negative caches regarding what implementations are - /// available. Future lookups will do more expensive existence checks. - void clearImportCache(); + void addDialectSearchPrefix(std::string value) { + nanobind::ft_lock_guard lock(mutex); + dialectSearchPrefixes.push_back(std::move(value)); + } /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises /// an error on any evaluation issues. /// Note that this returns void because it is expected that the module /// contains calls to decorators and helpers that register the salient - /// entities. - void loadDialectModule(llvm::StringRef dialectNamespace); + /// entities. Returns true if dialect is successfully loaded. + bool loadDialectModule(llvm::StringRef dialectNamespace); - /// Decorator for registering a custom Dialect class. The class object must - /// have a DIALECT_NAMESPACE attribute. - pybind11::object registerDialectDecorator(pybind11::object pyClass); + /// Adds a user-friendly Attribute builder. + /// Raises an exception if the mapping already exists and replace == false. + /// This is intended to be called by implementation code. + void registerAttributeBuilder(const std::string &attributeKind, + nanobind::callable pyFunc, + bool replace = false); + + /// Adds a user-friendly type caster. Raises an exception if the mapping + /// already exists and replace == false. This is intended to be called by + /// implementation code. + void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, + bool replace = false); + + /// Adds a user-friendly value caster. Raises an exception if the mapping + /// already exists and replace == false. This is intended to be called by + /// implementation code. + void registerValueCaster(MlirTypeID mlirTypeID, + nanobind::callable valueCaster, + bool replace = false); /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerDialectImpl(const std::string &dialectNamespace, - pybind11::object pyClass); + nanobind::object pyClass); /// Adds a concrete implementation operation class. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass, - pybind11::object rawOpViewClass); + nanobind::object pyClass, bool replace = false); + + /// Returns the custom Attribute builder for Attribute kind. + std::optional + lookupAttributeBuilder(const std::string &attributeKind); + + /// Returns the custom type caster for MlirTypeID mlirTypeID. + std::optional lookupTypeCaster(MlirTypeID mlirTypeID, + MlirDialect dialect); + + /// Returns the custom value caster for MlirTypeID mlirTypeID. + std::optional lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. - llvm::Optional + std::optional lookupDialectClass(const std::string &dialectNamespace); - /// Looks up a registered raw OpView class by operation name. Note that this - /// may trigger a load of the dialect, which can arbitrarily re-enter. - llvm::Optional - lookupRawOpViewClass(llvm::StringRef operationName); + /// Looks up a registered operation class (deriving from OpView) by operation + /// name. Note that this may trigger a load of the dialect, which can + /// arbitrarily re-enter. + std::optional + lookupOperationClass(llvm::StringRef operationName); + + class TracebackLoc { + public: + bool locTracebacksEnabled(); + + void setLocTracebacksEnabled(bool value); + + size_t locTracebackFramesLimit(); + + void setLocTracebackFramesLimit(size_t value); + + void registerTracebackFileInclusion(const std::string &file); + + void registerTracebackFileExclusion(const std::string &file); + + bool isUserTracebackFilename(llvm::StringRef file); + + static constexpr size_t kMaxFrames = 512; + + private: + nanobind::ft_mutex mutex; + bool locTracebackEnabled_ = false; + size_t locTracebackFramesLimit_ = 10; + std::unordered_set userTracebackIncludeFiles; + std::unordered_set userTracebackExcludeFiles; + std::regex userTracebackIncludeRegex; + bool rebuildUserTracebackIncludeRegex = false; + std::regex userTracebackExcludeRegex; + bool rebuildUserTracebackExcludeRegex = false; + llvm::StringMap isUserTracebackFilenameCache; + }; + + TracebackLoc &getTracebackLoc() { return tracebackLoc; } private: static PyGlobals *instance; + + nanobind::ft_mutex mutex; + /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. - llvm::StringMap dialectClassMap; + llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. - llvm::StringMap operationClassMap; - /// Map of operation name to custom subclass that directly initializes - /// the OpView base class (bypassing the user class constructor). - llvm::StringMap rawOpViewClassMap; - + llvm::StringMap operationClassMap; + /// Map of attribute ODS name to custom builder. + llvm::StringMap attributeBuilderMap; + /// Map of MlirTypeID to custom type caster. + llvm::DenseMap typeCasterMap; + /// Map of MlirTypeID to custom value caster. + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. - llvm::StringSet<> loadedDialectModulesCache; - /// Cache of operation name to custom OpView subclass that directly - /// initializes the OpView base class (or an undefined object for negative - /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap - /// in order for repeat lookups of the OpView classes to only incur the cost - /// of one hashtable lookup. - llvm::StringMap rawOpViewClassMapCache; + llvm::StringSet<> loadedDialectModules; + + TracebackLoc tracebackLoc; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp new file mode 100644 index 000000000..50f2a4f95 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -0,0 +1,990 @@ +//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include +#include + +#include "IRModule.h" +#include "NanobindUtils.h" +#include "mlir-c/AffineExpr.h" +#include "mlir-c/AffineMap.h" +#include "mlir-c/IntegerSet.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +static const char kDumpDocstring[] = + R"(Dumps a debug representation of the object to stderr.)"; + +/// Attempts to populate `result` with the content of `list` casted to the +/// appropriate type (Python and C types are provided as template arguments). +/// Throws errors in case of failure, using "action" to describe what the caller +/// was attempting to do. +template +static void pyListToVector(const nb::list &list, + llvm::SmallVectorImpl &result, + StringRef action) { + result.reserve(nb::len(list)); + for (nb::handle item : list) { + try { + result.push_back(nb::cast(item)); + } catch (nb::cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression when ") + action + + " (" + err.what() + ")") + .str(); + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { + std::string msg = (llvm::Twine("Invalid expression (None?) when ") + + action + " (" + err.what() + ")") + .str(); + throw std::runtime_error(msg.c_str()); + } + } +} + +template +static bool isPermutation(std::vector permutation) { + llvm::SmallVector seen(permutation.size(), false); + for (auto val : permutation) { + if (val < permutation.size()) { + if (seen[val]) + return false; + seen[val] = true; + continue; + } + return false; + } + return true; +} + +namespace { + +/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr +/// and should be castable from it. Intermediate hierarchy classes can be +/// modeled by specifying BaseTy. +template +class PyConcreteAffineExpr : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = nb::class_; + using IsAFunctionTy = bool (*)(MlirAffineExpr); + + PyConcreteAffineExpr() = default; + PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) + : BaseTy(std::move(contextRef), affineExpr) {} + PyConcreteAffineExpr(PyAffineExpr &orig) + : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} + + static MlirAffineExpr castFrom(PyAffineExpr &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast affine expression to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str() + .c_str()); + } + return orig; + } + + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::arg("expr")); + cls.def_static( + "isinstance", + [](PyAffineExpr &otherAffineExpr) -> bool { + return DerivedTy::isaFunction(otherAffineExpr); + }, + nb::arg("other")); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyAffineConstantExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; + static constexpr const char *pyClassName = "AffineConstantExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineConstantExpr get(intptr_t value, + DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = + mlirAffineConstantExprGet(context->get(), static_cast(value)); + return PyAffineConstantExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("value", [](PyAffineConstantExpr &self) { + return mlirAffineConstantExprGetValue(self); + }); + } +}; + +class PyAffineDimExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; + static constexpr const char *pyClassName = "AffineDimExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); + return PyAffineDimExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineDimExpr &self) { + return mlirAffineDimExprGetPosition(self); + }); + } +}; + +class PyAffineSymbolExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; + static constexpr const char *pyClassName = "AffineSymbolExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); + return PyAffineSymbolExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { + return mlirAffineSymbolExprGetPosition(self); + }); + } +}; + +class PyAffineBinaryExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; + static constexpr const char *pyClassName = "AffineBinaryExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + PyAffineExpr lhs() { + MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); + return PyAffineExpr(getContext(), lhsExpr); + } + + PyAffineExpr rhs() { + MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); + return PyAffineExpr(getContext(), rhsExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs); + c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs); + } +}; + +class PyAffineAddExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; + static constexpr const char *pyClassName = "AffineAddExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineAddExpr(rhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineAddExpr::get); + } +}; + +class PyAffineMulExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; + static constexpr const char *pyClassName = "AffineMulExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineMulExpr(rhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineMulExpr::get); + } +}; + +class PyAffineModExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; + static constexpr const char *pyClassName = "AffineModExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { + MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineModExpr(rhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineModExpr::get); + } +}; + +class PyAffineFloorDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; + static constexpr const char *pyClassName = "AffineFloorDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineFloorDivExpr(rhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineFloorDivExpr::get); + } +}; + +class PyAffineCeilDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; + static constexpr const char *pyClassName = "AffineCeilDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineCeilDivExpr(rhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineCeilDivExpr::get); + } +}; + +} // namespace + +bool PyAffineExpr::operator==(const PyAffineExpr &other) const { + return mlirAffineExprEqual(affineExpr, other.affineExpr); +} + +nb::object PyAffineExpr::getCapsule() { + return nb::steal(mlirPythonAffineExprToCapsule(*this)); +} + +PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { + MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); + if (mlirAffineExprIsNull(rawAffineExpr)) + throw nb::python_error(); + return PyAffineExpr( + PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), + rawAffineExpr); +} + +//------------------------------------------------------------------------------ +// PyAffineMap and utilities. +//------------------------------------------------------------------------------ +namespace { + +/// A list of expressions contained in an affine map. Internally these are +/// stored as a consecutive array leading to inexpensive random access. Both +/// the map and the expression are owned by the context so we need not bother +/// with lifetime extension. +class PyAffineMapExprList + : public Sliceable { +public: + static constexpr const char *pyClassName = "AffineExprList"; + + PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirAffineMapGetNumResults(map) : length, + step), + affineMap(map) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } + + PyAffineExpr getRawElement(intptr_t pos) { + return PyAffineExpr(affineMap.getContext(), + mlirAffineMapGetResult(affineMap, pos)); + } + + PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyAffineMapExprList(affineMap, startIndex, length, step); + } + + PyAffineMap affineMap; +}; +} // namespace + +bool PyAffineMap::operator==(const PyAffineMap &other) const { + return mlirAffineMapEqual(affineMap, other.affineMap); +} + +nb::object PyAffineMap::getCapsule() { + return nb::steal(mlirPythonAffineMapToCapsule(*this)); +} + +PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { + MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(rawAffineMap)) + throw nb::python_error(); + return PyAffineMap( + PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), + rawAffineMap); +} + +//------------------------------------------------------------------------------ +// PyIntegerSet and utilities. +//------------------------------------------------------------------------------ +namespace { + +class PyIntegerSetConstraint { +public: + PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) + : set(std::move(set)), pos(pos) {} + + PyAffineExpr getExpr() { + return PyAffineExpr(set.getContext(), + mlirIntegerSetGetConstraint(set, pos)); + } + + bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } + + static void bind(nb::module_ &m) { + nb::class_(m, "IntegerSetConstraint") + .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr) + .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq); + } + +private: + PyIntegerSet set; + intptr_t pos; +}; + +class PyIntegerSetConstraintList + : public Sliceable { +public: + static constexpr const char *pyClassName = "IntegerSetConstraintList"; + + PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, + step), + set(set) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } + + PyIntegerSetConstraint getRawElement(intptr_t pos) { + return PyIntegerSetConstraint(set, pos); + } + + PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyIntegerSetConstraintList(set, startIndex, length, step); + } + + PyIntegerSet set; +}; +} // namespace + +bool PyIntegerSet::operator==(const PyIntegerSet &other) const { + return mlirIntegerSetEqual(integerSet, other.integerSet); +} + +nb::object PyIntegerSet::getCapsule() { + return nb::steal(mlirPythonIntegerSetToCapsule(*this)); +} + +PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { + MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); + if (mlirIntegerSetIsNull(rawIntegerSet)) + throw nb::python_error(); + return PyIntegerSet( + PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), + rawIntegerSet); +} + +void mlir::python::populateIRAffine(nb::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of PyAffineExpr and derived classes. + //---------------------------------------------------------------------------- + nb::class_(m, "AffineExpr") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) + .def("__add__", &PyAffineAddExpr::get) + .def("__add__", &PyAffineAddExpr::getRHSConstant) + .def("__radd__", &PyAffineAddExpr::getRHSConstant) + .def("__mul__", &PyAffineMulExpr::get) + .def("__mul__", &PyAffineMulExpr::getRHSConstant) + .def("__rmul__", &PyAffineMulExpr::getRHSConstant) + .def("__mod__", &PyAffineModExpr::get) + .def("__mod__", &PyAffineModExpr::getRHSConstant) + .def("__rmod__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineModExpr::get( + PyAffineConstantExpr::get(other, *self.getContext().get()), + self); + }) + .def("__sub__", + [](PyAffineExpr &self, PyAffineExpr &other) { + auto negOne = + PyAffineConstantExpr::get(-1, *self.getContext().get()); + return PyAffineAddExpr::get(self, + PyAffineMulExpr::get(negOne, other)); + }) + .def("__sub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::get( + self, + PyAffineConstantExpr::get(-other, *self.getContext().get())); + }) + .def("__rsub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::getLHSConstant( + other, PyAffineMulExpr::getLHSConstant(-1, self)); + }) + .def("__eq__", [](PyAffineExpr &self, + PyAffineExpr &other) { return self == other; }) + .def("__eq__", + [](PyAffineExpr &self, nb::object &other) { return false; }) + .def("__str__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineExpr("); + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def("__hash__", + [](PyAffineExpr &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def_prop_ro( + "context", + [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def("compose", + [](PyAffineExpr &self, PyAffineMap &other) { + return PyAffineExpr(self.getContext(), + mlirAffineExprCompose(self, other)); + }) + .def( + "shift_dims", + [](PyAffineExpr &self, uint32_t numDims, uint32_t shift, + uint32_t offset) { + return PyAffineExpr( + self.getContext(), + mlirAffineExprShiftDims(self, numDims, shift, offset)); + }, + nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0) + .def( + "shift_symbols", + [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift, + uint32_t offset) { + return PyAffineExpr( + self.getContext(), + mlirAffineExprShiftSymbols(self, numSymbols, shift, offset)); + }, + nb::arg("num_symbols"), nb::arg("shift"), + nb::arg("offset").none() = 0) + .def_static( + "simplify_affine_expr", + [](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) { + return PyAffineExpr( + self.getContext(), + mlirSimplifyAffineExpr(self, numDims, numSymbols)); + }, + nb::arg("expr"), nb::arg("num_dims"), nb::arg("num_symbols"), + "Simplify an affine expression by flattening and some amount of " + "simple analysis.") + .def_static( + "get_add", &PyAffineAddExpr::get, + "Gets an affine expression containing a sum of two expressions.") + .def_static("get_add", &PyAffineAddExpr::getLHSConstant, + "Gets an affine expression containing a sum of a constant " + "and another expression.") + .def_static("get_add", &PyAffineAddExpr::getRHSConstant, + "Gets an affine expression containing a sum of an expression " + "and a constant.") + .def_static( + "get_mul", &PyAffineMulExpr::get, + "Gets an affine expression containing a product of two expressions.") + .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, + "Gets an affine expression containing a product of a " + "constant and another expression.") + .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, + "Gets an affine expression containing a product of an " + "expression and a constant.") + .def_static("get_mod", &PyAffineModExpr::get, + "Gets an affine expression containing the modulo of dividing " + "one expression by another.") + .def_static("get_mod", &PyAffineModExpr::getLHSConstant, + "Gets a semi-affine expression containing the modulo of " + "dividing a constant by an expression.") + .def_static("get_mod", &PyAffineModExpr::getRHSConstant, + "Gets an affine expression containing the module of dividing" + "an expression by a constant.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::get, + "Gets an affine expression containing the rounded-down " + "result of dividing one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-down " + "result of dividing a constant by an expression.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-down " + "result of dividing an expression by a constant.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, + "Gets an affine expression containing the rounded-up result " + "of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-up " + "result of dividing a constant by an expression.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-up result " + "of dividing an expression by a constant.") + .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none(), + "Gets a constant affine expression with the given value.") + .def_static( + "get_dim", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), + "Gets an affine expression of a dimension at the given position.") + .def_static( + "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), + "Gets an affine expression of a symbol at the given position.") + .def( + "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, + kDumpDocstring); + PyAffineConstantExpr::bind(m); + PyAffineDimExpr::bind(m); + PyAffineSymbolExpr::bind(m); + PyAffineBinaryExpr::bind(m); + PyAffineAddExpr::bind(m); + PyAffineMulExpr::bind(m); + PyAffineModExpr::bind(m); + PyAffineFloorDivExpr::bind(m); + PyAffineCeilDivExpr::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyAffineMap. + //---------------------------------------------------------------------------- + nb::class_(m, "AffineMap") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) + .def("__eq__", + [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) + .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; }) + .def("__str__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineMap("); + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def("__hash__", + [](PyAffineMap &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def_static("compress_unused_symbols", + [](nb::list affineMaps, DefaultingPyMlirContext context) { + SmallVector maps; + pyListToVector( + affineMaps, maps, "attempting to create an AffineMap"); + std::vector compressed(affineMaps.size()); + auto populate = [](void *result, intptr_t idx, + MlirAffineMap m) { + static_cast(result)[idx] = (m); + }; + mlirAffineMapCompressUnusedSymbols( + maps.data(), maps.size(), compressed.data(), populate); + std::vector res; + res.reserve(compressed.size()); + for (auto m : compressed) + res.emplace_back(context->getRef(), m); + return res; + }) + .def_prop_ro( + "context", + [](PyAffineMap &self) { return self.getContext().getObject(); }, + "Context that owns the Affine Map") + .def( + "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, + DefaultingPyMlirContext context) { + SmallVector affineExprs; + pyListToVector( + exprs, affineExprs, "attempting to create an AffineMap"); + MlirAffineMap map = + mlirAffineMapGet(context->get(), dimCount, symbolCount, + affineExprs.size(), affineExprs.data()); + return PyAffineMap(context->getRef(), map); + }, + nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), + nb::arg("context").none() = nb::none(), + "Gets a map with the given expressions as results.") + .def_static( + "get_constant", + [](intptr_t value, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapConstantGet(context->get(), value); + return PyAffineMap(context->getRef(), affineMap); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an affine map with a single constant result") + .def_static( + "get_empty", + [](DefaultingPyMlirContext context) { + MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); + return PyAffineMap(context->getRef(), affineMap); + }, + nb::arg("context").none() = nb::none(), "Gets an empty affine map.") + .def_static( + "get_identity", + [](intptr_t nDims, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMultiDimIdentityGet(context->get(), nDims); + return PyAffineMap(context->getRef(), affineMap); + }, + nb::arg("n_dims"), nb::arg("context").none() = nb::none(), + "Gets an identity map with the given number of dimensions.") + .def_static( + "get_minor_identity", + [](intptr_t nDims, intptr_t nResults, + DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); + return PyAffineMap(context->getRef(), affineMap); + }, + nb::arg("n_dims"), nb::arg("n_results"), + nb::arg("context").none() = nb::none(), + "Gets a minor identity map with the given number of dimensions and " + "results.") + .def_static( + "get_permutation", + [](std::vector permutation, + DefaultingPyMlirContext context) { + if (!isPermutation(permutation)) + throw std::runtime_error("Invalid permutation when attempting to " + "create an AffineMap"); + MlirAffineMap affineMap = mlirAffineMapPermutationGet( + context->get(), permutation.size(), permutation.data()); + return PyAffineMap(context->getRef(), affineMap); + }, + nb::arg("permutation"), nb::arg("context").none() = nb::none(), + "Gets an affine map that permutes its inputs.") + .def( + "get_submap", + [](PyAffineMap &self, std::vector &resultPos) { + intptr_t numResults = mlirAffineMapGetNumResults(self); + for (intptr_t pos : resultPos) { + if (pos < 0 || pos >= numResults) + throw nb::value_error("result position out of bounds"); + } + MlirAffineMap affineMap = mlirAffineMapGetSubMap( + self, resultPos.size(), resultPos.data()); + return PyAffineMap(self.getContext(), affineMap); + }, + nb::arg("result_positions")) + .def( + "get_major_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw nb::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMajorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }, + nb::arg("n_results")) + .def( + "get_minor_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw nb::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMinorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }, + nb::arg("n_results")) + .def( + "replace", + [](PyAffineMap &self, PyAffineExpr &expression, + PyAffineExpr &replacement, intptr_t numResultDims, + intptr_t numResultSyms) { + MlirAffineMap affineMap = mlirAffineMapReplace( + self, expression, replacement, numResultDims, numResultSyms); + return PyAffineMap(self.getContext(), affineMap); + }, + nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"), + nb::arg("n_result_syms")) + .def_prop_ro( + "is_permutation", + [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) + .def_prop_ro("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_prop_ro( + "n_dims", + [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) + .def_prop_ro( + "n_inputs", + [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) + .def_prop_ro( + "n_symbols", + [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) + .def_prop_ro("results", + [](PyAffineMap &self) { return PyAffineMapExprList(self); }); + PyAffineMapExprList::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyIntegerSet. + //---------------------------------------------------------------------------- + nb::class_(m, "IntegerSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) + .def("__eq__", [](PyIntegerSet &self, + PyIntegerSet &other) { return self == other; }) + .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) + .def("__str__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("IntegerSet("); + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def("__hash__", + [](PyIntegerSet &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def_prop_ro( + "context", + [](PyIntegerSet &self) { return self.getContext().getObject(); }) + .def( + "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, + std::vector eqFlags, DefaultingPyMlirContext context) { + if (exprs.size() != eqFlags.size()) + throw nb::value_error( + "Expected the number of constraints to match " + "that of equality flags"); + if (exprs.size() == 0) + throw nb::value_error("Expected non-empty list of constraints"); + + // Copy over to a SmallVector because std::vector has a + // specialization for booleans that packs data and does not + // expose a `bool *`. + SmallVector flags(eqFlags.begin(), eqFlags.end()); + + SmallVector affineExprs; + pyListToVector(exprs, affineExprs, + "attempting to create an IntegerSet"); + MlirIntegerSet set = mlirIntegerSetGet( + context->get(), numDims, numSymbols, exprs.size(), + affineExprs.data(), flags.data()); + return PyIntegerSet(context->getRef(), set); + }, + nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), + nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) + .def_static( + "get_empty", + [](intptr_t numDims, intptr_t numSymbols, + DefaultingPyMlirContext context) { + MlirIntegerSet set = + mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); + return PyIntegerSet(context->getRef(), set); + }, + nb::arg("num_dims"), nb::arg("num_symbols"), + nb::arg("context").none() = nb::none()) + .def( + "get_replaced", + [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, + intptr_t numResultDims, intptr_t numResultSymbols) { + if (static_cast(dimExprs.size()) != + mlirIntegerSetGetNumDims(self)) + throw nb::value_error( + "Expected the number of dimension replacement expressions " + "to match that of dimensions"); + if (static_cast(symbolExprs.size()) != + mlirIntegerSetGetNumSymbols(self)) + throw nb::value_error( + "Expected the number of symbol replacement expressions " + "to match that of symbols"); + + SmallVector dimAffineExprs, symbolAffineExprs; + pyListToVector( + dimExprs, dimAffineExprs, + "attempting to create an IntegerSet by replacing dimensions"); + pyListToVector( + symbolExprs, symbolAffineExprs, + "attempting to create an IntegerSet by replacing symbols"); + MlirIntegerSet set = mlirIntegerSetReplaceGet( + self, dimAffineExprs.data(), symbolAffineExprs.data(), + numResultDims, numResultSymbols); + return PyIntegerSet(self.getContext(), set); + }, + nb::arg("dim_exprs"), nb::arg("symbol_exprs"), + nb::arg("num_result_dims"), nb::arg("num_result_symbols")) + .def_prop_ro("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_prop_ro( + "n_dims", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) + .def_prop_ro( + "n_symbols", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) + .def_prop_ro( + "n_inputs", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) + .def_prop_ro("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_prop_ro("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_prop_ro("constraints", [](PyIntegerSet &self) { + return PyIntegerSetConstraintList(self); + }); + PyIntegerSetConstraint::bind(m); + PyIntegerSetConstraintList::bind(m); +} diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp new file mode 100644 index 000000000..db84ee1fc --- /dev/null +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -0,0 +1,1818 @@ +//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include "IRModule.h" +#include "NanobindUtils.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" + +namespace nb = nanobind; +using namespace nanobind::literals; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; + +//------------------------------------------------------------------------------ +// Docstrings (trivial, non-duplicated docstrings are included inline). +//------------------------------------------------------------------------------ + +static const char kDenseElementsAttrGetDocstring[] = + R"(Gets a DenseElementsAttr from a Python buffer or array. + +When `type` is not provided, then some limited type inferencing is done based +on the buffer format. Support presently exists for 8/16/32/64 signed and +unsigned integers and float16/float32/float64. DenseElementsAttrs of these +types can also be converted back to a corresponding buffer. + +For conversions outside of these types, a `type=` must be explicitly provided +and the buffer contents must be bit-castable to the MLIR internal +representation: + + * Integer types (except for i1): the buffer must be byte aligned to the + next byte boundary. + * Floating point types: Must be bit-castable to the given floating point + size. + * i1 (bool): Bit packed into 8bit words where the bit pattern matches a + row major ordering. An arbitrary Numpy `bool_` array can be bit packed to + this specification with: `np.packbits(ary, axis=None, bitorder='little')`. + +If a single element buffer is passed (or for i1, a single byte with value 0 +or 255), then a splat will be created. + +Args: + array: The array or buffer to convert. + signless: If inferring an appropriate MLIR type, use signless types for + integers (defaults True). + type: Skips inference of the MLIR element type and uses this instead. The + storage size must be consistent with the actual contents of the buffer. + shape: Overrides the shape of the buffer when constructing the MLIR + shaped type. This is needed when the physical and logical shape differ (as + for i1). + context: Explicit context, if not from context manager. + +Returns: + DenseElementsAttr on success. + +Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. +)"; + +static const char kDenseElementsAttrGetFromListDocstring[] = + R"(Gets a DenseElementsAttr from a Python list of attributes. + +Note that it can be expensive to construct attributes individually. +For a large number of elements, consider using a Python buffer or array instead. + +Args: + attrs: A list of attributes. + type: The desired shape and type of the resulting DenseElementsAttr. + If not provided, the element type is determined based on the type + of the 0th attribute and the shape is `[len(attrs)]`. + context: Explicit context, if not from context manager. + +Returns: + DenseElementsAttr on success. + +Raises: + ValueError: If the type of the attributes does not match the type + specified by `shaped_type`. +)"; + +static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = + R"(Gets a DenseResourceElementsAttr from a Python buffer or array. + +This function does minimal validation or massaging of the data, and it is +up to the caller to ensure that the buffer meets the characteristics +implied by the shape. + +The backing buffer and any user objects will be retained for the lifetime +of the resource blob. This is typically bounded to the context but the +resource can have a shorter lifespan depending on how it is used in +subsequent processing. + +Args: + buffer: The array or buffer to convert. + name: Name to provide to the resource (may be changed upon collision). + type: The explicit ShapedType to construct the attribute with. + context: Explicit context, if not from context manager. + +Returns: + DenseResourceElementsAttr on success. + +Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. +)"; + +namespace { + +struct nb_buffer_info { + void *ptr = nullptr; + ssize_t itemsize = 0; + ssize_t size = 0; + const char *format = nullptr; + ssize_t ndim = 0; + SmallVector shape; + SmallVector strides; + bool readonly = false; + + nb_buffer_info( + void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, + SmallVector shape_in, SmallVector strides_in, + bool readonly = false, + std::unique_ptr owned_view_in = + std::unique_ptr(nullptr, nullptr)) + : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), + readonly(readonly), owned_view(std::move(owned_view_in)) { + size = 1; + for (ssize_t i = 0; i < ndim; ++i) { + size *= shape[i]; + } + } + + explicit nb_buffer_info(Py_buffer *view) + : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, + // TODO(phawkins): check for null strides + {view->strides, view->strides + view->ndim}, + view->readonly != 0, + std::unique_ptr( + view, PyBuffer_Release)) {} + + nb_buffer_info(const nb_buffer_info &) = delete; + nb_buffer_info(nb_buffer_info &&) = default; + nb_buffer_info &operator=(const nb_buffer_info &) = delete; + nb_buffer_info &operator=(nb_buffer_info &&) = default; + +private: + std::unique_ptr owned_view; +}; + +class nb_buffer : public nb::object { + NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); + + nb_buffer_info request() const { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + auto *view = new Py_buffer(); + if (PyObject_GetBuffer(ptr(), view, flags) != 0) { + delete view; + throw nb::python_error(); + } + return nb_buffer_info(view); + } +}; + +template +struct nb_format_descriptor {}; + +template <> +struct nb_format_descriptor { + static const char *format() { return "?"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "b"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "B"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "h"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "H"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "i"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "I"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "q"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "Q"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "f"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "d"; } +}; + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + +class PyAffineMapAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; + static constexpr const char *pyClassName = "AffineMapAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAffineMapAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyAffineMap &affineMap) { + MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); + return PyAffineMapAttribute(affineMap.getContext(), attr); + }, + nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_prop_ro("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); + } +}; + +class PyIntegerSetAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; + static constexpr const char *pyClassName = "IntegerSetAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerSetAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyIntegerSet &integerSet) { + MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); + return PyIntegerSetAttribute(integerSet.getContext(), attr); + }, + nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); + } +}; + +template +static T pyTryCast(nb::handle object) { + try { + return nb::cast(object); + } catch (nb::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { + std::string msg = std::string("Invalid attribute (None?) when attempting " + "to create an ArrayAttribute (") + + err.what() + ")"; + throw std::runtime_error(msg.c_str()); + } +} + +/// A python-wrapped dense array attribute with an element type and a derived +/// implementation class. +template +class PyDenseArrayAttribute : public PyConcreteAttribute { +public: + using PyConcreteAttribute::PyConcreteAttribute; + + /// Iterator over the integer elements of a dense array. + class PyDenseArrayIterator { + public: + PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} + + /// Return a copy of the iterator. + PyDenseArrayIterator dunderIter() { return *this; } + + /// Return the next element. + EltTy dunderNext() { + // Throw if the index has reached the end. + if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) + throw nb::stop_iteration(); + return DerivedT::getElement(attr.get(), nextIndex++); + } + + /// Bind the iterator class. + static void bind(nb::module_ &m) { + nb::class_(m, DerivedT::pyIteratorName) + .def("__iter__", &PyDenseArrayIterator::dunderIter) + .def("__next__", &PyDenseArrayIterator::dunderNext); + } + + private: + /// The referenced dense array attribute. + PyAttribute attr; + /// The next index to read. + int nextIndex = 0; + }; + + /// Get the element at the given index. + EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } + + /// Bind the attribute class. + static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { + // Bind the constructor. + if constexpr (std::is_same_v) { + c.def_static( + "get", + [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { + std::vector values; + for (nb::handle py_value : py_values) { + int is_true = PyObject_IsTrue(py_value.ptr()); + if (is_true < 0) { + throw nb::python_error(); + } + values.push_back(is_true); + } + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } else { + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } + // Bind the array methods. + c.def("__getitem__", [](DerivedT &arr, intptr_t i) { + if (i >= mlirDenseArrayGetNumElements(arr)) + throw nb::index_error("DenseArray index out of range"); + return arr.getItem(i); + }); + c.def("__len__", [](const DerivedT &arr) { + return mlirDenseArrayGetNumElements(arr); + }); + c.def("__iter__", + [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); + c.def("__add__", [](DerivedT &arr, const nb::list &extras) { + std::vector values; + intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); + values.reserve(numOldElements + nb::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + values.push_back(arr.getItem(i)); + for (nb::handle attr : extras) + values.push_back(pyTryCast(attr)); + return getAttribute(values, arr.getContext()); + }); + } + +private: + static DerivedT getAttribute(const std::vector &values, + PyMlirContextRef ctx) { + if constexpr (std::is_same_v) { + std::vector intValues(values.begin(), values.end()); + MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), + intValues.data()); + return DerivedT(ctx, attr); + } else { + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), values.size(), values.data()); + return DerivedT(ctx, attr); + } + } +}; + +/// Instantiate the python dense array classes. +struct PyDenseBoolArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; + static constexpr auto getAttribute = mlirDenseBoolArrayGet; + static constexpr auto getElement = mlirDenseBoolArrayGetElement; + static constexpr const char *pyClassName = "DenseBoolArrayAttr"; + static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI8ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; + static constexpr auto getAttribute = mlirDenseI8ArrayGet; + static constexpr auto getElement = mlirDenseI8ArrayGetElement; + static constexpr const char *pyClassName = "DenseI8ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI16ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; + static constexpr auto getAttribute = mlirDenseI16ArrayGet; + static constexpr auto getElement = mlirDenseI16ArrayGetElement; + static constexpr const char *pyClassName = "DenseI16ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; + static constexpr auto getAttribute = mlirDenseI32ArrayGet; + static constexpr auto getElement = mlirDenseI32ArrayGetElement; + static constexpr const char *pyClassName = "DenseI32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; + static constexpr auto getAttribute = mlirDenseI64ArrayGet; + static constexpr auto getElement = mlirDenseI64ArrayGetElement; + static constexpr const char *pyClassName = "DenseI64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseF32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; + static constexpr auto getAttribute = mlirDenseF32ArrayGet; + static constexpr auto getElement = mlirDenseF32ArrayGetElement; + static constexpr const char *pyClassName = "DenseF32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseF64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; + static constexpr auto getAttribute = mlirDenseF64ArrayGet; + static constexpr auto getElement = mlirDenseF64ArrayGetElement; + static constexpr const char *pyClassName = "DenseF64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + +class PyArrayAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; + static constexpr const char *pyClassName = "ArrayAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirArrayAttrGetTypeID; + + class PyArrayAttributeIterator { + public: + PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} + + PyArrayAttributeIterator &dunderIter() { return *this; } + + MlirAttribute dunderNext() { + // TODO: Throw is an inefficient way to stop iteration. + if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) + throw nb::stop_iteration(); + return mlirArrayAttrGetElement(attr.get(), nextIndex++); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "ArrayAttributeIterator") + .def("__iter__", &PyArrayAttributeIterator::dunderIter) + .def("__next__", &PyArrayAttributeIterator::dunderNext); + } + + private: + PyAttribute attr; + int nextIndex = 0; + }; + + MlirAttribute getItem(intptr_t i) { + return mlirArrayAttrGetElement(*this, i); + } + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](nb::list attributes, DefaultingPyMlirContext context) { + SmallVector mlirAttributes; + mlirAttributes.reserve(nb::len(attributes)); + for (auto attribute : attributes) { + mlirAttributes.push_back(pyTryCast(attribute)); + } + MlirAttribute attr = mlirArrayAttrGet( + context->get(), mlirAttributes.size(), mlirAttributes.data()); + return PyArrayAttribute(context->getRef(), attr); + }, + nb::arg("attributes"), nb::arg("context").none() = nb::none(), + "Gets a uniqued Array attribute"); + c.def("__getitem__", + [](PyArrayAttribute &arr, intptr_t i) { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw nb::index_error("ArrayAttribute index out of range"); + return arr.getItem(i); + }) + .def("__len__", + [](const PyArrayAttribute &arr) { + return mlirArrayAttrGetNumElements(arr); + }) + .def("__iter__", [](const PyArrayAttribute &arr) { + return PyArrayAttributeIterator(arr); + }); + c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { + std::vector attributes; + intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); + attributes.reserve(numOldElements + nb::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + attributes.push_back(arr.getItem(i)); + for (nb::handle attr : extras) + attributes.push_back(pyTryCast(attr)); + MlirAttribute arrayAttr = mlirArrayAttrGet( + arr.getContext()->get(), attributes.size(), attributes.data()); + return PyArrayAttribute(arr.getContext(), arrayAttr); + }); + } +}; + +/// Float Point Attribute subclass - FloatAttr. +class PyFloatAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; + static constexpr const char *pyClassName = "FloatAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, double value, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); + return PyFloatAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), + "Gets an uniqued float point attribute associated to a type"); + c.def_static( + "get_f32", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF32TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an uniqued float point attribute associated to a f32 type"); + c.def_static( + "get_f64", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF64TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an uniqued float point attribute associated to a f64 type"); + c.def_prop_ro("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); + c.def("__float__", mlirFloatAttrGetValueDouble, + "Converts the value of the float attribute to a Python float"); + } +}; + +/// Integer Attribute subclass - IntegerAttr. +class PyIntegerAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; + static constexpr const char *pyClassName = "IntegerAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, int64_t value) { + MlirAttribute attr = mlirIntegerAttrGet(type, value); + return PyIntegerAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), + "Gets an uniqued integer attribute associated to a type"); + c.def_prop_ro("value", toPyInt, + "Returns the value of the integer attribute"); + c.def("__int__", toPyInt, + "Converts the value of the integer attribute to a Python int"); + c.def_prop_ro_static("static_typeid", + [](nb::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); + } + +private: + static int64_t toPyInt(PyIntegerAttribute &self) { + MlirType type = mlirAttributeGetType(self); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return mlirIntegerAttrGetValueInt(self); + if (mlirIntegerTypeIsSigned(type)) + return mlirIntegerAttrGetValueSInt(self); + return mlirIntegerAttrGetValueUInt(self); + } +}; + +/// Bool Attribute subclass - BoolAttr. +class PyBoolAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; + static constexpr const char *pyClassName = "BoolAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](bool value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirBoolAttrGet(context->get(), value); + return PyBoolAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets an uniqued bool attribute"); + c.def_prop_ro("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); + c.def("__bool__", mlirBoolAttrGetValue, + "Converts the value of the bool attribute to a Python bool"); + } +}; + +class PySymbolRefAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; + static constexpr const char *pyClassName = "SymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static MlirAttribute fromList(const std::vector &symbols, + PyMlirContext &context) { + if (symbols.empty()) + throw std::runtime_error("SymbolRefAttr must be composed of at least " + "one symbol."); + MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); + SmallVector referenceAttrs; + for (size_t i = 1; i < symbols.size(); ++i) { + referenceAttrs.push_back( + mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); + } + return mlirSymbolRefAttrGet(context.get(), rootSymbol, + referenceAttrs.size(), referenceAttrs.data()); + } + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &symbols, + DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromList(symbols, context.resolve()); + }, + nb::arg("symbols"), nb::arg("context").none() = nb::none(), + "Gets a uniqued SymbolRef attribute from a list of symbol names"); + c.def_prop_ro( + "value", + [](PySymbolRefAttribute &self) { + std::vector symbols = { + unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; + for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); + ++i) + symbols.push_back( + unwrap(mlirSymbolRefAttrGetRootReference( + mlirSymbolRefAttrGetNestedReference(self, i))) + .str()); + return symbols; + }, + "Returns the value of the SymbolRef attribute as a list[str]"); + } +}; + +class PyFlatSymbolRefAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; + static constexpr const char *pyClassName = "FlatSymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); + return PyFlatSymbolRefAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued FlatSymbolRef attribute"); + c.def_prop_ro( + "value", + [](PyFlatSymbolRefAttribute &self) { + MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the value of the FlatSymbolRef attribute as a string"); + } +}; + +class PyOpaqueAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; + static constexpr const char *pyClassName = "OpaqueAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, nb_buffer buffer, PyType &type, + DefaultingPyMlirContext context) { + const nb_buffer_info bufferInfo = buffer.request(); + intptr_t bufferSize = bufferInfo.size; + MlirAttribute attr = mlirOpaqueAttrGet( + context->get(), toMlirStringRef(dialectNamespace), bufferSize, + static_cast(bufferInfo.ptr), type); + return PyOpaqueAttribute(context->getRef(), attr); + }, + nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), + nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); + c.def_prop_ro( + "dialect_namespace", + [](PyOpaqueAttribute &self) { + MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque attribute as a string"); + c.def_prop_ro( + "data", + [](PyOpaqueAttribute &self) { + MlirStringRef stringRef = mlirOpaqueAttrGetData(self); + return nb::bytes(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaqued attributes as `bytes`"); + } +}; + +class PyStringAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; + static constexpr const char *pyClassName = "StringAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get", + [](nb::bytes value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get_typed", + [](PyType &type, std::string value) { + MlirAttribute attr = + mlirStringAttrTypedGet(type, toMlirStringRef(value)); + return PyStringAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), + "Gets a uniqued string attribute associated to a type"); + c.def_prop_ro( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); + c.def_prop_ro( + "value_bytes", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return nb::bytes(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute as `bytes`"); + } +}; + +// TODO: Support construction of string elements. +class PyDenseElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; + static constexpr const char *pyClassName = "DenseElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseElementsAttribute + getFromList(nb::list attributes, std::optional explicitType, + DefaultingPyMlirContext contextWrapper) { + const size_t numAttributes = nb::len(attributes); + if (numAttributes == 0) + throw nb::value_error("Attributes list must be non-empty."); + + MlirType shapedType; + if (explicitType) { + if ((!mlirTypeIsAShaped(*explicitType) || + !mlirShapedTypeHasStaticShape(*explicitType))) { + + std::string message; + llvm::raw_string_ostream os(message); + os << "Expected a static ShapedType for the shaped_type parameter: " + << nb::cast(nb::repr(nb::cast(*explicitType))); + throw nb::value_error(message.c_str()); + } + shapedType = *explicitType; + } else { + SmallVector shape = {static_cast(numAttributes)}; + shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), + mlirAttributeGetType(pyTryCast(attributes[0])), + mlirAttributeGetNull()); + } + + SmallVector mlirAttributes; + mlirAttributes.reserve(numAttributes); + for (const nb::handle &attribute : attributes) { + MlirAttribute mlirAttribute = pyTryCast(attribute); + MlirType attrType = mlirAttributeGetType(mlirAttribute); + mlirAttributes.push_back(mlirAttribute); + + if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { + std::string message; + llvm::raw_string_ostream os(message); + os << "All attributes must be of the same type and match " + << "the type parameter: expected=" + << nb::cast(nb::repr(nb::cast(shapedType))) + << ", but got=" + << nb::cast(nb::repr(nb::cast(attrType))); + throw nb::value_error(message.c_str()); + } + } + + MlirAttribute elements = mlirDenseElementsAttrGet( + shapedType, mlirAttributes.size(), mlirAttributes.data()); + + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + + static PyDenseElementsAttribute + getFromBuffer(nb_buffer array, bool signless, + std::optional explicitType, + std::optional> explicitShape, + DefaultingPyMlirContext contextWrapper) { + // Request a contiguous view. In exotic cases, this will cause a copy. + int flags = PyBUF_ND; + if (!explicitType) { + flags |= PyBUF_FORMAT; + } + Py_buffer view; + if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { + throw nb::python_error(); + } + auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); + + MlirContext context = contextWrapper->get(); + MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, + explicitShape, context); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseElementsAttr could not be constructed from the given buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseElementsAttribute(contextWrapper->getRef(), attr); + } + + static PyDenseElementsAttribute getSplat(const PyType &shapedType, + PyAttribute &elementAttr) { + auto contextWrapper = + PyMlirContext::forContext(mlirTypeGetContext(shapedType)); + if (!mlirAttributeIsAInteger(elementAttr) && + !mlirAttributeIsAFloat(elementAttr)) { + std::string message = "Illegal element type for DenseElementsAttr: "; + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); + } + if (!mlirTypeIsAShaped(shapedType) || + !mlirShapedTypeHasStaticShape(shapedType)) { + std::string message = + "Expected a static ShapedType for the shaped_type parameter: "; + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + throw nb::value_error(message.c_str()); + } + MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); + MlirType attrType = mlirAttributeGetType(elementAttr); + if (!mlirTypeEqual(shapedElementType, attrType)) { + std::string message = + "Shaped element type and attribute type must be equal: shaped="; + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + message.append(", element="); + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); + } + + MlirAttribute elements = + mlirDenseElementsAttrSplatGet(shapedType, elementAttr); + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + + intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } + + std::unique_ptr accessBuffer() { + MlirType shapedType = mlirAttributeGetType(*this); + MlirType elementType = mlirShapedTypeGetElementType(shapedType); + std::string format; + + if (mlirTypeIsAF32(elementType)) { + // f32 + return bufferInfo(shapedType); + } + if (mlirTypeIsAF64(elementType)) { + // f64 + return bufferInfo(shapedType); + } + if (mlirTypeIsAF16(elementType)) { + // f16 + return bufferInfo(shapedType, "e"); + } + if (mlirTypeIsAIndex(elementType)) { + // Same as IndexType::kInternalStorageBitWidth + return bufferInfo(shapedType); + } + if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 32) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i32 + return bufferInfo(shapedType); + } + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i32 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 64) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i64 + return bufferInfo(shapedType); + } + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i64 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 8) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i8 + return bufferInfo(shapedType); + } + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i8 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 16) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i16 + return bufferInfo(shapedType); + } + if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i16 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 1) { + // i1 / bool + // We can not send the buffer directly back to Python, because the i1 + // values are bitpacked within MLIR. We call numpy's unpackbits function + // to convert the bytes. + return getBooleanBufferFromBitpackedAttribute(); + } + + // TODO: Currently crashes the program. + // Reported as https://github.com/pybind/pybind11/issues/3336 + throw std::invalid_argument( + "unsupported data type for conversion to Python buffer"); + } + + static void bindDerived(ClassTy &c) { +#if PY_VERSION_HEX < 0x03090000 + PyTypeObject *tp = reinterpret_cast(c.ptr()); + tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; + tp->tp_as_buffer->bf_releasebuffer = + PyDenseElementsAttribute::bf_releasebuffer; +#endif + c.def("__len__", &PyDenseElementsAttribute::dunderLen) + .def_static("get", PyDenseElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("signless") = true, + nb::arg("type").none() = nb::none(), + nb::arg("shape").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kDenseElementsAttrGetDocstring) + .def_static("get", PyDenseElementsAttribute::getFromList, + nb::arg("attrs"), nb::arg("type").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kDenseElementsAttrGetFromListDocstring) + .def_static("get_splat", PyDenseElementsAttribute::getSplat, + nb::arg("shaped_type"), nb::arg("element_attr"), + "Gets a DenseElementsAttr where all values are the same") + .def_prop_ro("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def("get_splat_value", [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return mlirDenseElementsAttrGetSplatValue(self); + }); + } + + static PyType_Slot slots[]; + +private: + static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); + static void bf_releasebuffer(PyObject *, Py_buffer *buffer); + + static bool isUnsignedIntegerFormat(std::string_view format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'I' || code == 'B' || code == 'H' || code == 'L' || + code == 'Q'; + } + + static bool isSignedIntegerFormat(std::string_view format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'i' || code == 'b' || code == 'h' || code == 'l' || + code == 'q'; + } + + static MlirType + getShapedType(std::optional bulkLoadElementType, + std::optional> explicitShape, + Py_buffer &view) { + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(view.shape, view.shape + view.ndim); + } + + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + return *bulkLoadElementType; + } else { + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + } + + static MlirAttribute getAttributeFromBuffer( + Py_buffer &view, bool signless, std::optional explicitType, + std::optional> explicitShape, MlirContext &context) { + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes exotics types which do not have a direct + // representation in the buffer protocol (i.e. complex, etc). + std::optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (format == "?") { + // i1 + // The i1 type needs to be bit-packed, so we will handle it seperately + return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, + context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); + } + } + if (!bulkLoadElementType) { + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + std::string(format)); + } + } + + MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); + } + + // There is a complication for boolean numpy arrays, as numpy represents + // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 + // booleans per byte. + static MlirAttribute getBitpackedAttributeFromBooleanBuffer( + Py_buffer &view, std::optional> explicitShape, + MlirContext &context) { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a bit-packed MLIR attribute is " + "unsupported on big-endian systems"); + } + nb::ndarray, nb::c_contig> unpackedArray( + /*data=*/static_cast(view.buf), + /*shape=*/{static_cast(view.len)}); + + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object packbitsFunc = numpy.attr("packbits"); + nb::object packedBooleans = + packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); + nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); + + MlirType bitpackedType = + getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); + assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); + // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of + // packedBooleans, hence the MlirAttribute will remain valid even when + // packedBooleans get reclaimed by the end of the function. + return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, + pythonBuffer.ptr); + } + + // This does the opposite transformation of + // `getBitpackedAttributeFromBooleanBuffer` + std::unique_ptr getBooleanBufferFromBitpackedAttribute() { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a numpy array from a MLIR attribute " + "is unsupported on big-endian systems"); + } + + int64_t numBooleans = mlirElementsAttrGetNumElements(*this); + int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); + uint8_t *bitpackedData = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + nb::ndarray, nb::c_contig> packedArray( + /*data=*/bitpackedData, + /*shape=*/{static_cast(numBitpackedBytes)}); + + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object unpackbitsFunc = numpy.attr("unpackbits"); + nb::object equalFunc = numpy.attr("equal"); + nb::object reshapeFunc = numpy.attr("reshape"); + nb::object unpackedBooleans = + unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); + + // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. + // We need to: + // 1. Slice away the padded bits + // 2. Make the boolean array have the correct shape + // 3. Convert the array to a boolean array + unpackedBooleans = unpackedBooleans[nb::slice( + nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; + unpackedBooleans = equalFunc(unpackedBooleans, 1); + + MlirType shapedType = mlirAttributeGetType(*this); + intptr_t rank = mlirShapedTypeGetRank(shapedType); + std::vector shape(rank); + for (intptr_t i = 0; i < rank; ++i) { + shape[i] = mlirShapedTypeGetDimSize(shapedType, i); + } + unpackedBooleans = reshapeFunc(unpackedBooleans, shape); + + // Make sure the returned nb::buffer_view claims ownership of the data in + // `pythonBuffer` so it remains valid when Python reads it + nb_buffer pythonBuffer = nb::cast(unpackedBooleans); + return std::make_unique(pythonBuffer.request()); + } + + template + std::unique_ptr + bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); + // Prepare the data for the buffer_info. + // Buffer is configured for read-only access below. + Type *data = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + // Prepare the shape for the buffer_info. + SmallVector shape; + for (intptr_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + // Prepare the strides for the buffer_info. + SmallVector strides; + if (mlirDenseElementsAttrIsSplat(*this)) { + // Splats are special, only the single value is stored. + strides.assign(rank, 0); + } else { + for (intptr_t i = 1; i < rank; ++i) { + intptr_t strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + strides.push_back(sizeof(Type) * strideFactor); + } + strides.push_back(sizeof(Type)); + } + const char *format; + if (explicitFormat) { + format = explicitFormat; + } else { + format = nb_format_descriptor::format(); + } + return std::make_unique( + data, sizeof(Type), format, rank, std::move(shape), std::move(strides), + /*readonly=*/true); + } +}; // namespace + +PyType_Slot PyDenseElementsAttribute::slots[] = { +// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. +#if PY_VERSION_HEX >= 0x03090000 + {Py_bf_getbuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_getbuffer)}, + {Py_bf_releasebuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_releasebuffer)}, +#endif + {0, nullptr}, +}; + +/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, + Py_buffer *view, + int flags) { + view->obj = nullptr; + std::unique_ptr info; + try { + auto *attr = nb::cast(nb::handle(obj)); + info = attr->accessBuffer(); + } catch (nb::python_error &e) { + e.restore(); + nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); + return -1; + } + view->obj = obj; + view->ndim = 1; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = info->itemsize; + for (auto s : info->shape) { + view->len *= s; + } + view->readonly = info->readonly; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info->format); + } + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = static_cast(info->ndim); + view->strides = info->strides.data(); + view->shape = info->shape.data(); + } + view->suboffsets = nullptr; + view->internal = info.release(); + Py_INCREF(obj); + return 0; +} + +/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, + Py_buffer *view) { + delete reinterpret_cast(view->internal); +} + +/// Refinement of the PyDenseElementsAttribute for attributes containing +/// integer (and boolean) values. Supports element access. +class PyDenseIntElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; + static constexpr const char *pyClassName = "DenseIntElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + /// Returns the element at the given linear position. Asserts if the index + /// is out of range. + nb::object dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw nb::index_error("attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + // Index type can also appear as a DenseIntElementsAttr and therefore can be + // casted to integer. + assert(mlirTypeIsAInteger(type) || + mlirTypeIsAIndex(type) && "expected integer/index element type in " + "dense int elements attribute"); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. nb::int_ is implicitly constructible + // from any C++ integral type and handles bitwidth correctly. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAIndex(type)) { + return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); + } + unsigned width = mlirIntegerTypeGetWidth(type); + bool isUnsigned = mlirIntegerTypeIsUnsigned(type); + if (isUnsigned) { + if (width == 1) { + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); + } + if (width == 8) { + return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); + } + if (width == 16) { + return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); + } + if (width == 32) { + return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); + } + if (width == 64) { + return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); + } + } else { + if (width == 1) { + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); + } + if (width == 8) { + return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); + } + if (width == 16) { + return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); + } + if (width == 32) { + return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); + } + if (width == 64) { + return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); + } + } + throw nb::type_error("Unsupported integer type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); + } +}; + +// Check if the python version is less than 3.13. Py_IsFinalizing is a part +// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing. +#if PY_VERSION_HEX < 0x030d0000 +#define Py_IsFinalizing _Py_IsFinalizing +#endif + +class PyDenseResourceElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = + mlirAttributeIsADenseResourceElements; + static constexpr const char *pyClassName = "DenseResourceElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseResourceElementsAttribute + getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, + std::optional alignment, bool isMutable, + DefaultingPyMlirContext contextWrapper) { + if (!mlirTypeIsAShaped(type)) { + throw std::invalid_argument( + "Constructing a DenseResourceElementsAttr requires a ShapedType."); + } + + // Do not request any conversions as we must ensure to use caller + // managed memory. + int flags = PyBUF_STRIDES; + std::unique_ptr view = std::make_unique(); + if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { + throw nb::python_error(); + } + + // This scope releaser will only release if we haven't yet transferred + // ownership. + auto freeBuffer = llvm::make_scope_exit([&]() { + if (view) + PyBuffer_Release(view.get()); + }); + + if (!PyBuffer_IsContiguous(view.get(), 'A')) { + throw std::invalid_argument("Contiguous buffer is required."); + } + + // Infer alignment to be the stride of one element if not explicit. + size_t inferredAlignment; + if (alignment) + inferredAlignment = *alignment; + else + inferredAlignment = view->strides[view->ndim - 1]; + + // The userData is a Py_buffer* that the deleter owns. + auto deleter = [](void *userData, const void *data, size_t size, + size_t align) { + if (Py_IsFinalizing()) + return; + assert(Py_IsInitialized() && "expected interpreter to be initialized"); + Py_buffer *ownedView = static_cast(userData); + nb::gil_scoped_acquire gil; + PyBuffer_Release(ownedView); + delete ownedView; + }; + + size_t rawBufferSize = view->len; + MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( + type, toMlirStringRef(name), view->buf, rawBufferSize, + inferredAlignment, isMutable, deleter, static_cast(view.get())); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseResourceElementsAttr could not be constructed from the given " + "buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + view.release(); + return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); + } + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("name"), nb::arg("type"), + nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, + nb::arg("context").none() = nb::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); + } +}; + +class PyDictAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; + static constexpr const char *pyClassName = "DictAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirDictionaryAttrGetTypeID; + + intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } + + bool dunderContains(const std::string &name) { + return !mlirAttributeIsNull( + mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); + } + + static void bindDerived(ClassTy &c) { + c.def("__contains__", &PyDictAttribute::dunderContains); + c.def("__len__", &PyDictAttribute::dunderLen); + c.def_static( + "get", + [](nb::dict attributes, DefaultingPyMlirContext context) { + SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(attributes.size()); + for (std::pair it : attributes) { + auto &mlirAttr = nb::cast(it.second); + auto name = nb::cast(it.first); + mlirNamedAttributes.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), + toMlirStringRef(name)), + mlirAttr)); + } + MlirAttribute attr = + mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + return PyDictAttribute(context->getRef(), attr); + }, + nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), + "Gets an uniqued dict attribute"); + c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) + throw nb::key_error("attempt to access a non-existent attribute"); + return attr; + }); + c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { + if (index < 0 || index >= self.dunderLen()) { + throw nb::index_error("attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data)); + }); + } +}; + +/// Refinement of PyDenseElementsAttribute for attributes containing +/// floating-point values. Supports element access. +class PyDenseFPElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; + static constexpr const char *pyClassName = "DenseFPElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + nb::float_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw nb::index_error("attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. nb::float_ is implicitly constructible + // from float and double. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAF32(type)) { + return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); + } + if (mlirTypeIsAF64(type)) { + return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); + } + throw nb::type_error("Unsupported floating-point type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); + } +}; + +class PyTypeAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; + static constexpr const char *pyClassName = "TypeAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTypeAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirTypeAttrGet(value.get()); + return PyTypeAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued Type attribute"); + c.def_prop_ro("value", [](PyTypeAttribute &self) { + return mlirTypeAttrGetValue(self.get()); + }); + } +}; + +/// Unit Attribute subclass. Unit attributes don't have values. +class PyUnitAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; + static constexpr const char *pyClassName = "UnitAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnitAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return PyUnitAttribute(context->getRef(), + mlirUnitAttrGet(context->get())); + }, + nb::arg("context").none() = nb::none(), "Create a Unit attribute."); + } +}; + +/// Strided layout attribute subclass. +class PyStridedLayoutAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; + static constexpr const char *pyClassName = "StridedLayoutAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStridedLayoutAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int64_t offset, const std::vector strides, + DefaultingPyMlirContext ctx) { + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), offset, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + nb::arg("offset"), nb::arg("strides"), + nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute."); + c.def_static( + "get_fully_dynamic", + [](int64_t rank, DefaultingPyMlirContext ctx) { + auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); + std::vector strides(rank); + llvm::fill(strides, dynamic); + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), dynamic, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + nb::arg("rank"), nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute with dynamic offset and strides of " + "a " + "given rank."); + c.def_prop_ro( + "offset", + [](PyStridedLayoutAttribute &self) { + return mlirStridedLayoutAttrGetOffset(self); + }, + "Returns the value of the float point attribute"); + c.def_prop_ro( + "strides", + [](PyStridedLayoutAttribute &self) { + intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); + std::vector strides(size); + for (intptr_t i = 0; i < size; i++) { + strides[i] = mlirStridedLayoutAttrGetStride(self, i); + } + return strides; + }, + "Returns the value of the float point attribute"); + } +}; + +nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); + if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); + if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); + if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); + if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); + if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); + if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); +} + +nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); + if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) + return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); + std::string msg = + std::string( + "Can't cast unknown element type DenseIntOrFPElementsAttr (") + + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); +} + +nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { + if (PyBoolAttribute::isaFunction(pyAttribute)) + return nb::cast(PyBoolAttribute(pyAttribute)); + if (PyIntegerAttribute::isaFunction(pyAttribute)) + return nb::cast(PyIntegerAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); +} + +nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { + if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) + return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); + if (PySymbolRefAttribute::isaFunction(pyAttribute)) + return nb::cast(PySymbolRefAttribute(pyAttribute)); + std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + + nb::cast(nb::repr(nb::cast(pyAttribute))) + + ")"; + throw nb::type_error(msg.c_str()); +} + +} // namespace + +void mlir::python::populateIRAttributes(nb::module_ &m) { + PyAffineMapAttribute::bind(m); + PyDenseBoolArrayAttribute::bind(m); + PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI8ArrayAttribute::bind(m); + PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI16ArrayAttribute::bind(m); + PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI32ArrayAttribute::bind(m); + PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI64ArrayAttribute::bind(m); + PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseF32ArrayAttribute::bind(m); + PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseF64ArrayAttribute::bind(m); + PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseArrayAttrGetTypeID(), + nb::cast(nb::cpp_function(denseArrayAttributeCaster))); + + PyArrayAttribute::bind(m); + PyArrayAttribute::PyArrayAttributeIterator::bind(m); + PyBoolAttribute::bind(m); + PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); + PyDenseFPElementsAttribute::bind(m); + PyDenseIntElementsAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseIntOrFPElementsAttrGetTypeID(), + nb::cast( + nb::cpp_function(denseIntOrFPElementsAttributeCaster))); + PyDenseResourceElementsAttribute::bind(m); + + PyDictAttribute::bind(m); + PySymbolRefAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirSymbolRefAttrGetTypeID(), + nb::cast( + nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); + + PyFlatSymbolRefAttribute::bind(m); + PyOpaqueAttribute::bind(m); + PyFloatAttribute::bind(m); + PyIntegerAttribute::bind(m); + PyIntegerSetAttribute::bind(m); + PyStringAttribute::bind(m); + PyTypeAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirIntegerAttrGetTypeID(), + nb::cast(nb::cpp_function(integerOrBoolAttributeCaster))); + PyUnitAttribute::bind(m); + + PyStridedLayoutAttribute::bind(m); +} diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp new file mode 100644 index 000000000..390cdc542 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -0,0 +1,4412 @@ +//===- IRModules.cpp - IR Submodules of pybind module ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Globals.h" +#include "IRModule.h" +#include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Debug.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "nanobind/nanobind.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" + +#include + +namespace nb = nanobind; +using namespace nb::literals; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +//------------------------------------------------------------------------------ +// Docstrings (trivial, non-duplicated docstrings are included inline). +//------------------------------------------------------------------------------ + +static const char kContextParseTypeDocstring[] = + R"(Parses the assembly form of a type. + +Returns a Type object or raises an MLIRError if the type cannot be parsed. + +See also: https://mlir.llvm.org/docs/LangRef/#type-system +)"; + +static const char kContextGetCallSiteLocationDocstring[] = + R"(Gets a Location representing a caller and callsite)"; + +static const char kContextGetFileLocationDocstring[] = + R"(Gets a Location representing a file, line and column)"; + +static const char kContextGetFileRangeDocstring[] = + R"(Gets a Location representing a file, line and column range)"; + +static const char kContextGetFusedLocationDocstring[] = + R"(Gets a Location representing a fused location with optional metadata)"; + +static const char kContextGetNameLocationDocString[] = + R"(Gets a Location representing a named location with optional child location)"; + +static const char kModuleParseDocstring[] = + R"(Parses a module's assembly format from a string. + +Returns a new MlirModule or raises an MLIRError if the parsing fails. + +See also: https://mlir.llvm.org/docs/LangRef/ +)"; + +static const char kOperationCreateDocstring[] = + R"(Creates a new operation. + +Args: + name: Operation name (e.g. "dialect.operation"). + results: Sequence of Type representing op result types. + attributes: Dict of str:Attribute. + successors: List of Block for the operation's successors. + regions: Number of regions to create. + location: A Location object (defaults to resolve from context manager). + ip: An InsertionPoint (defaults to resolve from context manager or set to + False to disable insertion, even with an insertion point set in the + context manager). + infer_type: Whether to infer result types. +Returns: + A new "detached" Operation object. Detached operations can be added + to blocks, which causes them to become "attached." +)"; + +static const char kOperationPrintDocstring[] = + R"(Prints the assembly form of the operation to a file like object. + +Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource attributes above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_Scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. + skip_regions: Whether to skip printing regions. Defaults to False. +)"; + +static const char kOperationPrintStateDocstring[] = + R"(Prints the assembly form of the operation to a file like object. + +Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + state: AsmState capturing the operation numbering and flags. +)"; + +static const char kOperationGetAsmDocstring[] = + R"(Gets the assembly form of the operation with all options available. + +Args: + binary: Whether to return a bytes (True) or str (False) object. Defaults to + False. + ... others ...: See the print() method for common keyword arguments for + configuring the printout. +Returns: + Either a bytes or str object, depending on the setting of the 'binary' + argument. +)"; + +static const char kOperationPrintBytecodeDocstring[] = + R"(Write the bytecode form of the operation to a file like object. + +Args: + file: The file like object to write to. + desired_version: The version of bytecode to emit. +Returns: + The bytecode writer status. +)"; + +static const char kOperationStrDunderDocstring[] = + R"(Gets the assembly form of the operation with default options. + +If more advanced control over the assembly formatting or I/O options is needed, +use the dedicated print or get_asm method, which supports keyword arguments to +customize behavior. +)"; + +static const char kDumpDocstring[] = + R"(Dumps a debug representation of the object to stderr.)"; + +static const char kAppendBlockDocstring[] = + R"(Appends a new block, with argument types as positional args. + +Returns: + The created block. +)"; + +static const char kValueDunderStrDocstring[] = + R"(Returns the string form of the value. + +If the value is a block argument, this is the assembly form of its type and the +position in the argument list. If the value is an operation result, this is +equivalent to printing the operation that produced it. +)"; + +static const char kGetNameAsOperand[] = + R"(Returns the string form of value as an operand (i.e., the ValueID). +)"; + +static const char kValueReplaceAllUsesWithDocstring[] = + R"(Replace all uses of value with the new value, updating anything in +the IR that uses 'self' to use the other value instead. +)"; + +static const char kValueReplaceAllUsesExceptDocstring[] = + R"("Replace all uses of this value with the 'with' value, except for those +in 'exceptions'. 'exceptions' can be either a single operation or a list of +operations. +)"; + +//------------------------------------------------------------------------------ +// Utilities. +//------------------------------------------------------------------------------ + +/// Helper for creating an @classmethod. +template +nb::object classmethod(Func f, Args... args) { + nb::object cf = nb::cpp_function(f, args...); + return nb::borrow((PyClassMethod_New(cf.ptr()))); +} + +static nb::object +createCustomDialectWrapper(const std::string &dialectNamespace, + nb::object dialectDescriptor) { + auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); + if (!dialectClass) { + // Use the base class. + return nb::cast(PyDialect(std::move(dialectDescriptor))); + } + + // Create the custom implementation. + return (*dialectClass)(std::move(dialectDescriptor)); +} + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +static MlirStringRef toMlirStringRef(std::string_view s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + +/// Create a block, using the current location context if no locations are +/// specified. +static MlirBlock createBlock(const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { + SmallVector argTypes; + argTypes.reserve(nb::len(pyArgTypes)); + for (const auto &pyType : pyArgTypes) + argTypes.push_back(nb::cast(pyType)); + + SmallVector argLocs; + if (pyArgLocs) { + argLocs.reserve(nb::len(*pyArgLocs)); + for (const auto &pyLoc : *pyArgLocs) + argLocs.push_back(nb::cast(pyLoc)); + } else if (!argTypes.empty()) { + argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); + } + + if (argTypes.size() != argLocs.size()) + throw nb::value_error(("Expected " + Twine(argTypes.size()) + + " locations, got: " + Twine(argLocs.size())) + .str() + .c_str()); + return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); +} + +/// Wrapper for the global LLVM debugging flag. +struct PyGlobalDebugFlag { + static void set(nb::object &o, bool enable) { + nb::ft_lock_guard lock(mutex); + mlirEnableGlobalDebug(enable); + } + + static bool get(const nb::object &) { + nb::ft_lock_guard lock(mutex); + return mlirIsGlobalDebugEnabled(); + } + + static void bind(nb::module_ &m) { + // Debug flags. + nb::class_(m, "_GlobalDebug") + .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + .def_static( + "set_types", + [](const std::string &type) { + nb::ft_lock_guard lock(mutex); + mlirSetGlobalDebugType(type.c_str()); + }, + "types"_a, "Sets specific debug types to be produced by LLVM") + .def_static("set_types", [](const std::vector &types) { + std::vector pointers; + pointers.reserve(types.size()); + for (const std::string &str : types) + pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); + mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); + }); + } + +private: + static nb::ft_mutex mutex; +}; + +nb::ft_mutex PyGlobalDebugFlag::mutex; + +struct PyAttrBuilderMap { + static bool dunderContains(const std::string &attributeKind) { + return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); + } + static nb::callable dunderGetItemNamed(const std::string &attributeKind) { + auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); + if (!builder) + throw nb::key_error(attributeKind.c_str()); + return *builder; + } + static void dunderSetItemNamed(const std::string &attributeKind, + nb::callable func, bool replace) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), + replace); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "AttrBuilder") + .def_static("contains", &PyAttrBuilderMap::dunderContains) + .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) + .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, + "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, + "Register an attribute builder for building MLIR " + "attributes from python values."); + } +}; + +//------------------------------------------------------------------------------ +// PyBlock +//------------------------------------------------------------------------------ + +nb::object PyBlock::getCapsule() { + return nb::steal(mlirPythonBlockToCapsule(get())); +} + +//------------------------------------------------------------------------------ +// Collections. +//------------------------------------------------------------------------------ + +namespace { + +class PyRegionIterator { +public: + PyRegionIterator(PyOperationRef operation) + : operation(std::move(operation)) {} + + PyRegionIterator &dunderIter() { return *this; } + + PyRegion dunderNext() { + operation->checkValid(); + if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { + throw nb::stop_iteration(); + } + MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); + return PyRegion(operation, region); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "RegionIterator") + .def("__iter__", &PyRegionIterator::dunderIter) + .def("__next__", &PyRegionIterator::dunderNext); + } + +private: + PyOperationRef operation; + int nextIndex = 0; +}; + +/// Regions of an op are fixed length and indexed numerically so are represented +/// with a sequence-like container. +class PyRegionList : public Sliceable { +public: + static constexpr const char *pyClassName = "RegionSequence"; + + PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumRegions(operation->get()) + : length, + step), + operation(std::move(operation)) {} + + PyRegionIterator dunderIter() { + operation->checkValid(); + return PyRegionIterator(operation); + } + + static void bindDerived(ClassTy &c) { + c.def("__iter__", &PyRegionList::dunderIter); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumRegions(operation->get()); + } + + PyRegion getRawElement(intptr_t pos) { + operation->checkValid(); + return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); + } + + PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyRegionList(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + +class PyBlockIterator { +public: + PyBlockIterator(PyOperationRef operation, MlirBlock next) + : operation(std::move(operation)), next(next) {} + + PyBlockIterator &dunderIter() { return *this; } + + PyBlock dunderNext() { + operation->checkValid(); + if (mlirBlockIsNull(next)) { + throw nb::stop_iteration(); + } + + PyBlock returnBlock(operation, next); + next = mlirBlockGetNextInRegion(next); + return returnBlock; + } + + static void bind(nb::module_ &m) { + nb::class_(m, "BlockIterator") + .def("__iter__", &PyBlockIterator::dunderIter) + .def("__next__", &PyBlockIterator::dunderNext); + } + +private: + PyOperationRef operation; + MlirBlock next; +}; + +/// Blocks are exposed by the C-API as a forward-only linked list. In Python, +/// we present them as a more full-featured list-like container but optimize +/// it for forward iteration. Blocks are always owned by a region. +class PyBlockList { +public: + PyBlockList(PyOperationRef operation, MlirRegion region) + : operation(std::move(operation)), region(region) {} + + PyBlockIterator dunderIter() { + operation->checkValid(); + return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); + } + + intptr_t dunderLen() { + operation->checkValid(); + intptr_t count = 0; + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + count += 1; + block = mlirBlockGetNextInRegion(block); + } + return count; + } + + PyBlock dunderGetItem(intptr_t index) { + operation->checkValid(); + if (index < 0) { + index += dunderLen(); + } + if (index < 0) { + throw nb::index_error("attempt to access out of bounds block"); + } + MlirBlock block = mlirRegionGetFirstBlock(region); + while (!mlirBlockIsNull(block)) { + if (index == 0) { + return PyBlock(operation, block); + } + block = mlirBlockGetNextInRegion(block); + index -= 1; + } + throw nb::index_error("attempt to access out of bounds block"); + } + + PyBlock appendBlock(const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { + operation->checkValid(); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); + mlirRegionAppendOwnedBlock(region, block); + return PyBlock(operation, block); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "BlockList") + .def("__getitem__", &PyBlockList::dunderGetItem) + .def("__iter__", &PyBlockList::dunderIter) + .def("__len__", &PyBlockList::dunderLen) + .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, + nb::arg("args"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt); + } + +private: + PyOperationRef operation; + MlirRegion region; +}; + +class PyOperationIterator { +public: + PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) + : parentOperation(std::move(parentOperation)), next(next) {} + + PyOperationIterator &dunderIter() { return *this; } + + nb::object dunderNext() { + parentOperation->checkValid(); + if (mlirOperationIsNull(next)) { + throw nb::stop_iteration(); + } + + PyOperationRef returnOperation = + PyOperation::forOperation(parentOperation->getContext(), next); + next = mlirOperationGetNextInBlock(next); + return returnOperation->createOpView(); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "OperationIterator") + .def("__iter__", &PyOperationIterator::dunderIter) + .def("__next__", &PyOperationIterator::dunderNext); + } + +private: + PyOperationRef parentOperation; + MlirOperation next; +}; + +/// Operations are exposed by the C-API as a forward-only linked list. In +/// Python, we present them as a more full-featured list-like container but +/// optimize it for forward iteration. Iterable operations are always owned +/// by a block. +class PyOperationList { +public: + PyOperationList(PyOperationRef parentOperation, MlirBlock block) + : parentOperation(std::move(parentOperation)), block(block) {} + + PyOperationIterator dunderIter() { + parentOperation->checkValid(); + return PyOperationIterator(parentOperation, + mlirBlockGetFirstOperation(block)); + } + + intptr_t dunderLen() { + parentOperation->checkValid(); + intptr_t count = 0; + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + count += 1; + childOp = mlirOperationGetNextInBlock(childOp); + } + return count; + } + + nb::object dunderGetItem(intptr_t index) { + parentOperation->checkValid(); + if (index < 0) { + index += dunderLen(); + } + if (index < 0) { + throw nb::index_error("attempt to access out of bounds operation"); + } + MlirOperation childOp = mlirBlockGetFirstOperation(block); + while (!mlirOperationIsNull(childOp)) { + if (index == 0) { + return PyOperation::forOperation(parentOperation->getContext(), childOp) + ->createOpView(); + } + childOp = mlirOperationGetNextInBlock(childOp); + index -= 1; + } + throw nb::index_error("attempt to access out of bounds operation"); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "OperationList") + .def("__getitem__", &PyOperationList::dunderGetItem) + .def("__iter__", &PyOperationList::dunderIter) + .def("__len__", &PyOperationList::dunderLen); + } + +private: + PyOperationRef parentOperation; + MlirBlock block; +}; + +class PyOpOperand { +public: + PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} + + nb::object getOwner() { + MlirOperation owner = mlirOpOperandGetOwner(opOperand); + PyMlirContextRef context = + PyMlirContext::forContext(mlirOperationGetContext(owner)); + return PyOperation::forOperation(context, owner)->createOpView(); + } + + size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } + + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperand") + .def_prop_ro("owner", &PyOpOperand::getOwner) + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); + } + +private: + MlirOpOperand opOperand; +}; + +class PyOpOperandIterator { +public: + PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} + + PyOpOperandIterator &dunderIter() { return *this; } + + PyOpOperand dunderNext() { + if (mlirOpOperandIsNull(opOperand)) + throw nb::stop_iteration(); + + PyOpOperand returnOpOperand(opOperand); + opOperand = mlirOpOperandGetNextUse(opOperand); + return returnOpOperand; + } + + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperandIterator") + .def("__iter__", &PyOpOperandIterator::dunderIter) + .def("__next__", &PyOpOperandIterator::dunderNext); + } + +private: + MlirOpOperand opOperand; +}; + +} // namespace + +//------------------------------------------------------------------------------ +// PyMlirContext +//------------------------------------------------------------------------------ + +PyMlirContext::PyMlirContext(MlirContext context) : context(context) { + nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); + auto &liveContexts = getLiveContexts(); + liveContexts[context.ptr] = this; +} + +PyMlirContext::~PyMlirContext() { + // Note that the only public way to construct an instance is via the + // forContext method, which always puts the associated handle into + // liveContexts. + nb::gil_scoped_acquire acquire; + { + nb::ft_lock_guard lock(live_contexts_mutex); + getLiveContexts().erase(context.ptr); + } + mlirContextDestroy(context); +} + +nb::object PyMlirContext::getCapsule() { + return nb::steal(mlirPythonContextToCapsule(get())); +} + +nb::object PyMlirContext::createFromCapsule(nb::object capsule) { + MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); + if (mlirContextIsNull(rawContext)) + throw nb::python_error(); + return forContext(rawContext).releaseObject(); +} + +PyMlirContextRef PyMlirContext::forContext(MlirContext context) { + nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); + auto &liveContexts = getLiveContexts(); + auto it = liveContexts.find(context.ptr); + if (it == liveContexts.end()) { + // Create. + PyMlirContext *unownedContextWrapper = new PyMlirContext(context); + nb::object pyRef = nb::cast(unownedContextWrapper); + assert(pyRef && "cast to nb::object failed"); + liveContexts[context.ptr] = unownedContextWrapper; + return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); + } + // Use existing. + nb::object pyRef = nb::cast(it->second); + return PyMlirContextRef(it->second, std::move(pyRef)); +} + +nb::ft_mutex PyMlirContext::live_contexts_mutex; + +PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { + static LiveContextMap liveContexts; + return liveContexts; +} + +size_t PyMlirContext::getLiveCount() { + nb::ft_lock_guard lock(live_contexts_mutex); + return getLiveContexts().size(); +} + +size_t PyMlirContext::getLiveOperationCount() { + nb::ft_lock_guard lock(liveOperationsMutex); + return liveOperations.size(); +} + +std::vector PyMlirContext::getLiveOperationObjects() { + std::vector liveObjects; + nb::ft_lock_guard lock(liveOperationsMutex); + for (auto &entry : liveOperations) + liveObjects.push_back(entry.second.second); + return liveObjects; +} + +size_t PyMlirContext::clearLiveOperations() { + + LiveOperationMap operations; + { + nb::ft_lock_guard lock(liveOperationsMutex); + std::swap(operations, liveOperations); + } + for (auto &op : operations) + op.second.second->setInvalid(); + size_t numInvalidated = operations.size(); + return numInvalidated; +} + +void PyMlirContext::clearOperation(MlirOperation op) { + PyOperation *py_op; + { + nb::ft_lock_guard lock(liveOperationsMutex); + auto it = liveOperations.find(op.ptr); + if (it == liveOperations.end()) { + return; + } + py_op = it->second.second; + liveOperations.erase(it); + } + py_op->setInvalid(); +} + +void PyMlirContext::clearOperationsInside(PyOperationBase &op) { + typedef struct { + PyOperation &rootOp; + bool rootSeen; + } callBackData; + callBackData data{op.getOperation(), false}; + // Mark all ops below the op that the passmanager will be rooted + // at (but not op itself - note the preorder) as invalid. + MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, + void *userData) { + callBackData *data = static_cast(userData); + if (LLVM_LIKELY(data->rootSeen)) + data->rootOp.getOperation().getContext()->clearOperation(op); + else + data->rootSeen = true; + return MlirWalkResult::MlirWalkResultAdvance; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + static_cast(&data), MlirWalkPreOrder); +} +void PyMlirContext::clearOperationsInside(MlirOperation op) { + PyOperationRef opRef = PyOperation::forOperation(getRef(), op); + clearOperationsInside(opRef->getOperation()); +} + +void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { + MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, + void *userData) { + PyMlirContextRef &contextRef = *static_cast(userData); + contextRef->clearOperation(op); + return MlirWalkResult::MlirWalkResultAdvance; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + &op.getOperation().getContext(), MlirWalkPreOrder); +} + +size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } + +nb::object PyMlirContext::contextEnter(nb::object context) { + return PyThreadContextEntry::pushContext(context); +} + +void PyMlirContext::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { + PyThreadContextEntry::popContext(*this); +} + +nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { + // Note that ownership is transferred to the delete callback below by way of + // an explicit inc_ref (borrow). + PyDiagnosticHandler *pyHandler = + new PyDiagnosticHandler(get(), std::move(callback)); + nb::object pyHandlerObject = + nb::cast(pyHandler, nb::rv_policy::take_ownership); + pyHandlerObject.inc_ref(); + + // In these C callbacks, the userData is a PyDiagnosticHandler* that is + // guaranteed to be known to pybind. + auto handlerCallback = + +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { + PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); + nb::object pyDiagnosticObject = + nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); + + auto *pyHandler = static_cast(userData); + bool result = false; + { + // Since this can be called from arbitrary C++ contexts, always get the + // gil. + nb::gil_scoped_acquire gil; + try { + result = nb::cast(pyHandler->callback(pyDiagnostic)); + } catch (std::exception &e) { + fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", + e.what()); + pyHandler->hadError = true; + } + } + + pyDiagnostic->invalidate(); + return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); + }; + auto deleteCallback = +[](void *userData) { + auto *pyHandler = static_cast(userData); + assert(pyHandler->registeredID && "handler is not registered"); + pyHandler->registeredID.reset(); + + // Decrement reference, balancing the inc_ref() above. + nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); + pyHandlerObject.dec_ref(); + }; + + pyHandler->registeredID = mlirContextAttachDiagnosticHandler( + get(), handlerCallback, static_cast(pyHandler), deleteCallback); + return pyHandlerObject; +} + +MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, + void *userData) { + auto *self = static_cast(userData); + // Check if the context requested we emit errors instead of capturing them. + if (self->ctx->emitErrorDiagnostics) + return mlirLogicalResultFailure(); + + if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) + return mlirLogicalResultFailure(); + + self->errors.emplace_back(PyDiagnostic(diag).getInfo()); + return mlirLogicalResultSuccess(); +} + +PyMlirContext &DefaultingPyMlirContext::resolve() { + PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); + if (!context) { + throw std::runtime_error( + "An MLIR function requires a Context but none was provided in the call " + "or from the surrounding environment. Either pass to the function with " + "a 'context=' argument or establish a default using 'with Context():'"); + } + return *context; +} + +//------------------------------------------------------------------------------ +// PyThreadContextEntry management +//------------------------------------------------------------------------------ + +std::vector &PyThreadContextEntry::getStack() { + static thread_local std::vector stack; + return stack; +} + +PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { + auto &stack = getStack(); + if (stack.empty()) + return nullptr; + return &stack.back(); +} + +void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, + nb::object insertionPoint, + nb::object location) { + auto &stack = getStack(); + stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), + std::move(location)); + // If the new stack has more than one entry and the context of the new top + // entry matches the previous, copy the insertionPoint and location from the + // previous entry if missing from the new top entry. + if (stack.size() > 1) { + auto &prev = *(stack.rbegin() + 1); + auto ¤t = stack.back(); + if (current.context.is(prev.context)) { + // Default non-context objects from the previous entry. + if (!current.insertionPoint) + current.insertionPoint = prev.insertionPoint; + if (!current.location) + current.location = prev.location; + } + } +} + +PyMlirContext *PyThreadContextEntry::getContext() { + if (!context) + return nullptr; + return nb::cast(context); +} + +PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { + if (!insertionPoint) + return nullptr; + return nb::cast(insertionPoint); +} + +PyLocation *PyThreadContextEntry::getLocation() { + if (!location) + return nullptr; + return nb::cast(location); +} + +PyMlirContext *PyThreadContextEntry::getDefaultContext() { + auto *tos = getTopOfStack(); + return tos ? tos->getContext() : nullptr; +} + +PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { + auto *tos = getTopOfStack(); + return tos ? tos->getInsertionPoint() : nullptr; +} + +PyLocation *PyThreadContextEntry::getDefaultLocation() { + auto *tos = getTopOfStack(); + return tos ? tos->getLocation() : nullptr; +} + +nb::object PyThreadContextEntry::pushContext(nb::object context) { + push(FrameKind::Context, /*context=*/context, + /*insertionPoint=*/nb::object(), + /*location=*/nb::object()); + return context; +} + +void PyThreadContextEntry::popContext(PyMlirContext &context) { + auto &stack = getStack(); + if (stack.empty()) + throw std::runtime_error("Unbalanced Context enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) + throw std::runtime_error("Unbalanced Context enter/exit"); + stack.pop_back(); +} + +nb::object +PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { + PyInsertionPoint &insertionPoint = + nb::cast(insertionPointObj); + nb::object contextObj = + insertionPoint.getBlock().getParentOperation()->getContext().getObject(); + push(FrameKind::InsertionPoint, + /*context=*/contextObj, + /*insertionPoint=*/insertionPointObj, + /*location=*/nb::object()); + return insertionPointObj; +} + +void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { + auto &stack = getStack(); + if (stack.empty()) + throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::InsertionPoint && + tos.getInsertionPoint() != &insertionPoint) + throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); + stack.pop_back(); +} + +nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { + PyLocation &location = nb::cast(locationObj); + nb::object contextObj = location.getContext().getObject(); + push(FrameKind::Location, /*context=*/contextObj, + /*insertionPoint=*/nb::object(), + /*location=*/locationObj); + return locationObj; +} + +void PyThreadContextEntry::popLocation(PyLocation &location) { + auto &stack = getStack(); + if (stack.empty()) + throw std::runtime_error("Unbalanced Location enter/exit"); + auto &tos = stack.back(); + if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) + throw std::runtime_error("Unbalanced Location enter/exit"); + stack.pop_back(); +} + +//------------------------------------------------------------------------------ +// PyDiagnostic* +//------------------------------------------------------------------------------ + +void PyDiagnostic::invalidate() { + valid = false; + if (materializedNotes) { + for (nb::handle noteObject : *materializedNotes) { + PyDiagnostic *note = nb::cast(noteObject); + note->invalidate(); + } + } +} + +PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, + nb::object callback) + : context(context), callback(std::move(callback)) {} + +PyDiagnosticHandler::~PyDiagnosticHandler() = default; + +void PyDiagnosticHandler::detach() { + if (!registeredID) + return; + MlirDiagnosticHandlerID localID = *registeredID; + mlirContextDetachDiagnosticHandler(context, localID); + assert(!registeredID && "should have unregistered"); + // Not strictly necessary but keeps stale pointers from being around to cause + // issues. + context = {nullptr}; +} + +void PyDiagnostic::checkValid() { + if (!valid) { + throw std::invalid_argument( + "Diagnostic is invalid (used outside of callback)"); + } +} + +MlirDiagnosticSeverity PyDiagnostic::getSeverity() { + checkValid(); + return mlirDiagnosticGetSeverity(diagnostic); +} + +PyLocation PyDiagnostic::getLocation() { + checkValid(); + MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); + MlirContext context = mlirLocationGetContext(loc); + return PyLocation(PyMlirContext::forContext(context), loc); +} + +nb::str PyDiagnostic::getMessage() { + checkValid(); + nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); + PyFileAccumulator accum(fileObject, /*binary=*/false); + mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); + return nb::cast(fileObject.attr("getvalue")()); +} + +nb::tuple PyDiagnostic::getNotes() { + checkValid(); + if (materializedNotes) + return *materializedNotes; + intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); + nb::tuple notes = nb::steal(PyTuple_New(numNotes)); + for (intptr_t i = 0; i < numNotes; ++i) { + MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); + nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); + PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); + } + materializedNotes = std::move(notes); + + return *materializedNotes; +} + +PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { + std::vector notes; + for (nb::handle n : getNotes()) + notes.emplace_back(nb::cast(n).getInfo()); + return {getSeverity(), getLocation(), nb::cast(getMessage()), + std::move(notes)}; +} + +//------------------------------------------------------------------------------ +// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry +//------------------------------------------------------------------------------ + +MlirDialect PyDialects::getDialectForKey(const std::string &key, + bool attrError) { + MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), + {key.data(), key.size()}); + if (mlirDialectIsNull(dialect)) { + std::string msg = (Twine("Dialect '") + key + "' not found").str(); + if (attrError) + throw nb::attribute_error(msg.c_str()); + throw nb::index_error(msg.c_str()); + } + return dialect; +} + +nb::object PyDialectRegistry::getCapsule() { + return nb::steal(mlirPythonDialectRegistryToCapsule(*this)); +} + +PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { + MlirDialectRegistry rawRegistry = + mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + if (mlirDialectRegistryIsNull(rawRegistry)) + throw nb::python_error(); + return PyDialectRegistry(rawRegistry); +} + +//------------------------------------------------------------------------------ +// PyLocation +//------------------------------------------------------------------------------ + +nb::object PyLocation::getCapsule() { + return nb::steal(mlirPythonLocationToCapsule(*this)); +} + +PyLocation PyLocation::createFromCapsule(nb::object capsule) { + MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); + if (mlirLocationIsNull(rawLoc)) + throw nb::python_error(); + return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), + rawLoc); +} + +nb::object PyLocation::contextEnter(nb::object locationObj) { + return PyThreadContextEntry::pushLocation(locationObj); +} + +void PyLocation::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { + PyThreadContextEntry::popLocation(*this); +} + +PyLocation &DefaultingPyLocation::resolve() { + auto *location = PyThreadContextEntry::getDefaultLocation(); + if (!location) { + throw std::runtime_error( + "An MLIR function requires a Location but none was provided in the " + "call or from the surrounding environment. Either pass to the function " + "with a 'loc=' argument or establish a default using 'with loc:'"); + } + return *location; +} + +//------------------------------------------------------------------------------ +// PyModule +//------------------------------------------------------------------------------ + +PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) + : BaseContextObject(std::move(contextRef)), module(module) {} + +PyModule::~PyModule() { + nb::gil_scoped_acquire acquire; + auto &liveModules = getContext()->liveModules; + assert(liveModules.count(module.ptr) == 1 && + "destroying module not in live map"); + liveModules.erase(module.ptr); + mlirModuleDestroy(module); +} + +PyModuleRef PyModule::forModule(MlirModule module) { + MlirContext context = mlirModuleGetContext(module); + PyMlirContextRef contextRef = PyMlirContext::forContext(context); + + nb::gil_scoped_acquire acquire; + auto &liveModules = contextRef->liveModules; + auto it = liveModules.find(module.ptr); + if (it == liveModules.end()) { + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + unownedModule->handle = pyRef; + liveModules[module.ptr] = + std::make_pair(unownedModule->handle, unownedModule); + return PyModuleRef(unownedModule, std::move(pyRef)); + } + // Use existing. + PyModule *existing = it->second.second; + nb::object pyRef = nb::borrow(it->second.first); + return PyModuleRef(existing, std::move(pyRef)); +} + +nb::object PyModule::createFromCapsule(nb::object capsule) { + MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); + if (mlirModuleIsNull(rawModule)) + throw nb::python_error(); + return forModule(rawModule).releaseObject(); +} + +nb::object PyModule::getCapsule() { + return nb::steal(mlirPythonModuleToCapsule(get())); +} + +//------------------------------------------------------------------------------ +// PyOperation +//------------------------------------------------------------------------------ + +PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) + : BaseContextObject(std::move(contextRef)), operation(operation) {} + +PyOperation::~PyOperation() { + // If the operation has already been invalidated there is nothing to do. + if (!valid) + return; + + // Otherwise, invalidate the operation and remove it from live map when it is + // attached. + if (isAttached()) { + getContext()->clearOperation(*this); + } else { + // And destroy it when it is detached, i.e. owned by Python, in which case + // all nested operations must be invalidated at removed from the live map as + // well. + erase(); + } +} + +namespace { + +// Constructs a new object of type T in-place on the Python heap, returning a +// PyObjectRef to it, loosely analogous to std::make_shared(). +template +PyObjectRef makeObjectRef(Args &&...args) { + nb::handle type = nb::type(); + nb::object instance = nb::inst_alloc(type); + T *ptr = nb::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nb::inst_mark_ready(instance); + return PyObjectRef(ptr, std::move(instance)); +} + +} // namespace + +PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, + MlirOperation operation, + nb::object parentKeepAlive) { + // Create. + PyOperationRef unownedOperation = + makeObjectRef(std::move(contextRef), operation); + unownedOperation->handle = unownedOperation.getObject(); + if (parentKeepAlive) { + unownedOperation->parentKeepAlive = std::move(parentKeepAlive); + } + return unownedOperation; +} + +PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, + MlirOperation operation, + nb::object parentKeepAlive) { + nb::ft_lock_guard lock(contextRef->liveOperationsMutex); + auto &liveOperations = contextRef->liveOperations; + auto it = liveOperations.find(operation.ptr); + if (it == liveOperations.end()) { + // Create. + PyOperationRef result = createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + liveOperations[operation.ptr] = + std::make_pair(result.getObject(), result.get()); + return result; + } + // Use existing. + PyOperation *existing = it->second.second; + nb::object pyRef = nb::borrow(it->second.first); + return PyOperationRef(existing, std::move(pyRef)); +} + +PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, + MlirOperation operation, + nb::object parentKeepAlive) { + nb::ft_lock_guard lock(contextRef->liveOperationsMutex); + auto &liveOperations = contextRef->liveOperations; + assert(liveOperations.count(operation.ptr) == 0 && + "cannot create detached operation that already exists"); + (void)liveOperations; + PyOperationRef created = createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + liveOperations[operation.ptr] = + std::make_pair(created.getObject(), created.get()); + created->attached = false; + return created; +} + +PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, + const std::string &sourceStr, + const std::string &sourceName) { + PyMlirContext::ErrorCapture errors(contextRef); + MlirOperation op = + mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), + toMlirStringRef(sourceName)); + if (mlirOperationIsNull(op)) + throw MLIRError("Unable to parse operation assembly", errors.take()); + return PyOperation::createDetached(std::move(contextRef), op); +} + +void PyOperation::checkValid() const { + if (!valid) { + throw std::runtime_error("the operation has been invalidated"); + } +} + +void PyOperationBase::print(std::optional largeElementsLimit, + std::optional largeResourceLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope, + bool useNameLocAsPrefix, bool assumeVerified, + nb::object fileObject, bool binary, + bool skipRegions) { + PyOperation &operation = getOperation(); + operation.checkValid(); + if (fileObject.is_none()) + fileObject = nb::module_::import_("sys").attr("stdout"); + + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (largeElementsLimit) + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); + if (largeResourceLimit) + mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit); + if (enableDebugInfo) + mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + /*prettyForm=*/prettyDebugInfo); + if (printGenericOpForm) + mlirOpPrintingFlagsPrintGenericOpForm(flags); + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + if (assumeVerified) + mlirOpPrintingFlagsAssumeVerified(flags); + if (skipRegions) + mlirOpPrintingFlagsSkipRegions(flags); + if (useNameLocAsPrefix) + mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); + + PyFileAccumulator accum(fileObject, binary); + mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), + accum.getUserData()); + mlirOpPrintingFlagsDestroy(flags); +} + +void PyOperationBase::print(PyAsmState &state, nb::object fileObject, + bool binary) { + PyOperation &operation = getOperation(); + operation.checkValid(); + if (fileObject.is_none()) + fileObject = nb::module_::import_("sys").attr("stdout"); + PyFileAccumulator accum(fileObject, binary); + mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), + accum.getUserData()); +} + +void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject, + std::optional bytecodeVersion) { + PyOperation &operation = getOperation(); + operation.checkValid(); + PyFileAccumulator accum(fileOrStringObject, /*binary=*/true); + + if (!bytecodeVersion.has_value()) + return mlirOperationWriteBytecode(operation, accum.getCallback(), + accum.getUserData()); + + MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); + mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); + MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( + operation, config, accum.getCallback(), accum.getUserData()); + mlirBytecodeWriterConfigDestroy(config); + if (mlirLogicalResultIsFailure(res)) + throw nb::value_error((Twine("Unable to honor desired bytecode version ") + + Twine(*bytecodeVersion)) + .str() + .c_str()); +} + +void PyOperationBase::walk( + std::function callback, + MlirWalkOrder walkOrder) { + PyOperation &operation = getOperation(); + operation.checkValid(); + struct UserData { + std::function callback; + bool gotException; + std::string exceptionWhat; + nb::object exceptionType; + }; + UserData userData{callback, false, {}, {}}; + MlirOperationWalkCallback walkCallback = [](MlirOperation op, + void *userData) { + UserData *calleeUserData = static_cast(userData); + try { + return (calleeUserData->callback)(op); + } catch (nb::python_error &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = std::string(e.what()); + calleeUserData->exceptionType = nb::borrow(e.type()); + return MlirWalkResult::MlirWalkResultInterrupt; + } + }; + mlirOperationWalk(operation, walkCallback, &userData, walkOrder); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(message); + } +} + +nb::object PyOperationBase::getAsm(bool binary, + std::optional largeElementsLimit, + std::optional largeResourceLimit, + bool enableDebugInfo, bool prettyDebugInfo, + bool printGenericOpForm, bool useLocalScope, + bool useNameLocAsPrefix, bool assumeVerified, + bool skipRegions) { + nb::object fileObject; + if (binary) { + fileObject = nb::module_::import_("io").attr("BytesIO")(); + } else { + fileObject = nb::module_::import_("io").attr("StringIO")(); + } + print(/*largeElementsLimit=*/largeElementsLimit, + /*largeResourceLimit=*/largeResourceLimit, + /*enableDebugInfo=*/enableDebugInfo, + /*prettyDebugInfo=*/prettyDebugInfo, + /*printGenericOpForm=*/printGenericOpForm, + /*useLocalScope=*/useLocalScope, + /*useNameLocAsPrefix=*/useNameLocAsPrefix, + /*assumeVerified=*/assumeVerified, + /*fileObject=*/fileObject, + /*binary=*/binary, + /*skipRegions=*/skipRegions); + + return fileObject.attr("getvalue")(); +} + +void PyOperationBase::moveAfter(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + mlirOperationMoveAfter(operation, otherOp); + operation.parentKeepAlive = otherOp.parentKeepAlive; +} + +void PyOperationBase::moveBefore(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + mlirOperationMoveBefore(operation, otherOp); + operation.parentKeepAlive = otherOp.parentKeepAlive; +} + +bool PyOperationBase::isBeforeInBlock(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + return mlirOperationIsBeforeInBlock(operation, otherOp); +} + +bool PyOperationBase::verify() { + PyOperation &op = getOperation(); + PyMlirContext::ErrorCapture errors(op.getContext()); + if (!mlirOperationVerify(op.get())) + throw MLIRError("Verification failed", errors.take()); + return true; +} + +std::optional PyOperation::getParentOperation() { + checkValid(); + if (!isAttached()) + throw nb::value_error("Detached operations have no parent"); + MlirOperation operation = mlirOperationGetParentOperation(get()); + if (mlirOperationIsNull(operation)) + return {}; + return PyOperation::forOperation(getContext(), operation); +} + +PyBlock PyOperation::getBlock() { + checkValid(); + std::optional parentOperation = getParentOperation(); + MlirBlock block = mlirOperationGetBlock(get()); + assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); + assert(parentOperation && "Operation has no parent"); + return PyBlock{std::move(*parentOperation), block}; +} + +nb::object PyOperation::getCapsule() { + checkValid(); + return nb::steal(mlirPythonOperationToCapsule(get())); +} + +nb::object PyOperation::createFromCapsule(nb::object capsule) { + MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); + if (mlirOperationIsNull(rawOperation)) + throw nb::python_error(); + MlirContext rawCtxt = mlirOperationGetContext(rawOperation); + return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) + .releaseObject(); +} + +static void maybeInsertOperation(PyOperationRef &op, + const nb::object &maybeIp) { + // InsertPoint active? + if (!maybeIp.is(nb::cast(false))) { + PyInsertionPoint *ip; + if (maybeIp.is_none()) { + ip = PyThreadContextEntry::getDefaultInsertionPoint(); + } else { + ip = nb::cast(maybeIp); + } + if (ip) + ip->insert(*op.get()); + } +} + +nb::object PyOperation::create(std::string_view name, + std::optional> results, + llvm::ArrayRef operands, + std::optional attributes, + std::optional> successors, + int regions, PyLocation &location, + const nb::object &maybeIp, bool inferType) { + llvm::SmallVector mlirResults; + llvm::SmallVector mlirSuccessors; + llvm::SmallVector, 4> mlirAttributes; + + // General parameter validation. + if (regions < 0) + throw nb::value_error("number of regions must be >= 0"); + + // Unpack/validate results. + if (results) { + mlirResults.reserve(results->size()); + for (PyType *result : *results) { + // TODO: Verify result type originate from the same context. + if (!result) + throw nb::value_error("result type cannot be None"); + mlirResults.push_back(*result); + } + } + // Unpack/validate attributes. + if (attributes) { + mlirAttributes.reserve(attributes->size()); + for (std::pair it : *attributes) { + std::string key; + try { + key = nb::cast(it.first); + } catch (nb::cast_error &err) { + std::string msg = "Invalid attribute key (not a string) when " + "attempting to create the operation \"" + + std::string(name) + "\" (" + err.what() + ")"; + throw nb::type_error(msg.c_str()); + } + try { + auto &attribute = nb::cast(it.second); + // TODO: Verify attribute originates from the same context. + mlirAttributes.emplace_back(std::move(key), attribute); + } catch (nb::cast_error &err) { + std::string msg = "Invalid attribute value for the key \"" + key + + "\" when attempting to create the operation \"" + + std::string(name) + "\" (" + err.what() + ")"; + throw nb::type_error(msg.c_str()); + } catch (std::runtime_error &) { + // This exception seems thrown when the value is "None". + std::string msg = + "Found an invalid (`None`?) attribute value for the key \"" + key + + "\" when attempting to create the operation \"" + + std::string(name) + "\""; + throw std::runtime_error(msg); + } + } + } + // Unpack/validate successors. + if (successors) { + mlirSuccessors.reserve(successors->size()); + for (auto *successor : *successors) { + // TODO: Verify successor originate from the same context. + if (!successor) + throw nb::value_error("successor block cannot be None"); + mlirSuccessors.push_back(successor->get()); + } + } + + // Apply unpacked/validated to the operation state. Beyond this + // point, exceptions cannot be thrown or else the state will leak. + MlirOperationState state = + mlirOperationStateGet(toMlirStringRef(name), location); + if (!operands.empty()) + mlirOperationStateAddOperands(&state, operands.size(), operands.data()); + state.enableResultTypeInference = inferType; + if (!mlirResults.empty()) + mlirOperationStateAddResults(&state, mlirResults.size(), + mlirResults.data()); + if (!mlirAttributes.empty()) { + // Note that the attribute names directly reference bytes in + // mlirAttributes, so that vector must not be changed from here + // on. + llvm::SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(mlirAttributes.size()); + for (auto &it : mlirAttributes) + mlirNamedAttributes.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(it.second), + toMlirStringRef(it.first)), + it.second)); + mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + } + if (!mlirSuccessors.empty()) + mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), + mlirSuccessors.data()); + if (regions) { + llvm::SmallVector mlirRegions; + mlirRegions.resize(regions); + for (int i = 0; i < regions; ++i) + mlirRegions[i] = mlirRegionCreate(); + mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), + mlirRegions.data()); + } + + // Construct the operation. + MlirOperation operation = mlirOperationCreate(&state); + if (!operation.ptr) + throw nb::value_error("Operation creation failed"); + PyOperationRef created = + PyOperation::createDetached(location.getContext(), operation); + maybeInsertOperation(created, maybeIp); + + return created.getObject(); +} + +nb::object PyOperation::clone(const nb::object &maybeIp) { + MlirOperation clonedOperation = mlirOperationClone(operation); + PyOperationRef cloned = + PyOperation::createDetached(getContext(), clonedOperation); + maybeInsertOperation(cloned, maybeIp); + + return cloned->createOpView(); +} + +nb::object PyOperation::createOpView() { + checkValid(); + MlirIdentifier ident = mlirOperationGetName(get()); + MlirStringRef identStr = mlirIdentifierStr(ident); + auto operationCls = PyGlobals::get().lookupOperationClass( + StringRef(identStr.data, identStr.length)); + if (operationCls) + return PyOpView::constructDerived(*operationCls, getRef().getObject()); + return nb::cast(PyOpView(getRef().getObject())); +} + +void PyOperation::erase() { + checkValid(); + getContext()->clearOperationAndInside(*this); + mlirOperationDestroy(operation); +} + +namespace { +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template +class PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = nb::class_; + using IsAFunctionTy = bool (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str() + .c_str()); + } + return orig.get(); + } + + /// Binds the Python module objects to functions of this class. + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + nb::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +} // namespace + +/// Python wrapper for MlirOpResult. +class PyOpResult : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; + static constexpr const char *pyClassName = "OpResult"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("owner", [](PyOpResult &self) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in the IR"); + return self.getParentOperation().getObject(); + }); + c.def_prop_ro("result_number", [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }); + } +}; + +/// Returns the list of types of the values held by container. +template +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; + result.reserve(container.size()); + for (int i = 0, e = container.size(); i < e; ++i) { + result.push_back(mlirValueGetType(container.getElement(i).get())); + } + return result; +} + +/// A list of operation results. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) result list is associated +/// with the operation whose results these are, and thus extends the lifetime of +/// this operation. +class PyOpResultList : public Sliceable { +public: + static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; + + PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumResults(operation->get()) + : length, + step), + operation(std::move(operation)) {} + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + c.def_prop_ro("owner", [](PyOpResultList &self) { + return self.operation->createOpView(); + }); + } + + PyOperationRef &getOperation() { return operation; } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumResults(operation->get()); + } + + PyOpResult getRawElement(intptr_t index) { + PyValue value(operation, mlirOperationGetResult(operation->get(), index)); + return PyOpResult(value); + } + + PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpResultList(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + +//------------------------------------------------------------------------------ +// PyOpView +//------------------------------------------------------------------------------ + +static void populateResultTypes(StringRef name, nb::list resultTypeList, + const nb::object &resultSegmentSpecObj, + std::vector &resultSegmentLengths, + std::vector &resultTypes) { + resultTypes.reserve(resultTypeList.size()); + if (resultSegmentSpecObj.is_none()) { + // Non-variadic result unpacking. + for (const auto &it : llvm::enumerate(resultTypeList)) { + try { + resultTypes.push_back(nb::cast(it.value())); + if (!resultTypes.back()) + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Type (" + err.what() + ")") + .str() + .c_str()); + } + } + } else { + // Sized result unpacking. + auto resultSegmentSpec = nb::cast>(resultSegmentSpecObj); + if (resultSegmentSpec.size() != resultTypeList.size()) { + throw nb::value_error((llvm::Twine("Operation \"") + name + + "\" requires " + + llvm::Twine(resultSegmentSpec.size()) + + " result segments but was provided " + + llvm::Twine(resultTypeList.size())) + .str() + .c_str()); + } + resultSegmentLengths.reserve(resultTypeList.size()); + for (const auto &it : + llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1 || segmentSpec == 0) { + // Unpack unary element. + try { + auto *resultType = nb::cast(std::get<0>(it.value())); + if (resultType) { + resultTypes.push_back(resultType); + resultSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + resultSegmentLengths.push_back(0); + } else { + throw nb::value_error( + (llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Type (was None and result is not optional)") + .str() + .c_str()); + } + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Type (" + err.what() + + ")") + .str() + .c_str()); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + if (std::get<0>(it.value()).is_none()) { + // Treat it as an empty list. + resultSegmentLengths.push_back(0); + } else { + // Unpack the list. + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + resultTypes.push_back(nb::cast(segmentItem)); + if (!resultTypes.back()) { + throw nb::type_error("contained a None item"); + } + } + resultSegmentLengths.push_back(nb::len(segment)); + } + } catch (std::exception &err) { + // NOTE: Sloppy to be using a catch-all here, but there are at least + // three different unrelated exceptions that can be thrown in the + // above "casts". Just keep the scope above small and catch them all. + throw nb::value_error((llvm::Twine("Result ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Sequence of Types (" + + err.what() + ")") + .str() + .c_str()); + } + } else { + throw nb::value_error("Unexpected segment spec"); + } + } + } +} + +static MlirValue getUniqueResult(MlirOperation operation) { + auto numResults = mlirOperationGetNumResults(operation); + if (numResults != 1) { + auto name = mlirIdentifierStr(mlirOperationGetName(operation)); + throw nb::value_error((Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)") + .str() + .c_str()); + } + return mlirOperationGetResult(operation, 0); +} + +static MlirValue getOpResultOrValue(nb::handle operand) { + if (operand.is_none()) { + throw nb::value_error("contained a None item"); + } + PyOperationBase *op; + if (nb::try_cast(operand, op)) { + return getUniqueResult(op->getOperation()); + } + PyOpResultList *opResultList; + if (nb::try_cast(operand, opResultList)) { + return getUniqueResult(opResultList->getOperation()->get()); + } + PyValue *value; + if (nb::try_cast(operand, value)) { + return value->get(); + } + throw nb::value_error("is not a Value"); +} + +nb::object PyOpView::buildGeneric( + std::string_view name, std::tuple opRegionSpec, + nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, + std::optional resultTypeList, nb::list operandList, + std::optional attributes, + std::optional> successors, + std::optional regions, PyLocation &location, + const nb::object &maybeIp) { + PyMlirContextRef context = location.getContext(); + + // Class level operation construction metadata. + // Operand and result segment specs are either none, which does no + // variadic unpacking, or a list of ints with segment sizes, where each + // element is either a positive number (typically 1 for a scalar) or -1 to + // indicate that it is derived from the length of the same-indexed operand + // or result (implying that it is a list at that position). + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; + + // Validate/determine region count. + int opMinRegionCount = std::get<0>(opRegionSpec); + bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); + if (!regions) { + regions = opMinRegionCount; + } + if (*regions < opMinRegionCount) { + throw nb::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str() + .c_str()); + } + if (opHasNoVariadicRegions && *regions > opMinRegionCount) { + throw nb::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str() + .c_str()); + } + + // Unpack results. + std::vector resultTypes; + if (resultTypeList.has_value()) { + populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, + resultSegmentLengths, resultTypes); + } + + // Unpack operands. + llvm::SmallVector operands; + operands.reserve(operands.size()); + if (operandSegmentSpecObj.is_none()) { + // Non-sized operand unpacking. + for (const auto &it : llvm::enumerate(operandList)) { + try { + operands.push_back(getOpResultOrValue(it.value())); + } catch (nb::builtin_exception &err) { + throw nb::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Value (" + err.what() + ")") + .str() + .c_str()); + } + } + } else { + // Sized operand unpacking. + auto operandSegmentSpec = nb::cast>(operandSegmentSpecObj); + if (operandSegmentSpec.size() != operandList.size()) { + throw nb::value_error((llvm::Twine("Operation \"") + name + + "\" requires " + + llvm::Twine(operandSegmentSpec.size()) + + "operand segments but was provided " + + llvm::Twine(operandList.size())) + .str() + .c_str()); + } + operandSegmentLengths.reserve(operandList.size()); + for (const auto &it : + llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { + int segmentSpec = std::get<1>(it.value()); + if (segmentSpec == 1 || segmentSpec == 0) { + // Unpack unary element. + auto &operand = std::get<0>(it.value()); + if (!operand.is_none()) { + try { + + operands.push_back(getOpResultOrValue(operand)); + } catch (nb::builtin_exception &err) { + throw nb::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (" + err.what() + ")") + .str() + .c_str()); + } + + operandSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + operandSegmentLengths.push_back(0); + } else { + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (was None and operand is not optional)") + .str() + .c_str()); + } + } else if (segmentSpec == -1) { + // Unpack sequence by appending. + try { + if (std::get<0>(it.value()).is_none()) { + // Treat it as an empty list. + operandSegmentLengths.push_back(0); + } else { + // Unpack the list. + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + operands.push_back(getOpResultOrValue(segmentItem)); + } + operandSegmentLengths.push_back(nb::len(segment)); + } + } catch (std::exception &err) { + // NOTE: Sloppy to be using a catch-all here, but there are at least + // three different unrelated exceptions that can be thrown in the + // above "casts". Just keep the scope above small and catch them all. + throw nb::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + " of operation \"" + + name + "\" must be a Sequence of Values (" + + err.what() + ")") + .str() + .c_str()); + } + } else { + throw nb::value_error("Unexpected segment spec"); + } + } + } + + // Merge operand/result segment lengths into attributes if needed. + if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { + // Dup. + if (attributes) { + attributes = nb::dict(*attributes); + } else { + attributes = nb::dict(); + } + if (attributes->contains("resultSegmentSizes") || + attributes->contains("operandSegmentSizes")) { + throw nb::value_error("Manually setting a 'resultSegmentSizes' or " + "'operandSegmentSizes' attribute is unsupported. " + "Use Operation.create for such low-level access."); + } + + // Add resultSegmentSizes attribute. + if (!resultSegmentLengths.empty()) { + MlirAttribute segmentLengthAttr = + mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), + resultSegmentLengths.data()); + (*attributes)["resultSegmentSizes"] = + PyAttribute(context, segmentLengthAttr); + } + + // Add operandSegmentSizes attribute. + if (!operandSegmentLengths.empty()) { + MlirAttribute segmentLengthAttr = + mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), + operandSegmentLengths.data()); + (*attributes)["operandSegmentSizes"] = + PyAttribute(context, segmentLengthAttr); + } + } + + // Delegate to create. + return PyOperation::create(name, + /*results=*/std::move(resultTypes), + /*operands=*/std::move(operands), + /*attributes=*/std::move(attributes), + /*successors=*/std::move(successors), + /*regions=*/*regions, location, maybeIp, + !resultTypeList); +} + +nb::object PyOpView::constructDerived(const nb::object &cls, + const nb::object &operation) { + nb::handle opViewType = nb::type(); + nb::object instance = cls.attr("__new__")(cls); + opViewType.attr("__init__")(instance, operation); + return instance; +} + +PyOpView::PyOpView(const nb::object &operationObject) + // Casting through the PyOperationBase base-class and then back to the + // Operation lets us accept any PyOperationBase subclass. + : operation(nb::cast(operationObject).getOperation()), + operationObject(operation.getRef().getObject()) {} + +//------------------------------------------------------------------------------ +// PyInsertionPoint. +//------------------------------------------------------------------------------ + +PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} + +PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) + : refOperation(beforeOperationBase.getOperation().getRef()), + block((*refOperation)->getBlock()) {} + +void PyInsertionPoint::insert(PyOperationBase &operationBase) { + PyOperation &operation = operationBase.getOperation(); + if (operation.isAttached()) + throw nb::value_error( + "Attempt to insert operation that is already attached"); + block.getParentOperation()->checkValid(); + MlirOperation beforeOp = {nullptr}; + if (refOperation) { + // Insert before operation. + (*refOperation)->checkValid(); + beforeOp = (*refOperation)->get(); + } else { + // Insert at end (before null) is only valid if the block does not + // already end in a known terminator (violating this will cause assertion + // failures later). + if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { + throw nb::index_error("Cannot insert operation at the end of a block " + "that already has a terminator. Did you mean to " + "use 'InsertionPoint.at_block_terminator(block)' " + "versus 'InsertionPoint(block)'?"); + } + } + mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); + operation.setAttached(); +} + +PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { + MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); + if (mlirOperationIsNull(firstOp)) { + // Just insert at end. + return PyInsertionPoint(block); + } + + // Insert before first op. + PyOperationRef firstOpRef = PyOperation::forOperation( + block.getParentOperation()->getContext(), firstOp); + return PyInsertionPoint{block, std::move(firstOpRef)}; +} + +PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { + MlirOperation terminator = mlirBlockGetTerminator(block.get()); + if (mlirOperationIsNull(terminator)) + throw nb::value_error("Block has no terminator"); + PyOperationRef terminatorOpRef = PyOperation::forOperation( + block.getParentOperation()->getContext(), terminator); + return PyInsertionPoint{block, std::move(terminatorOpRef)}; +} + +nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { + return PyThreadContextEntry::pushInsertionPoint(insertPoint); +} + +void PyInsertionPoint::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { + PyThreadContextEntry::popInsertionPoint(*this); +} + +//------------------------------------------------------------------------------ +// PyAttribute. +//------------------------------------------------------------------------------ + +bool PyAttribute::operator==(const PyAttribute &other) const { + return mlirAttributeEqual(attr, other.attr); +} + +nb::object PyAttribute::getCapsule() { + return nb::steal(mlirPythonAttributeToCapsule(*this)); +} + +PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { + MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); + if (mlirAttributeIsNull(rawAttr)) + throw nb::python_error(); + return PyAttribute( + PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); +} + +//------------------------------------------------------------------------------ +// PyNamedAttribute. +//------------------------------------------------------------------------------ + +PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) + : ownedName(new std::string(std::move(ownedName))) { + namedAttr = mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(attr), + toMlirStringRef(*this->ownedName)), + attr); +} + +//------------------------------------------------------------------------------ +// PyType. +//------------------------------------------------------------------------------ + +bool PyType::operator==(const PyType &other) const { + return mlirTypeEqual(type, other.type); +} + +nb::object PyType::getCapsule() { + return nb::steal(mlirPythonTypeToCapsule(*this)); +} + +PyType PyType::createFromCapsule(nb::object capsule) { + MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); + if (mlirTypeIsNull(rawType)) + throw nb::python_error(); + return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), + rawType); +} + +//------------------------------------------------------------------------------ +// PyTypeID. +//------------------------------------------------------------------------------ + +nb::object PyTypeID::getCapsule() { + return nb::steal(mlirPythonTypeIDToCapsule(*this)); +} + +PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { + MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); + if (mlirTypeIDIsNull(mlirTypeID)) + throw nb::python_error(); + return PyTypeID(mlirTypeID); +} +bool PyTypeID::operator==(const PyTypeID &other) const { + return mlirTypeIDEqual(typeID, other.typeID); +} + +//------------------------------------------------------------------------------ +// PyValue and subclasses. +//------------------------------------------------------------------------------ + +nb::object PyValue::getCapsule() { + return nb::steal(mlirPythonValueToCapsule(get())); +} + +nb::object PyValue::maybeDownCast() { + MlirType type = mlirValueGetType(get()); + MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional valueCaster = + PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); + // nb::rv_policy::move means use std::move to move the return value + // contents into a new instance that will be owned by Python. + nb::object thisObj = nb::cast(this, nb::rv_policy::move); + if (!valueCaster) + return thisObj; + return valueCaster.value()(thisObj); +} + +PyValue PyValue::createFromCapsule(nb::object capsule) { + MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); + if (mlirValueIsNull(value)) + throw nb::python_error(); + MlirOperation owner; + if (mlirValueIsAOpResult(value)) + owner = mlirOpResultGetOwner(value); + if (mlirValueIsABlockArgument(value)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); + if (mlirOperationIsNull(owner)) + throw nb::python_error(); + MlirContext ctx = mlirOperationGetContext(owner); + PyOperationRef ownerRef = + PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); + return PyValue(ownerRef, value); +} + +//------------------------------------------------------------------------------ +// PySymbolTable. +//------------------------------------------------------------------------------ + +PySymbolTable::PySymbolTable(PyOperationBase &operation) + : operation(operation.getOperation().getRef()) { + symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); + if (mlirSymbolTableIsNull(symbolTable)) { + throw nb::type_error("Operation is not a Symbol Table."); + } +} + +nb::object PySymbolTable::dunderGetItem(const std::string &name) { + operation->checkValid(); + MlirOperation symbol = mlirSymbolTableLookup( + symbolTable, mlirStringRefCreate(name.data(), name.length())); + if (mlirOperationIsNull(symbol)) + throw nb::key_error( + ("Symbol '" + name + "' not in the symbol table.").c_str()); + + return PyOperation::forOperation(operation->getContext(), symbol, + operation.getObject()) + ->createOpView(); +} + +void PySymbolTable::erase(PyOperationBase &symbol) { + operation->checkValid(); + symbol.getOperation().checkValid(); + mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); + // The operation is also erased, so we must invalidate it. There may be Python + // references to this operation so we don't want to delete it from the list of + // live operations here. + symbol.getOperation().valid = false; +} + +void PySymbolTable::dunderDel(const std::string &name) { + nb::object operation = dunderGetItem(name); + erase(nb::cast(operation)); +} + +MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { + operation->checkValid(); + symbol.getOperation().checkValid(); + MlirAttribute symbolAttr = mlirOperationGetAttributeByName( + symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); + if (mlirAttributeIsNull(symbolAttr)) + throw nb::value_error("Expected operation to have a symbol name."); + return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); +} + +MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { + // Op must already be a symbol. + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); + MlirAttribute existingNameAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingNameAttr)) + throw nb::value_error("Expected operation to have a symbol name."); + return existingNameAttr; +} + +void PySymbolTable::setSymbolName(PyOperationBase &symbol, + const std::string &name) { + // Op must already be a symbol. + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); + MlirAttribute existingNameAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingNameAttr)) + throw nb::value_error("Expected operation to have a symbol name."); + MlirAttribute newNameAttr = + mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); + mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); +} + +MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); + MlirAttribute existingVisAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingVisAttr)) + throw nb::value_error("Expected operation to have a symbol visibility."); + return existingVisAttr; +} + +void PySymbolTable::setVisibility(PyOperationBase &symbol, + const std::string &visibility) { + if (visibility != "public" && visibility != "private" && + visibility != "nested") + throw nb::value_error( + "Expected visibility to be 'public', 'private' or 'nested'"); + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); + MlirAttribute existingVisAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingVisAttr)) + throw nb::value_error("Expected operation to have a symbol visibility."); + MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), + toMlirStringRef(visibility)); + mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); +} + +void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, + const std::string &newSymbol, + PyOperationBase &from) { + PyOperation &fromOperation = from.getOperation(); + fromOperation.checkValid(); + if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( + toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), + from.getOperation()))) + + throw nb::value_error("Symbol rename failed"); +} + +void PySymbolTable::walkSymbolTables(PyOperationBase &from, + bool allSymUsesVisible, + nb::object callback) { + PyOperation &fromOperation = from.getOperation(); + fromOperation.checkValid(); + struct UserData { + PyMlirContextRef context; + nb::object callback; + bool gotException; + std::string exceptionWhat; + nb::object exceptionType; + }; + UserData userData{ + fromOperation.getContext(), std::move(callback), false, {}, {}}; + mlirSymbolTableWalkSymbolTables( + fromOperation.get(), allSymUsesVisible, + [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { + UserData *calleeUserData = static_cast(calleeUserDataVoid); + auto pyFoundOp = + PyOperation::forOperation(calleeUserData->context, foundOp); + if (calleeUserData->gotException) + return; + try { + calleeUserData->callback(pyFoundOp.getObject(), isVisible); + } catch (nb::python_error &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = nb::borrow(e.type()); + } + }, + static_cast(&userData)); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(message); + } +} + +namespace { + +/// Python wrapper for MlirBlockArgument. +class PyBlockArgument : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; + static constexpr const char *pyClassName = "BlockArgument"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("owner", [](PyBlockArgument &self) { + return PyBlock(self.getParentOperation(), + mlirBlockArgumentGetOwner(self.get())); + }); + c.def_prop_ro("arg_number", [](PyBlockArgument &self) { + return mlirBlockArgumentGetArgNumber(self.get()); + }); + c.def( + "set_type", + [](PyBlockArgument &self, PyType type) { + return mlirBlockArgumentSetType(self.get(), type); + }, + nb::arg("type")); + } +}; + +/// A list of block arguments. Internally, these are stored as consecutive +/// elements, random access is cheap. The argument list is associated with the +/// operation that contains the block (detached blocks are not allowed in +/// Python bindings) and extends its lifetime. +class PyBlockArgumentList + : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockArgumentList"; + using SliceableT = Sliceable; + + PyBlockArgumentList(PyOperationRef operation, MlirBlock block, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumArguments(block) : length, + step), + operation(std::move(operation)), block(block) {} + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("types", [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + /// Returns the number of arguments in the list. + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirBlockGetNumArguments(block); + } + + /// Returns `pos`-the element in the list. + PyBlockArgument getRawElement(intptr_t pos) { + MlirValue argument = mlirBlockGetArgument(block, pos); + return PyBlockArgument(operation, argument); + } + + /// Returns a sublist of this list. + PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockArgumentList(operation, block, startIndex, length, step); + } + + PyOperationRef operation; + MlirBlock block; +}; + +/// A list of operation operands. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) operand list is associated +/// with the operation whose operands these are, and thus extends the lifetime +/// of this operation. +class PyOpOperandList : public Sliceable { +public: + static constexpr const char *pyClassName = "OpOperandList"; + using SliceableT = Sliceable; + + PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumOperands(operation->get()) + : length, + step), + operation(operation) {} + + void dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumOperands(operation->get()); + } + + PyValue getRawElement(intptr_t pos) { + MlirValue operand = mlirOperationGetOperand(operation->get(), pos); + MlirOperation owner; + if (mlirValueIsAOpResult(operand)) + owner = mlirOpResultGetOwner(operand); + else if (mlirValueIsABlockArgument(operand)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); + else + assert(false && "Value must be an block arg or op result."); + PyOperationRef pyOwner = + PyOperation::forOperation(operation->getContext(), owner); + return PyValue(pyOwner, operand); + } + + PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpOperandList(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + +/// A list of operation successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation whose successors these are, and thus extends +/// the lifetime of this operation. +class PyOpSuccessors : public Sliceable { +public: + static constexpr const char *pyClassName = "OpSuccessors"; + + PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumSuccessors(operation->get()) + : length, + step), + operation(operation) {} + + void dunderSetItem(intptr_t index, PyBlock block) { + index = wrapIndex(index); + mlirOperationSetSuccessor(operation->get(), index, block.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpSuccessors::dunderSetItem); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumSuccessors(operation->get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); + return PyBlock(operation, block); + } + + PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpSuccessors(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + +/// A list of block successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation and block whose successors these are, and thus +/// extends the lifetime of this operation and block. +class PyBlockSuccessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockSuccessors"; + + PyBlockSuccessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumSuccessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumSuccessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyBlockSuccessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of block predecessors. The (returned) predecessor list is +/// associated with the operation and block whose predecessors these are, and +/// thus extends the lifetime of this operation and block. +/// +/// WARNING: This Sliceable is more expensive than the others here because +/// mlirBlockGetPredecessor actually iterates the use-def chain (of block +/// operands) anew for each indexed access. +class PyBlockPredecessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockPredecessors"; + + PyBlockPredecessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumPredecessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumPredecessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockPredecessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockPredecessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of operation attributes. Can be indexed by name, producing +/// attributes, or by index, producing named attributes. +class PyOpAttributeMap { +public: + PyOpAttributeMap(PyOperationRef operation) + : operation(std::move(operation)) {} + + MlirAttribute dunderGetItemNamed(const std::string &name) { + MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), + toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) { + throw nb::key_error("attempt to access a non-existent attribute"); + } + return attr; + } + + PyNamedAttribute dunderGetItemIndexed(intptr_t index) { + if (index < 0) { + index += dunderLen(); + } + if (index < 0 || index >= dunderLen()) { + throw nb::index_error("attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = + mlirOperationGetAttribute(operation->get(), index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data, + mlirIdentifierStr(namedAttr.name).length)); + } + + void dunderSetItem(const std::string &name, const PyAttribute &attr) { + mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), + attr); + } + + void dunderDelItem(const std::string &name) { + int removed = mlirOperationRemoveAttributeByName(operation->get(), + toMlirStringRef(name)); + if (!removed) + throw nb::key_error("attempt to delete a non-existent attribute"); + } + + intptr_t dunderLen() { + return mlirOperationGetNumAttributes(operation->get()); + } + + bool dunderContains(const std::string &name) { + return !mlirAttributeIsNull(mlirOperationGetAttributeByName( + operation->get(), toMlirStringRef(name))); + } + + static void bind(nb::module_ &m) { + nb::class_(m, "OpAttributeMap") + .def("__contains__", &PyOpAttributeMap::dunderContains) + .def("__len__", &PyOpAttributeMap::dunderLen) + .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) + .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) + .def("__setitem__", &PyOpAttributeMap::dunderSetItem) + .def("__delitem__", &PyOpAttributeMap::dunderDelItem); + } + +private: + PyOperationRef operation; +}; + +// see +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h + +#ifndef _Py_CAST +#define _Py_CAST(type, expr) ((type)(expr)) +#endif + +// Static inline functions should use _Py_NULL rather than using directly NULL +// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, +// _Py_NULL is defined as nullptr. +#ifndef _Py_NULL +#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \ + (defined(__cplusplus) && __cplusplus >= 201103) +#define _Py_NULL nullptr +#else +#define _Py_NULL NULL +#endif +#endif + +// Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 + +// bpo-42262 added Py_XNewRef() +#if !defined(Py_XNewRef) +PyObject *_Py_XNewRef(PyObject *obj) { + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) +#endif + +// bpo-42262 added Py_NewRef() +#if !defined(Py_NewRef) +PyObject *_Py_NewRef(PyObject *obj) { + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) +#endif + +#endif // Python 3.10.0a3 + +// Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) + +// bpo-40429 added PyThreadState_GetFrame() +PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { + assert(tstate != _Py_NULL && "expected tstate != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} + +// bpo-40421 added PyFrame_GetBack() +PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back)); +} + +// bpo-40421 added PyFrame_GetCode() +PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL"); + return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code)); +} + +#endif // Python 3.9.0b1 + +MlirLocation tracebackToLocation(MlirContext ctx) { + size_t framesLimit = + PyGlobals::get().getTracebackLoc().locTracebackFramesLimit(); + // Use a thread_local here to avoid requiring a large amount of space. + thread_local std::array + frames; + size_t count = 0; + + nb::gil_scoped_acquire acquire; + PyThreadState *tstate = PyThreadState_GET(); + PyFrameObject *next; + PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate); + // In the increment expression: + // 1. get the next prev frame; + // 2. decrement the ref count on the current frame (in order that it can get + // gc'd, along with any objects in its closure and etc); + // 3. set current = next. + for (; pyFrame != nullptr && count < framesLimit; + next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) { + PyCodeObject *code = PyFrame_GetCode(pyFrame); + auto fileNameStr = + nb::cast(nb::borrow(code->co_filename)); + llvm::StringRef fileName(fileNameStr); + if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName)) + continue; + + // co_qualname and PyCode_Addr2Location added in py3.11 +#if PY_VERSION_HEX < 0x030B00F0 + std::string name = + nb::cast(nb::borrow(code->co_name)); + llvm::StringRef funcName(name); + int startLine = PyFrame_GetLineNumber(pyFrame); + MlirLocation loc = + mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0); +#else + std::string name = + nb::cast(nb::borrow(code->co_qualname)); + llvm::StringRef funcName(name); + int startLine, startCol, endLine, endCol; + int lasti = PyFrame_GetLasti(pyFrame); + if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine, + &endCol)) { + throw nb::python_error(); + } + MlirLocation loc = mlirLocationFileLineColRangeGet( + ctx, wrap(fileName), startLine, startCol, endLine, endCol); +#endif + + frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc); + ++count; + } + // When the loop breaks (after the last iter), current frame (if non-null) + // is leaked without this. + Py_XDECREF(pyFrame); + + if (count == 0) + return mlirLocationUnknownGet(ctx); + + MlirLocation callee = frames[0]; + assert(!mlirLocationIsNull(callee) && "expected non-null callee location"); + if (count == 1) + return callee; + + MlirLocation caller = frames[count - 1]; + assert(!mlirLocationIsNull(caller) && "expected non-null caller location"); + for (int i = count - 2; i >= 1; i--) + caller = mlirLocationCallSiteGet(frames[i], caller); + + return mlirLocationCallSiteGet(callee, caller); +} + +PyLocation +maybeGetTracebackLocation(const std::optional &location) { + if (location.has_value()) + return location.value(); + if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) + return DefaultingPyLocation::resolve(); + + PyMlirContext &ctx = DefaultingPyMlirContext::resolve(); + MlirLocation mlirLoc = tracebackToLocation(ctx.get()); + PyMlirContextRef ref = PyMlirContext::forContext(ctx.get()); + return {ref, mlirLoc}; +} + +} // namespace + +//------------------------------------------------------------------------------ +// Populates the core exports of the 'ir' submodule. +//------------------------------------------------------------------------------ + +void mlir::python::populateIRCore(nb::module_ &m) { + // disable leak warnings which tend to be false positives. + nb::set_leak_warnings(false); + //---------------------------------------------------------------------------- + // Enums. + //---------------------------------------------------------------------------- + nb::enum_(m, "DiagnosticSeverity") + .value("ERROR", MlirDiagnosticError) + .value("WARNING", MlirDiagnosticWarning) + .value("NOTE", MlirDiagnosticNote) + .value("REMARK", MlirDiagnosticRemark); + + nb::enum_(m, "WalkOrder") + .value("PRE_ORDER", MlirWalkPreOrder) + .value("POST_ORDER", MlirWalkPostOrder); + + nb::enum_(m, "WalkResult") + .value("ADVANCE", MlirWalkResultAdvance) + .value("INTERRUPT", MlirWalkResultInterrupt) + .value("SKIP", MlirWalkResultSkip); + + //---------------------------------------------------------------------------- + // Mapping of Diagnostics. + //---------------------------------------------------------------------------- + nb::class_(m, "Diagnostic") + .def_prop_ro("severity", &PyDiagnostic::getSeverity) + .def_prop_ro("location", &PyDiagnostic::getLocation) + .def_prop_ro("message", &PyDiagnostic::getMessage) + .def_prop_ro("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> nb::str { + if (!self.isValid()) + return nb::str(""); + return self.getMessage(); + }); + + nb::class_(m, "DiagnosticInfo") + .def("__init__", + [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { + new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); + }) + .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) + .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) + .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) + .def("__str__", + [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); + + nb::class_(m, "DiagnosticHandler") + .def("detach", &PyDiagnosticHandler::detach) + .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) + .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) + .def("__enter__", &PyDiagnosticHandler::contextEnter) + .def("__exit__", &PyDiagnosticHandler::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()); + + //---------------------------------------------------------------------------- + // Mapping of MlirContext. + // Note that this is exported as _BaseContext. The containing, Python level + // __init__.py will subclass it with site-specific functionality and set a + // "Context" attribute on this module. + //---------------------------------------------------------------------------- + + // Expose DefaultThreadPool to python + nb::class_(m, "ThreadPool") + .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) + .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); + + nb::class_(m, "_BaseContext") + .def("__init__", + [](PyMlirContext &self) { + MlirContext context = mlirContextCreateWithThreading(false); + new (&self) PyMlirContext(context); + }) + .def_static("_get_live_count", &PyMlirContext::getLiveCount) + .def("_get_context_again", + [](PyMlirContext &self) { + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); + }) + .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_get_live_operation_objects", + &PyMlirContext::getLiveOperationObjects) + .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) + .def("_clear_live_operations_inside", + nb::overload_cast( + &PyMlirContext::clearOperationsInside)) + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def("__enter__", &PyMlirContext::contextEnter) + .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) + .def_prop_ro_static( + "current", + [](nb::object & /*class*/) { + auto *context = PyThreadContextEntry::getDefaultContext(); + if (!context) + return nb::none(); + return nb::cast(context); + }, + "Gets the Context bound to the current thread or raises ValueError") + .def_prop_ro( + "dialects", + [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Gets a container for accessing dialects by name") + .def_prop_ro( + "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Alias for 'dialect'") + .def( + "get_dialect_descriptor", + [=](PyMlirContext &self, std::string &name) { + MlirDialect dialect = mlirContextGetOrLoadDialect( + self.get(), {name.data(), name.size()}); + if (mlirDialectIsNull(dialect)) { + throw nb::value_error( + (Twine("Dialect '") + name + "' not found").str().c_str()); + } + return PyDialectDescriptor(self.getRef(), dialect); + }, + nb::arg("dialect_name"), + "Gets or loads a dialect by name, returning its descriptor object") + .def_prop_rw( + "allow_unregistered_dialects", + [](PyMlirContext &self) -> bool { + return mlirContextGetAllowUnregisteredDialects(self.get()); + }, + [](PyMlirContext &self, bool value) { + mlirContextSetAllowUnregisteredDialects(self.get(), value); + }) + .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, + nb::arg("callback"), + "Attaches a diagnostic handler that will receive callbacks") + .def( + "enable_multithreading", + [](PyMlirContext &self, bool enable) { + mlirContextEnableMultithreading(self.get(), enable); + }, + nb::arg("enable")) + .def("set_thread_pool", + [](PyMlirContext &self, PyThreadPool &pool) { + // we should disable multi-threading first before setting + // new thread pool otherwise the assert in + // MLIRContext::setThreadPool will be raised. + mlirContextEnableMultithreading(self.get(), false); + mlirContextSetThreadPool(self.get(), pool.get()); + }) + .def("get_num_threads", + [](PyMlirContext &self) { + return mlirContextGetNumThreads(self.get()); + }) + .def("_mlir_thread_pool_ptr", + [](PyMlirContext &self) { + MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); + std::stringstream ss; + ss << pool.ptr; + return ss.str(); + }) + .def( + "is_registered_operation", + [](PyMlirContext &self, std::string &name) { + return mlirContextIsRegisteredOperation( + self.get(), MlirStringRef{name.data(), name.size()}); + }, + nb::arg("operation_name")) + .def( + "append_dialect_registry", + [](PyMlirContext &self, PyDialectRegistry ®istry) { + mlirContextAppendDialectRegistry(self.get(), registry); + }, + nb::arg("registry")) + .def_prop_rw("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") + .def("load_all_available_dialects", [](PyMlirContext &self) { + mlirContextLoadAllAvailableDialects(self.get()); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialectDescriptor + //---------------------------------------------------------------------------- + nb::class_(m, "DialectDescriptor") + .def_prop_ro("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + return nb::str(ns.data, ns.length); + }) + .def("__repr__", [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + std::string repr(""); + return repr; + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialects + //---------------------------------------------------------------------------- + nb::class_(m, "Dialects") + .def("__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(keyName, std::move(descriptor)); + }) + .def("__getattr__", [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(attrName, std::move(descriptor)); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialect + //---------------------------------------------------------------------------- + nb::class_(m, "Dialect") + .def(nb::init(), nb::arg("descriptor")) + .def_prop_ro("descriptor", + [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](nb::object self) { + auto clazz = self.attr("__class__"); + return nb::str(""); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyDialectRegistry + //---------------------------------------------------------------------------- + nb::class_(m, "DialectRegistry") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) + .def(nb::init<>()); + + //---------------------------------------------------------------------------- + // Mapping of Location + //---------------------------------------------------------------------------- + nb::class_(m, "Location") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) + .def("__enter__", &PyLocation::contextEnter) + .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) + .def("__eq__", + [](PyLocation &self, PyLocation &other) -> bool { + return mlirLocationEqual(self, other); + }) + .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) + .def_prop_ro_static( + "current", + [](nb::object & /*class*/) -> std::optional { + auto *loc = PyThreadContextEntry::getDefaultLocation(); + if (!loc) + return std::nullopt; + return loc; + }, + "Gets the Location bound to the current thread or raises ValueError") + .def_static( + "unknown", + [](DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationUnknownGet(context->get())); + }, + nb::arg("context").none() = nb::none(), + "Gets a Location representing an unknown location") + .def_static( + "callsite", + [](PyLocation callee, const std::vector &frames, + DefaultingPyMlirContext context) { + if (frames.empty()) + throw nb::value_error("No caller frames provided"); + MlirLocation caller = frames.back().get(); + for (const PyLocation &frame : + llvm::reverse(llvm::ArrayRef(frames).drop_back())) + caller = mlirLocationCallSiteGet(frame.get(), caller); + return PyLocation(context->getRef(), + mlirLocationCallSiteGet(callee.get(), caller)); + }, + nb::arg("callee"), nb::arg("frames"), + nb::arg("context").none() = nb::none(), + kContextGetCallSiteLocationDocstring) + .def("is_a_callsite", mlirLocationIsACallSite) + .def_prop_ro("callee", mlirLocationCallSiteGetCallee) + .def_prop_ro("caller", mlirLocationCallSiteGetCaller) + .def_static( + "file", + [](std::string filename, int line, int col, + DefaultingPyMlirContext context) { + return PyLocation( + context->getRef(), + mlirLocationFileLineColGet( + context->get(), toMlirStringRef(filename), line, col)); + }, + nb::arg("filename"), nb::arg("line"), nb::arg("col"), + nb::arg("context").none() = nb::none(), + kContextGetFileLocationDocstring) + .def_static( + "file", + [](std::string filename, int startLine, int startCol, int endLine, + int endCol, DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationFileLineColRangeGet( + context->get(), toMlirStringRef(filename), + startLine, startCol, endLine, endCol)); + }, + nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), + nb::arg("end_line"), nb::arg("end_col"), + nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) + .def("is_a_file", mlirLocationIsAFileLineColRange) + .def_prop_ro("filename", + [](MlirLocation loc) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(loc)); + }) + .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) + .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) + .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) + .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn) + .def_static( + "fused", + [](const std::vector &pyLocations, + std::optional metadata, + DefaultingPyMlirContext context) { + llvm::SmallVector locations; + locations.reserve(pyLocations.size()); + for (auto &pyLocation : pyLocations) + locations.push_back(pyLocation.get()); + MlirLocation location = mlirLocationFusedGet( + context->get(), locations.size(), locations.data(), + metadata ? metadata->get() : MlirAttribute{0}); + return PyLocation(context->getRef(), location); + }, + nb::arg("locations"), nb::arg("metadata").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetFusedLocationDocstring) + .def("is_a_fused", mlirLocationIsAFused) + .def_prop_ro("locations", + [](MlirLocation loc) { + unsigned numLocations = + mlirLocationFusedGetNumLocations(loc); + std::vector locations(numLocations); + if (numLocations) + mlirLocationFusedGetLocations(loc, locations.data()); + return locations; + }) + .def_static( + "name", + [](std::string name, std::optional childLoc, + DefaultingPyMlirContext context) { + return PyLocation( + context->getRef(), + mlirLocationNameGet( + context->get(), toMlirStringRef(name), + childLoc ? childLoc->get() + : mlirLocationUnknownGet(context->get()))); + }, + nb::arg("name"), nb::arg("childLoc").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetNameLocationDocString) + .def("is_a_name", mlirLocationIsAName) + .def_prop_ro("name_str", + [](MlirLocation loc) { + return mlirIdentifierStr(mlirLocationNameGetName(loc)); + }) + .def_prop_ro("child_loc", mlirLocationNameGetChildLoc) + .def_static( + "from_attr", + [](PyAttribute &attribute, DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationFromAttribute(attribute)); + }, + nb::arg("attribute"), nb::arg("context").none() = nb::none(), + "Gets a Location from a LocationAttr") + .def_prop_ro( + "context", + [](PyLocation &self) { return self.getContext().getObject(); }, + "Context that owns the Location") + .def_prop_ro( + "attr", + [](PyLocation &self) { return mlirLocationGetAttribute(self); }, + "Get the underlying LocationAttr") + .def( + "emit_error", + [](PyLocation &self, std::string message) { + mlirEmitError(self, message.c_str()); + }, + nb::arg("message"), "Emits an error at this location") + .def("__repr__", [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }); + + //---------------------------------------------------------------------------- + // Mapping of Module + //---------------------------------------------------------------------------- + nb::class_(m, "Module", nb::is_weak_referenceable()) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def_static( + "parse", + [](const std::string &moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParse( + context->get(), toMlirStringRef(moduleAsm)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) + .def_static( + "parse", + [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParse( + context->get(), toMlirStringRef(moduleAsm)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) + .def_static( + "parseFile", + [](const std::string &path, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParseFromFile( + context->get(), toMlirStringRef(path)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("path"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) + .def_static( + "create", + [](const std::optional &loc) { + PyLocation pyLoc = maybeGetTracebackLocation(loc); + MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("loc").none() = nb::none(), "Creates an empty module") + .def_prop_ro( + "context", + [](PyModule &self) { return self.getContext().getObject(); }, + "Context that created the Module") + .def_prop_ro( + "operation", + [](PyModule &self) { + return PyOperation::forOperation(self.getContext(), + mlirModuleGetOperation(self.get()), + self.getRef().releaseObject()) + .releaseObject(); + }, + "Accesses the module as an operation") + .def_prop_ro( + "body", + [](PyModule &self) { + PyOperationRef moduleOp = PyOperation::forOperation( + self.getContext(), mlirModuleGetOperation(self.get()), + self.getRef().releaseObject()); + PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); + return returnBlock; + }, + "Return the block for this module") + .def( + "dump", + [](PyModule &self) { + mlirOperationDump(mlirModuleGetOperation(self.get())); + }, + kDumpDocstring) + .def( + "__str__", + [](nb::object self) { + // Defer to the operation's __str__. + return self.attr("operation").attr("__str__")(); + }, + kOperationStrDunderDocstring); + + //---------------------------------------------------------------------------- + // Mapping of Operation. + //---------------------------------------------------------------------------- + nb::class_(m, "_OperationBase") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }) + .def("__eq__", + [](PyOperationBase &self, PyOperationBase &other) { + return &self.getOperation() == &other.getOperation(); + }) + .def("__eq__", + [](PyOperationBase &self, nb::object other) { return false; }) + .def("__hash__", + [](PyOperationBase &self) { + return static_cast(llvm::hash_value(&self.getOperation())); + }) + .def_prop_ro("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap(self.getOperation().getRef()); + }) + .def_prop_ro( + "context", + [](PyOperationBase &self) { + PyOperation &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + return concreteOperation.getContext().getObject(); + }, + "Context that owns the Operation") + .def_prop_ro("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + return mlirIdentifierStr(mlirOperationGetName(operation)); + }) + .def_prop_ro("operands", + [](PyOperationBase &self) { + return PyOpOperandList(self.getOperation().getRef()); + }) + .def_prop_ro("regions", + [](PyOperationBase &self) { + return PyRegionList(self.getOperation().getRef()); + }) + .def_prop_ro( + "results", + [](PyOperationBase &self) { + return PyOpResultList(self.getOperation().getRef()); + }, + "Returns the list of Operation results.") + .def_prop_ro( + "result", + [](PyOperationBase &self) { + auto &operation = self.getOperation(); + return PyOpResult(operation.getRef(), getUniqueResult(operation)) + .maybeDownCast(); + }, + "Shortcut to get an op result if it has only one (throws an error " + "otherwise).") + .def_prop_ro( + "location", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + return PyLocation(operation.getContext(), + mlirOperationGetLocation(operation.get())); + }, + "Returns the source location the operation was defined or derived " + "from.") + .def_prop_ro("parent", + [](PyOperationBase &self) -> nb::object { + auto parent = self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return nb::none(); + }) + .def( + "__str__", + [](PyOperationBase &self) { + return self.getAsm(/*binary=*/false, + /*largeElementsLimit=*/std::nullopt, + /*largeResourceLimit=*/std::nullopt, + /*enableDebugInfo=*/false, + /*prettyDebugInfo=*/false, + /*printGenericOpForm=*/false, + /*useLocalScope=*/false, + /*useNameLocAsPrefix=*/false, + /*assumeVerified=*/false, + /*skipRegions=*/false); + }, + "Returns the assembly form of the operation.") + .def("print", + nb::overload_cast( + &PyOperationBase::print), + nb::arg("state"), nb::arg("file").none() = nb::none(), + nb::arg("binary") = false, kOperationPrintStateDocstring) + .def("print", + nb::overload_cast, std::optional, + bool, bool, bool, bool, bool, bool, nb::object, + bool, bool>(&PyOperationBase::print), + // Careful: Lots of arguments must match up with print method. + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("large_resource_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false, + nb::arg("assume_verified") = false, + nb::arg("file").none() = nb::none(), nb::arg("binary") = false, + nb::arg("skip_regions") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), + nb::arg("desired_version").none() = nb::none(), + kOperationPrintBytecodeDocstring) + .def("get_asm", &PyOperationBase::getAsm, + // Careful: Lots of arguments must match up with get_asm method. + nb::arg("binary") = false, + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("large_resource_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false, + nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, + kOperationGetAsmDocstring) + .def("verify", &PyOperationBase::verify, + "Verify the operation. Raises MLIRError if verification fails, and " + "returns true otherwise.") + .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), + "Puts self immediately after the other operation in its parent " + "block.") + .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), + "Puts self immediately before the other operation in its parent " + "block.") + .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, + nb::arg("other"), + "Given an operation 'other' that is within the same parent block, " + "return" + "whether the current operation is before 'other' in the operation " + "list" + "of the parent block.") + .def( + "clone", + [](PyOperationBase &self, nb::object ip) { + return self.getOperation().clone(ip); + }, + nb::arg("ip").none() = nb::none()) + .def( + "detach_from_parent", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + if (!operation.isAttached()) + throw nb::value_error("Detached operation has no parent."); + + operation.detachFromParent(); + return operation.createOpView(); + }, + "Detaches the operation from its parent block.") + .def_prop_ro( + "attached", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + return operation.isAttached(); + }, + "Reports if the operation is attached to its parent block.") + .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) + .def("walk", &PyOperationBase::walk, nb::arg("callback"), + nb::arg("walk_order") = MlirWalkPostOrder); + + nb::class_(m, "Operation") + .def_static( + "create", + [](std::string_view name, + std::optional> results, + std::optional> operands, + std::optional attributes, + std::optional> successors, int regions, + const std::optional &location, + const nb::object &maybeIp, bool inferType) { + // Unpack/validate operands. + llvm::SmallVector mlirOperands; + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue *operand : *operands) { + if (!operand) + throw nb::value_error("operand value cannot be None"); + mlirOperands.push_back(operand->get()); + } + } + + PyLocation pyLoc = maybeGetTracebackLocation(location); + return PyOperation::create(name, results, mlirOperands, attributes, + successors, regions, pyLoc, maybeIp, + inferType); + }, + nb::arg("name"), nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0, + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + nb::arg("infer_type") = false, kOperationCreateDocstring) + .def_static( + "parse", + [](const std::string &sourceStr, const std::string &sourceName, + DefaultingPyMlirContext context) { + return PyOperation::parse(context->getRef(), sourceStr, sourceName) + ->createOpView(); + }, + nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", + nb::arg("context").none() = nb::none(), + "Parses an operation. Supports both text assembly format and binary " + "bytecode format.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) + .def_prop_ro("operation", [](nb::object self) { return self; }) + .def_prop_ro("opview", &PyOperation::createOpView) + .def_prop_ro("block", &PyOperation::getBlock) + .def_prop_ro( + "successors", + [](PyOperationBase &self) { + return PyOpSuccessors(self.getOperation().getRef()); + }, + "Returns the list of Operation successors."); + + auto opViewClass = + nb::class_(m, "OpView") + .def(nb::init(), nb::arg("operation")) + .def( + "__init__", + [](PyOpView *self, std::string_view name, + std::tuple opRegionSpec, + nb::object operandSegmentSpecObj, + nb::object resultSegmentSpecObj, + std::optional resultTypeList, nb::list operandList, + std::optional attributes, + std::optional> successors, + std::optional regions, + const std::optional &location, + const nb::object &maybeIp) { + PyLocation pyLoc = maybeGetTracebackLocation(location); + new (self) PyOpView(PyOpView::buildGeneric( + name, opRegionSpec, operandSegmentSpecObj, + resultSegmentSpecObj, resultTypeList, operandList, + attributes, successors, regions, pyLoc, maybeIp)); + }, + nb::arg("name"), nb::arg("opRegionSpec"), + nb::arg("operandSegmentSpecObj").none() = nb::none(), + nb::arg("resultSegmentSpecObj").none() = nb::none(), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("loc").none() = nb::none(), + nb::arg("ip").none() = nb::none()) + + .def_prop_ro("operation", &PyOpView::getOperationObject) + .def_prop_ro("opview", [](nb::object self) { return self; }) + .def( + "__str__", + [](PyOpView &self) { return nb::str(self.getOperationObject()); }) + .def_prop_ro( + "successors", + [](PyOperationBase &self) { + return PyOpSuccessors(self.getOperation().getRef()); + }, + "Returns the list of Operation successors."); + opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); + opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); + opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); + // It is faster to pass the operation_name, ods_regions, and + // ods_operand_segments/ods_result_segments as arguments to the constructor, + // rather than to access them as attributes. + opViewClass.attr("build_generic") = classmethod( + [](nb::handle cls, std::optional resultTypeList, + nb::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, std::optional location, + const nb::object &maybeIp) { + std::string name = nb::cast(cls.attr("OPERATION_NAME")); + std::tuple opRegionSpec = + nb::cast>(cls.attr("_ODS_REGIONS")); + nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); + nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); + PyLocation pyLoc = maybeGetTracebackLocation(location); + return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, + resultSegmentSpec, resultTypeList, + operandList, attributes, successors, + regions, pyLoc, maybeIp); + }, + nb::arg("cls"), nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + "Builds a specific, generated OpView based on class level attributes."); + opViewClass.attr("parse") = classmethod( + [](const nb::object &cls, const std::string &sourceStr, + const std::string &sourceName, DefaultingPyMlirContext context) { + PyOperationRef parsed = + PyOperation::parse(context->getRef(), sourceStr, sourceName); + + // Check if the expected operation was parsed, and cast to to the + // appropriate `OpView` subclass if successful. + // NOTE: This accesses attributes that have been automatically added to + // `OpView` subclasses, and is not intended to be used on `OpView` + // directly. + std::string clsOpName = + nb::cast(cls.attr("OPERATION_NAME")); + MlirStringRef identifier = + mlirIdentifierStr(mlirOperationGetName(*parsed.get())); + std::string_view parsedOpName(identifier.data, identifier.length); + if (clsOpName != parsedOpName) + throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + + parsedOpName + "'"); + return PyOpView::constructDerived(cls, parsed.getObject()); + }, + nb::arg("cls"), nb::arg("source"), nb::kw_only(), + nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), + "Parses a specific, generated OpView based on class level attributes"); + + //---------------------------------------------------------------------------- + // Mapping of PyRegion. + //---------------------------------------------------------------------------- + nb::class_(m, "Region") + .def_prop_ro( + "blocks", + [](PyRegion &self) { + return PyBlockList(self.getParentOperation(), self.get()); + }, + "Returns a forward-optimized sequence of blocks.") + .def_prop_ro( + "owner", + [](PyRegion &self) { + return self.getParentOperation()->createOpView(); + }, + "Returns the operation owning this region.") + .def( + "__iter__", + [](PyRegion &self) { + self.checkValid(); + MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); + return PyBlockIterator(self.getParentOperation(), firstBlock); + }, + "Iterates over blocks in the region.") + .def("__eq__", + [](PyRegion &self, PyRegion &other) { + return self.get().ptr == other.get().ptr; + }) + .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); + + //---------------------------------------------------------------------------- + // Mapping of PyBlock. + //---------------------------------------------------------------------------- + nb::class_(m, "Block") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_prop_ro( + "owner", + [](PyBlock &self) { + return self.getParentOperation()->createOpView(); + }, + "Returns the owning operation of this block.") + .def_prop_ro( + "region", + [](PyBlock &self) { + MlirRegion region = mlirBlockGetParentRegion(self.get()); + return PyRegion(self.getParentOperation(), region); + }, + "Returns the owning region of this block.") + .def_prop_ro( + "arguments", + [](PyBlock &self) { + return PyBlockArgumentList(self.getParentOperation(), self.get()); + }, + "Returns a list of block arguments.") + .def( + "add_argument", + [](PyBlock &self, const PyType &type, const PyLocation &loc) { + return mlirBlockAddArgument(self.get(), type, loc); + }, + "Append an argument of the specified type to the block and returns " + "the newly added argument.") + .def( + "erase_argument", + [](PyBlock &self, unsigned index) { + return mlirBlockEraseArgument(self.get(), index); + }, + "Erase the argument at 'index' and remove it from the argument list.") + .def_prop_ro( + "operations", + [](PyBlock &self) { + return PyOperationList(self.getParentOperation(), self.get()); + }, + "Returns a forward-optimized sequence of operations.") + .def_static( + "create_at_start", + [](PyRegion &parent, const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { + parent.checkValid(); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + mlirRegionInsertOwnedBlock(parent, 0, block); + return PyBlock(parent.getParentOperation(), block); + }, + nb::arg("parent"), nb::arg("arg_types") = nb::list(), + nb::arg("arg_locs") = std::nullopt, + "Creates and returns a new Block at the beginning of the given " + "region (with given argument types and locations).") + .def( + "append_to", + [](PyBlock &self, PyRegion ®ion) { + MlirBlock b = self.get(); + if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) + mlirBlockDetach(b); + mlirRegionAppendOwnedBlock(region.get(), b); + }, + "Append this block to a region, transferring ownership if necessary") + .def( + "create_before", + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { + self.checkValid(); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); + MlirRegion region = mlirBlockGetParentRegion(self.get()); + mlirRegionInsertOwnedBlockBefore(region, self.get(), block); + return PyBlock(self.getParentOperation(), block); + }, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, + "Creates and returns a new Block before this block " + "(with given argument types and locations).") + .def( + "create_after", + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { + self.checkValid(); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); + MlirRegion region = mlirBlockGetParentRegion(self.get()); + mlirRegionInsertOwnedBlockAfter(region, self.get(), block); + return PyBlock(self.getParentOperation(), block); + }, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, + "Creates and returns a new Block after this block " + "(with given argument types and locations).") + .def( + "__iter__", + [](PyBlock &self) { + self.checkValid(); + MlirOperation firstOperation = + mlirBlockGetFirstOperation(self.get()); + return PyOperationIterator(self.getParentOperation(), + firstOperation); + }, + "Iterates over operations in the block.") + .def("__eq__", + [](PyBlock &self, PyBlock &other) { + return self.get().ptr == other.get().ptr; + }) + .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) + .def("__hash__", + [](PyBlock &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def( + "__str__", + [](PyBlock &self) { + self.checkValid(); + PyPrintAccumulator printAccum; + mlirBlockPrint(self.get(), printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + "Returns the assembly form of the block.") + .def( + "append", + [](PyBlock &self, PyOperationBase &operation) { + if (operation.getOperation().isAttached()) + operation.getOperation().detachFromParent(); + + MlirOperation mlirOperation = operation.getOperation().get(); + mlirBlockAppendOwnedOperation(self.get(), mlirOperation); + operation.getOperation().setAttached( + self.getParentOperation().getObject()); + }, + nb::arg("operation"), + "Appends an operation to this block. If the operation is currently " + "in another block, it will be moved.") + .def_prop_ro( + "successors", + [](PyBlock &self) { + return PyBlockSuccessors(self, self.getParentOperation()); + }, + "Returns the list of Block successors.") + .def_prop_ro( + "predecessors", + [](PyBlock &self) { + return PyBlockPredecessors(self, self.getParentOperation()); + }, + "Returns the list of Block predecessors."); + + //---------------------------------------------------------------------------- + // Mapping of PyInsertionPoint. + //---------------------------------------------------------------------------- + + nb::class_(m, "InsertionPoint") + .def(nb::init(), nb::arg("block"), + "Inserts after the last operation but still inside the block.") + .def("__enter__", &PyInsertionPoint::contextEnter) + .def("__exit__", &PyInsertionPoint::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()) + .def_prop_ro_static( + "current", + [](nb::object & /*class*/) { + auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); + if (!ip) + throw nb::value_error("No current InsertionPoint"); + return ip; + }, + "Gets the InsertionPoint bound to the current thread or raises " + "ValueError if none has been set") + .def(nb::init(), nb::arg("beforeOperation"), + "Inserts before a referenced operation.") + .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, + nb::arg("block"), "Inserts at the beginning of the block.") + .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, + nb::arg("block"), "Inserts before the block terminator.") + .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), + "Inserts an operation.") + .def_prop_ro( + "block", [](PyInsertionPoint &self) { return self.getBlock(); }, + "Returns the block that this InsertionPoint points to.") + .def_prop_ro( + "ref_operation", + [](PyInsertionPoint &self) -> nb::object { + auto refOperation = self.getRefOperation(); + if (refOperation) + return refOperation->getObject(); + return nb::none(); + }, + "The reference operation before which new operations are " + "inserted, or None if the insertion point is at the end of " + "the block"); + + //---------------------------------------------------------------------------- + // Mapping of PyAttribute. + //---------------------------------------------------------------------------- + nb::class_(m, "Attribute") + // Delegate to the PyAttribute copy constructor, which will also lifetime + // extend the backing context which owns the MlirAttribute. + .def(nb::init(), nb::arg("cast_from_type"), + "Casts the passed attribute to the generic Attribute") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) + .def_static( + "parse", + [](const std::string &attrSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute attr = mlirAttributeParseGet( + context->get(), toMlirStringRef(attrSpec)); + if (mlirAttributeIsNull(attr)) + throw MLIRError("Unable to parse attribute", errors.take()); + return attr; + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + "Parses an attribute from an assembly form. Raises an MLIRError on " + "failure.") + .def_prop_ro( + "context", + [](PyAttribute &self) { return self.getContext().getObject(); }, + "Context that owns the Attribute") + .def_prop_ro("type", + [](PyAttribute &self) { return mlirAttributeGetType(self); }) + .def( + "get_named", + [](PyAttribute &self, std::string name) { + return PyNamedAttribute(self, std::move(name)); + }, + nb::keep_alive<0, 1>(), "Binds a name to the attribute") + .def("__eq__", + [](PyAttribute &self, PyAttribute &other) { return self == other; }) + .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) + .def("__hash__", + [](PyAttribute &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def( + "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, + kDumpDocstring) + .def( + "__str__", + [](PyAttribute &self) { + PyPrintAccumulator printAccum; + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + "Returns the assembly form of the Attribute.") + .def("__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_prop_ro("typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return mlirTypeID; + }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirAttributeGetDialect(self)); + if (!typeCaster) + return nb::cast(self); + return typeCaster.value()(self); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyNamedAttribute + //---------------------------------------------------------------------------- + nb::class_(m, "NamedAttribute") + .def("__repr__", + [](PyNamedAttribute &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("NamedAttribute("); + printAccum.parts.append( + nb::str(mlirIdentifierStr(self.namedAttr.name).data, + mlirIdentifierStr(self.namedAttr.name).length)); + printAccum.parts.append("="); + mlirAttributePrint(self.namedAttr.attribute, + printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_prop_ro( + "name", + [](PyNamedAttribute &self) { + return mlirIdentifierStr(self.namedAttr.name); + }, + "The name of the NamedAttribute binding") + .def_prop_ro( + "attr", + [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, + nb::keep_alive<0, 1>(), + "The underlying generic attribute of the NamedAttribute binding"); + + //---------------------------------------------------------------------------- + // Mapping of PyType. + //---------------------------------------------------------------------------- + nb::class_(m, "Type") + // Delegate to the PyType copy constructor, which will also lifetime + // extend the backing context which owns the MlirType. + .def(nb::init(), nb::arg("cast_from_type"), + "Casts the passed type to the generic Type") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) + .def_static( + "parse", + [](std::string typeSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType type = + mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); + if (mlirTypeIsNull(type)) + throw MLIRError("Unable to parse type", errors.take()); + return type; + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kContextParseTypeDocstring) + .def_prop_ro( + "context", [](PyType &self) { return self.getContext().getObject(); }, + "Context that owns the Type") + .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) + .def( + "__eq__", [](PyType &self, nb::object &other) { return false; }, + nb::arg("other").none()) + .def("__hash__", + [](PyType &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def( + "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) + .def( + "__str__", + [](PyType &self) { + PyPrintAccumulator printAccum; + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + "Returns the assembly form of the type.") + .def("__repr__", + [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyType &self) { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirTypeGetDialect(self)); + if (!typeCaster) + return nb::cast(self); + return typeCaster.value()(self); + }) + .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return mlirTypeID; + auto origRepr = nb::cast(nb::repr(nb::cast(self))); + throw nb::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyTypeID. + //---------------------------------------------------------------------------- + nb::class_(m, "TypeID") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the Python objects are the same (i.e., PyTypeID is a value type). + .def("__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }) + .def("__eq__", + [](PyTypeID &self, const nb::object &other) { return false; }) + // Note, this gives the hash value of the underlying TypeID, not the + // hash value of the Python object, nor the hash value of the + // MlirTypeID wrapper. + .def("__hash__", [](PyTypeID &self) { + return static_cast(mlirTypeIDHashValue(self)); + }); + + //---------------------------------------------------------------------------- + // Mapping of Value. + //---------------------------------------------------------------------------- + nb::class_(m, "Value") + .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) + .def_prop_ro( + "context", + [](PyValue &self) { return self.getParentOperation()->getContext(); }, + "Context in which the value lives.") + .def( + "dump", [](PyValue &self) { mlirValueDump(self.get()); }, + kDumpDocstring) + .def_prop_ro( + "owner", + [](PyValue &self) -> nb::object { + MlirValue v = self.get(); + if (mlirValueIsAOpResult(v)) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation().getObject(); + } + + if (mlirValueIsABlockArgument(v)) { + MlirBlock block = mlirBlockArgumentGetOwner(self.get()); + return nb::cast(PyBlock(self.getParentOperation(), block)); + } + + assert(false && "Value must be a block argument or an op result"); + return nb::none(); + }) + .def_prop_ro("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) + .def("__eq__", + [](PyValue &self, PyValue &other) { + return self.get().ptr == other.get().ptr; + }) + .def("__eq__", [](PyValue &self, nb::object other) { return false; }) + .def("__hash__", + [](PyValue &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) + .def( + "__str__", + [](PyValue &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("Value("); + mlirValuePrint(self.get(), printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }, + kValueDunderStrDocstring) + .def( + "get_name", + [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) { + PyPrintAccumulator printAccum; + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + if (useNameLocAsPrefix) + mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); + MlirAsmState valueState = + mlirAsmStateCreateForValue(self.get(), flags); + mlirValuePrintAsOperand(self.get(), valueState, + printAccum.getCallback(), + printAccum.getUserData()); + mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(valueState); + return printAccum.join(); + }, + nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false) + .def( + "get_name", + [](PyValue &self, PyAsmState &state) { + PyPrintAccumulator printAccum; + MlirAsmState valueState = state.get(); + mlirValuePrintAsOperand(self.get(), valueState, + printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }, + nb::arg("state"), kGetNameAsOperand) + .def_prop_ro("type", + [](PyValue &self) { return mlirValueGetType(self.get()); }) + .def( + "set_type", + [](PyValue &self, const PyType &type) { + return mlirValueSetType(self.get(), type); + }, + nb::arg("type")) + .def( + "replace_all_uses_with", + [](PyValue &self, PyValue &with) { + mlirValueReplaceAllUsesOfWith(self.get(), with.get()); + }, + kValueReplaceAllUsesWithDocstring) + .def( + "replace_all_uses_except", + [](MlirValue self, MlirValue with, PyOperation &exception) { + MlirOperation exceptedUser = exception.get(); + mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); + }, + nb::arg("with"), nb::arg("exceptions"), + kValueReplaceAllUsesExceptDocstring) + .def( + "replace_all_uses_except", + [](MlirValue self, MlirValue with, nb::list exceptions) { + // Convert Python list to a SmallVector of MlirOperations + llvm::SmallVector exceptionOps; + for (nb::handle exception : exceptions) { + exceptionOps.push_back(nb::cast(exception).get()); + } + + mlirValueReplaceAllUsesExcept( + self, with, static_cast(exceptionOps.size()), + exceptionOps.data()); + }, + nb::arg("with"), nb::arg("exceptions"), + kValueReplaceAllUsesExceptDocstring) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) { return self.maybeDownCast(); }) + .def_prop_ro( + "location", + [](MlirValue self) { + return PyLocation( + PyMlirContext::forContext(mlirValueGetContext(self)), + mlirValueGetLocation(self)); + }, + "Returns the source location the value"); + + PyBlockArgument::bind(m); + PyOpResult::bind(m); + PyOpOperand::bind(m); + + nb::class_(m, "AsmState") + .def(nb::init(), nb::arg("value"), + nb::arg("use_local_scope") = false) + .def(nb::init(), nb::arg("op"), + nb::arg("use_local_scope") = false); + + //---------------------------------------------------------------------------- + // Mapping of SymbolTable. + //---------------------------------------------------------------------------- + nb::class_(m, "SymbolTable") + .def(nb::init()) + .def("__getitem__", &PySymbolTable::dunderGetItem) + .def("insert", &PySymbolTable::insert, nb::arg("operation")) + .def("erase", &PySymbolTable::erase, nb::arg("operation")) + .def("__delitem__", &PySymbolTable::dunderDel) + .def("__contains__", + [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }) + // Static helpers. + .def_static("set_symbol_name", &PySymbolTable::setSymbolName, + nb::arg("symbol"), nb::arg("name")) + .def_static("get_symbol_name", &PySymbolTable::getSymbolName, + nb::arg("symbol")) + .def_static("get_visibility", &PySymbolTable::getVisibility, + nb::arg("symbol")) + .def_static("set_visibility", &PySymbolTable::setVisibility, + nb::arg("symbol"), nb::arg("visibility")) + .def_static("replace_all_symbol_uses", + &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), + nb::arg("new_symbol"), nb::arg("from_op")) + .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, + nb::arg("from_op"), nb::arg("all_sym_uses_visible"), + nb::arg("callback")); + + // Container bindings. + PyBlockArgumentList::bind(m); + PyBlockIterator::bind(m); + PyBlockList::bind(m); + PyBlockSuccessors::bind(m); + PyBlockPredecessors::bind(m); + PyOperationIterator::bind(m); + PyOperationList::bind(m); + PyOpAttributeMap::bind(m); + PyOpOperandIterator::bind(m); + PyOpOperandList::bind(m); + PyOpResultList::bind(m); + PyOpSuccessors::bind(m); + PyRegionIterator::bind(m); + PyRegionList::bind(m); + + // Debug bindings. + PyGlobalDebugFlag::bind(m); + + // Attribute builder getter. + PyAttrBuilderMap::bind(m); + + nb::register_exception_translator([](const std::exception_ptr &p, + void *payload) { + // We can't define exceptions with custom fields through pybind, so instead + // the exception class is defined in python and imported here. + try { + if (p) + std::rethrow_exception(p); + } catch (const MLIRError &e) { + nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("MLIRError")(e.message, e.errorDiagnostics); + PyErr_SetObject(PyExc_Exception, obj.ptr()); + } + }); +} diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp new file mode 100644 index 000000000..9e1fedaab --- /dev/null +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -0,0 +1,482 @@ +//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include +#include + +#include "IRModule.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/IR.h" +#include "mlir-c/Interfaces.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" + +namespace nb = nanobind; + +namespace mlir { +namespace python { + +constexpr static const char *constructorDoc = + R"(Creates an interface from a given operation/opview object or from a +subclass of OpView. Raises ValueError if the operation does not implement the +interface.)"; + +constexpr static const char *operationDoc = + R"(Returns an Operation for which the interface was constructed.)"; + +constexpr static const char *opviewDoc = + R"(Returns an OpView subclass _instance_ for which the interface was +constructed)"; + +constexpr static const char *inferReturnTypesDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return types. Raises ValueError on failure.)"; + +constexpr static const char *inferReturnTypeComponentsDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return shaped type components. Raises ValueError on failure.)"; + +namespace { + +/// Takes in an optional ist of operands and converts them into a SmallVector +/// of MlirVlaues. Returns an empty SmallVector if the list is empty. +llvm::SmallVector wrapOperands(std::optional operandList) { + llvm::SmallVector mlirOperands; + + if (!operandList || operandList->size() == 0) { + return mlirOperands; + } + + // Note: as the list may contain other lists this may not be final size. + mlirOperands.reserve(operandList->size()); + for (const auto &&it : llvm::enumerate(*operandList)) { + if (it.value().is_none()) + continue; + + PyValue *val; + try { + val = nb::cast(it.value()); + if (!val) + throw nb::cast_error(); + mlirOperands.push_back(val->get()); + continue; + } catch (nb::cast_error &err) { + // Intentionally unhandled to try sequence below first. + (void)err; + } + + try { + auto vals = nb::cast(it.value()); + for (nb::handle v : vals) { + try { + val = nb::cast(v); + if (!val) + throw nb::cast_error(); + mlirOperands.push_back(val->get()); + } catch (nb::cast_error &err) { + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str() + .c_str()); + } + } + continue; + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + + err.what() + ")") + .str() + .c_str()); + } + + throw nb::cast_error(); + } + + return mlirOperands; +} + +/// Takes in an optional vector of PyRegions and returns a SmallVector of +/// MlirRegion. Returns an empty SmallVector if the list is empty. +llvm::SmallVector +wrapRegions(std::optional> regions) { + llvm::SmallVector mlirRegions; + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + return mlirRegions; +} + +} // namespace + +/// CRTP base class for Python classes representing MLIR Op interfaces. +/// Interface hierarchies are flat so no base class is expected here. The +/// derived class is expected to define the following static fields: +/// - `const char *pyClassName` - the name of the Python class to create; +/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID +/// of the interface. +/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind +/// interface-specific methods. +/// +/// An interface class may be constructed from either an Operation/OpView object +/// or from a subclass of OpView. In the latter case, only the static interface +/// methods are available, similarly to calling ConcereteOp::staticMethod on the +/// C++ side. Implementations of concrete interfaces can use the `isStatic` +/// method to check whether the interface object was constructed from a class or +/// an operation/opview instance. The `getOpName` always succeeds and returns a +/// canonical name of the operation suitable for lookups. +template +class PyConcreteOpInterface { +protected: + using ClassTy = nb::class_; + using GetTypeIDFunctionTy = MlirTypeID (*)(); + +public: + /// Constructs an interface instance from an object that is either an + /// operation or a subclass of OpView. In the latter case, only the static + /// methods of the interface are accessible to the caller. + PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) + : obj(std::move(object)) { + try { + operation = &nb::cast(obj); + } catch (nb::cast_error &) { + // Do nothing. + } + + try { + operation = &nb::cast(obj).getOperation(); + } catch (nb::cast_error &) { + // Do nothing. + } + + if (operation != nullptr) { + if (!mlirOperationImplementsInterface(*operation, + ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); + } + + MlirIdentifier identifier = mlirOperationGetName(*operation); + MlirStringRef stringRef = mlirIdentifierStr(identifier); + opName = std::string(stringRef.data, stringRef.length); + } else { + try { + opName = nb::cast(obj.attr("OPERATION_NAME")); + } catch (nb::cast_error &) { + throw nb::type_error( + "Op interface does not refer to an operation or OpView class"); + } + + if (!mlirOperationImplementsInterfaceStatic( + mlirStringRefCreate(opName.data(), opName.length()), + context.resolve().get(), ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); + } + } + } + + /// Creates the Python bindings for this class in the given module. + static void bind(nb::module_ &m) { + nb::class_ cls(m, ConcreteIface::pyClassName); + cls.def(nb::init(), nb::arg("object"), + nb::arg("context").none() = nb::none(), constructorDoc) + .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); + ConcreteIface::bindDerived(cls); + } + + /// Hook for derived classes to add class-specific bindings. + static void bindDerived(ClassTy &cls) {} + + /// Returns `true` if this object was constructed from a subclass of OpView + /// rather than from an operation instance. + bool isStatic() { return operation == nullptr; } + + /// Returns the operation instance from which this object was constructed. + /// Throws a type error if this object was constructed from a subclass of + /// OpView. + nb::object getOperationObject() { + if (operation == nullptr) { + throw nb::type_error("Cannot get an operation from a static interface"); + } + + return operation->getRef().releaseObject(); + } + + /// Returns the opview of the operation instance from which this object was + /// constructed. Throws a type error if this object was constructed form a + /// subclass of OpView. + nb::object getOpView() { + if (operation == nullptr) { + throw nb::type_error("Cannot get an opview from a static interface"); + } + + return operation->createOpView(); + } + + /// Returns the canonical name of the operation this interface is constructed + /// from. + const std::string &getOpName() { return opName; } + +private: + PyOperation *operation = nullptr; + std::string opName; + nb::object obj; +}; + +/// Python wrapper for InferTypeOpInterface. This interface has only static +/// methods. +class PyInferTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredTypes; + PyMlirContext &pyMlirContext; + }; + + /// Appends the types provided as the two first arguments to the user-data + /// structure (expects AppendResultsCallbackData). + static void appendResultsCallback(intptr_t nTypes, MlirType *types, + void *userData) { + auto *data = static_cast(userData); + data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); + for (intptr_t i = 0; i < nTypes; ++i) { + data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); + } + } + + /// Given the arguments required to build an operation, attempts to infer its + /// return types. Throws value_error on failure. + std::vector + inferReturnTypes(std::optional operandList, + std::optional attributes, void *properties, + std::optional> regions, + DefaultingPyMlirContext context, + DefaultingPyLocation location) { + llvm::SmallVector mlirOperands = + wrapOperands(std::move(operandList)); + llvm::SmallVector mlirRegions = wrapRegions(std::move(regions)); + + std::vector inferredTypes; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredTypes, pyContext}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, properties, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw nb::value_error("Failed to infer result types"); + } + + return inferredTypes; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); + } +}; + +/// Wrapper around an shaped type components. +class PyShapedTypeComponents { +public: + PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} + PyShapedTypeComponents(nb::list shape, MlirType elementType) + : shape(std::move(shape)), elementType(elementType), ranked(true) {} + PyShapedTypeComponents(nb::list shape, MlirType elementType, + MlirAttribute attribute) + : shape(std::move(shape)), elementType(elementType), attribute(attribute), + ranked(true) {} + PyShapedTypeComponents(PyShapedTypeComponents &) = delete; + PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept + : shape(other.shape), elementType(other.elementType), + attribute(other.attribute), ranked(other.ranked) {} + + static void bind(nb::module_ &m) { + nb::class_(m, "ShapedTypeComponents") + .def_prop_ro( + "element_type", + [](PyShapedTypeComponents &self) { return self.elementType; }, + "Returns the element type of the shaped type components.") + .def_static( + "get", + [](PyType &elementType) { + return PyShapedTypeComponents(elementType); + }, + nb::arg("element_type"), + "Create an shaped type components object with only the element " + "type.") + .def_static( + "get", + [](nb::list shape, PyType &elementType) { + return PyShapedTypeComponents(std::move(shape), elementType); + }, + nb::arg("shape"), nb::arg("element_type"), + "Create a ranked shaped type components object.") + .def_static( + "get", + [](nb::list shape, PyType &elementType, PyAttribute &attribute) { + return PyShapedTypeComponents(std::move(shape), elementType, + attribute); + }, + nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), + "Create a ranked shaped type components object with attribute.") + .def_prop_ro( + "has_rank", + [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, + "Returns whether the given shaped type component is ranked.") + .def_prop_ro( + "rank", + [](PyShapedTypeComponents &self) -> nb::object { + if (!self.ranked) { + return nb::none(); + } + return nb::int_(self.shape.size()); + }, + "Returns the rank of the given ranked shaped type components. If " + "the shaped type components does not have a rank, None is " + "returned.") + .def_prop_ro( + "shape", + [](PyShapedTypeComponents &self) -> nb::object { + if (!self.ranked) { + return nb::none(); + } + return nb::list(self.shape); + }, + "Returns the shape of the ranked shaped type components as a list " + "of integers. Returns none if the shaped type component does not " + "have a rank."); + } + + nb::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(nb::object capsule); + +private: + nb::list shape; + MlirType elementType; + MlirAttribute attribute; + bool ranked{false}; +}; + +/// Python wrapper for InferShapedTypeOpInterface. This interface has only +/// static methods. +class PyInferShapedTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface< + PyInferShapedTypeOpInterface>::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferShapedTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredShapedTypeComponents; + }; + + /// Appends the shaped type components provided as unpacked shape, element + /// type, attribute to the user-data. + static void appendResultsCallback(bool hasRank, intptr_t rank, + const int64_t *shape, MlirType elementType, + MlirAttribute attribute, void *userData) { + auto *data = static_cast(userData); + if (!hasRank) { + data->inferredShapedTypeComponents.emplace_back(elementType); + } else { + nb::list shapeList; + for (intptr_t i = 0; i < rank; ++i) { + shapeList.append(shape[i]); + } + data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, + attribute); + } + } + + /// Given the arguments required to build an operation, attempts to infer the + /// shaped type components. Throws value_error on failure. + std::vector inferReturnTypeComponents( + std::optional operandList, + std::optional attributes, void *properties, + std::optional> regions, + DefaultingPyMlirContext context, DefaultingPyLocation location) { + llvm::SmallVector mlirOperands = + wrapOperands(std::move(operandList)); + llvm::SmallVector mlirRegions = wrapRegions(std::move(regions)); + + std::vector inferredShapedTypeComponents; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredShapedTypeComponents}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, properties, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw nb::value_error("Failed to infer result shape type components"); + } + + return inferredShapedTypeComponents; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypeComponents", + &PyInferShapedTypeOpInterface::inferReturnTypeComponents, + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); + } +}; + +void populateIRInterfaces(nb::module_ &m) { + PyInferTypeOpInterface::bind(m); + PyShapedTypeComponents::bind(m); + PyInferShapedTypeOpInterface::bind(m); +} + +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp new file mode 100644 index 000000000..0de2f1711 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -0,0 +1,267 @@ +//===- IRModule.cpp - IR pybind module ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include +#include + +#include "Globals.h" +#include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; + +// ----------------------------------------------------------------------------- +// PyGlobals +// ----------------------------------------------------------------------------- + +PyGlobals *PyGlobals::instance = nullptr; + +PyGlobals::PyGlobals() { + assert(!instance && "PyGlobals already constructed"); + instance = this; + // The default search path include {mlir.}dialects, where {mlir.} is the + // package prefix configured at compile time. + dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); +} + +PyGlobals::~PyGlobals() { instance = nullptr; } + +bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { + { + nb::ft_lock_guard lock(mutex); + if (loadedDialectModules.contains(dialectNamespace)) + return true; + } + // Since re-entrancy is possible, make a copy of the search prefixes. + std::vector localSearchPrefixes = dialectSearchPrefixes; + nb::object loaded = nb::none(); + for (std::string moduleName : localSearchPrefixes) { + moduleName.push_back('.'); + moduleName.append(dialectNamespace.data(), dialectNamespace.size()); + + try { + loaded = nb::module_::import_(moduleName.c_str()); + } catch (nb::python_error &e) { + if (e.matches(PyExc_ModuleNotFoundError)) { + continue; + } + throw; + } + break; + } + + if (loaded.is_none()) + return false; + // Note: Iterator cannot be shared from prior to loading, since re-entrancy + // may have occurred, which may do anything. + nb::ft_lock_guard lock(mutex); + loadedDialectModules.insert(dialectNamespace); + return true; +} + +void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, + nb::callable pyFunc, bool replace) { + nb::ft_lock_guard lock(mutex); + nb::object &found = attributeBuilderMap[attributeKind]; + if (found && !replace) { + throw std::runtime_error((llvm::Twine("Attribute builder for '") + + attributeKind + + "' is already registered with func: " + + nb::cast(nb::str(found))) + .str()); + } + found = std::move(pyFunc); +} + +void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, + nb::callable typeCaster, bool replace) { + nb::ft_lock_guard lock(mutex); + nb::object &found = typeCasterMap[mlirTypeID]; + if (found && !replace) + throw std::runtime_error("Type caster is already registered with caster: " + + nb::cast(nb::str(found))); + found = std::move(typeCaster); +} + +void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, + nb::callable valueCaster, bool replace) { + nb::ft_lock_guard lock(mutex); + nb::object &found = valueCasterMap[mlirTypeID]; + if (found && !replace) + throw std::runtime_error("Value caster is already registered: " + + nb::cast(nb::repr(found))); + found = std::move(valueCaster); +} + +void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, + nb::object pyClass) { + nb::ft_lock_guard lock(mutex); + nb::object &found = dialectClassMap[dialectNamespace]; + if (found) { + throw std::runtime_error((llvm::Twine("Dialect namespace '") + + dialectNamespace + "' is already registered.") + .str()); + } + found = std::move(pyClass); +} + +void PyGlobals::registerOperationImpl(const std::string &operationName, + nb::object pyClass, bool replace) { + nb::ft_lock_guard lock(mutex); + nb::object &found = operationClassMap[operationName]; + if (found && !replace) { + throw std::runtime_error((llvm::Twine("Operation '") + operationName + + "' is already registered.") + .str()); + } + found = std::move(pyClass); +} + +std::optional +PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { + nb::ft_lock_guard lock(mutex); + const auto foundIt = attributeBuilderMap.find(attributeKind); + if (foundIt != attributeBuilderMap.end()) { + assert(foundIt->second && "attribute builder is defined"); + return foundIt->second; + } + return std::nullopt; +} + +std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, + MlirDialect dialect) { + // Try to load dialect module. + (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + nb::ft_lock_guard lock(mutex); + const auto foundIt = typeCasterMap.find(mlirTypeID); + if (foundIt != typeCasterMap.end()) { + assert(foundIt->second && "type caster is defined"); + return foundIt->second; + } + return std::nullopt; +} + +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect) { + // Try to load dialect module. + (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + nb::ft_lock_guard lock(mutex); + const auto foundIt = valueCasterMap.find(mlirTypeID); + if (foundIt != valueCasterMap.end()) { + assert(foundIt->second && "value caster is defined"); + return foundIt->second; + } + return std::nullopt; +} + +std::optional +PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { + // Make sure dialect module is loaded. + if (!loadDialectModule(dialectNamespace)) + return std::nullopt; + nb::ft_lock_guard lock(mutex); + const auto foundIt = dialectClassMap.find(dialectNamespace); + if (foundIt != dialectClassMap.end()) { + assert(foundIt->second && "dialect class is defined"); + return foundIt->second; + } + // Not found and loading did not yield a registration. + return std::nullopt; +} + +std::optional +PyGlobals::lookupOperationClass(llvm::StringRef operationName) { + // Make sure dialect module is loaded. + auto split = operationName.split('.'); + llvm::StringRef dialectNamespace = split.first; + if (!loadDialectModule(dialectNamespace)) + return std::nullopt; + + nb::ft_lock_guard lock(mutex); + auto foundIt = operationClassMap.find(operationName); + if (foundIt != operationClassMap.end()) { + assert(foundIt->second && "OpView is defined"); + return foundIt->second; + } + // Not found and loading did not yield a registration. + return std::nullopt; +} + +bool PyGlobals::TracebackLoc::locTracebacksEnabled() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackEnabled_; +} + +void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackEnabled_ = value; +} + +size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackFramesLimit_; +} + +void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackFramesLimit_ = std::min(value, kMaxFrames); +} + +void PyGlobals::TracebackLoc::registerTracebackFileInclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackIncludeFiles.insert(reg).second) + rebuildUserTracebackIncludeRegex = true; + if (userTracebackExcludeFiles.count(reg)) { + if (userTracebackExcludeFiles.erase(reg)) + rebuildUserTracebackExcludeRegex = true; + } +} + +void PyGlobals::TracebackLoc::registerTracebackFileExclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackExcludeFiles.insert(reg).second) + rebuildUserTracebackExcludeRegex = true; + if (userTracebackIncludeFiles.count(reg)) { + if (userTracebackIncludeFiles.erase(reg)) + rebuildUserTracebackIncludeRegex = true; + } +} + +bool PyGlobals::TracebackLoc::isUserTracebackFilename( + const llvm::StringRef file) { + nanobind::ft_lock_guard lock(mutex); + if (rebuildUserTracebackIncludeRegex) { + userTracebackIncludeRegex.assign( + llvm::join(userTracebackIncludeFiles, "|")); + rebuildUserTracebackIncludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (rebuildUserTracebackExcludeRegex) { + userTracebackExcludeRegex.assign( + llvm::join(userTracebackExcludeFiles, "|")); + rebuildUserTracebackExcludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (!isUserTracebackFilenameCache.contains(file)) { + std::string fileStr = file.str(); + bool include = std::regex_search(fileStr, userTracebackIncludeRegex); + bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex); + isUserTracebackFilenameCache[file] = include || !exclude; + } + return isUserTracebackFilenameCache[file]; +} diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h new file mode 100644 index 000000000..fa16ae3ce --- /dev/null +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -0,0 +1,1355 @@ +//===- IRModules.h - IR Submodules of pybind module -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H +#define MLIR_BINDINGS_PYTHON_IRMODULES_H + +#include +#include +#include +#include + +#include "Globals.h" +#include "NanobindUtils.h" +#include "mlir-c/AffineExpr.h" +#include "mlir-c/AffineMap.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" +#include "mlir-c/Transforms.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ThreadPool.h" + +namespace mlir { +namespace python { + +class PyBlock; +class PyDiagnostic; +class PyDiagnosticHandler; +class PyInsertionPoint; +class PyLocation; +class DefaultingPyLocation; +class PyMlirContext; +class DefaultingPyMlirContext; +class PyModule; +class PyOperation; +class PyOperationBase; +class PyType; +class PySymbolTable; +class PyValue; + +/// Template for a reference to a concrete type which captures a python +/// reference to its underlying python object. +template +class PyObjectRef { +public: + PyObjectRef(T *referrent, nanobind::object object) + : referrent(referrent), object(std::move(object)) { + assert(this->referrent && + "cannot construct PyObjectRef with null referrent"); + assert(this->object && "cannot construct PyObjectRef with null object"); + } + PyObjectRef(PyObjectRef &&other) noexcept + : referrent(other.referrent), object(std::move(other.object)) { + other.referrent = nullptr; + assert(!other.object); + } + PyObjectRef(const PyObjectRef &other) + : referrent(other.referrent), object(other.object /* copies */) {} + ~PyObjectRef() = default; + + int getRefCount() { + if (!object) + return 0; + return Py_REFCNT(object.ptr()); + } + + /// Releases the object held by this instance, returning it. + /// This is the proper thing to return from a function that wants to return + /// the reference. Note that this does not work from initializers. + nanobind::object releaseObject() { + assert(referrent && object); + referrent = nullptr; + auto stolen = std::move(object); + return stolen; + } + + T *get() { return referrent; } + T *operator->() { + assert(referrent && object); + return referrent; + } + nanobind::object getObject() { + assert(referrent && object); + return object; + } + operator bool() const { return referrent && object; } + +private: + T *referrent; + nanobind::object object; +}; + +/// Tracks an entry in the thread context stack. New entries are pushed onto +/// here for each with block that activates a new InsertionPoint, Context or +/// Location. +/// +/// Pushing either a Location or InsertionPoint also pushes its associated +/// Context. Pushing a Context will not modify the Location or InsertionPoint +/// unless if they are from a different context, in which case, they are +/// cleared. +class PyThreadContextEntry { +public: + enum class FrameKind { + Context, + InsertionPoint, + Location, + }; + + PyThreadContextEntry(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, + nanobind::object location) + : context(std::move(context)), insertionPoint(std::move(insertionPoint)), + location(std::move(location)), frameKind(frameKind) {} + + /// Gets the top of stack context and return nullptr if not defined. + static PyMlirContext *getDefaultContext(); + + /// Gets the top of stack insertion point and return nullptr if not defined. + static PyInsertionPoint *getDefaultInsertionPoint(); + + /// Gets the top of stack location and returns nullptr if not defined. + static PyLocation *getDefaultLocation(); + + PyMlirContext *getContext(); + PyInsertionPoint *getInsertionPoint(); + PyLocation *getLocation(); + FrameKind getFrameKind() { return frameKind; } + + /// Stack management. + static PyThreadContextEntry *getTopOfStack(); + static nanobind::object pushContext(nanobind::object context); + static void popContext(PyMlirContext &context); + static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); + static void popInsertionPoint(PyInsertionPoint &insertionPoint); + static nanobind::object pushLocation(nanobind::object location); + static void popLocation(PyLocation &location); + + /// Gets the thread local stack. + static std::vector &getStack(); + +private: + static void push(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, nanobind::object location); + + /// An object reference to the PyContext. + nanobind::object context; + /// An object reference to the current insertion point. + nanobind::object insertionPoint; + /// An object reference to the current location. + nanobind::object location; + // The kind of push that was performed. + FrameKind frameKind; +}; + +/// Wrapper around MlirLlvmThreadPool +/// Python object owns the C++ thread pool +class PyThreadPool { +public: + PyThreadPool() { + ownedThreadPool = std::make_unique(); + } + PyThreadPool(const PyThreadPool &) = delete; + PyThreadPool(PyThreadPool &&) = delete; + + int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } + MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } + + std::string _mlir_thread_pool_ptr() const { + std::stringstream ss; + ss << ownedThreadPool.get(); + return ss.str(); + } + +private: + std::unique_ptr ownedThreadPool; +}; + +/// Wrapper around MlirContext. +using PyMlirContextRef = PyObjectRef; +class PyMlirContext { +public: + PyMlirContext() = delete; + PyMlirContext(MlirContext context); + PyMlirContext(const PyMlirContext &) = delete; + PyMlirContext(PyMlirContext &&) = delete; + + /// Returns a context reference for the singleton PyMlirContext wrapper for + /// the given context. + static PyMlirContextRef forContext(MlirContext context); + ~PyMlirContext(); + + /// Accesses the underlying MlirContext. + MlirContext get() { return context; } + + /// Gets a strong reference to this context, which will ensure it is kept + /// alive for the life of the reference. + PyMlirContextRef getRef() { + return PyMlirContextRef(this, nanobind::cast(this)); + } + + /// Gets a capsule wrapping the void* within the MlirContext. + nanobind::object getCapsule(); + + /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. + /// Note that PyMlirContext instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirContext + /// is taken by calling this function. + static nanobind::object createFromCapsule(nanobind::object capsule); + + /// Gets the count of live context objects. Used for testing. + static size_t getLiveCount(); + + /// Get a list of Python objects which are still in the live context map. + std::vector getLiveOperationObjects(); + + /// Gets the count of live operations associated with this context. + /// Used for testing. + size_t getLiveOperationCount(); + + /// Clears the live operations map, returning the number of entries which were + /// invalidated. To be used as a safety mechanism so that API end-users can't + /// corrupt by holding references they shouldn't have accessed in the first + /// place. + size_t clearLiveOperations(); + + /// Removes an operation from the live operations map and sets it invalid. + /// This is useful for when some non-bindings code destroys the operation and + /// the bindings need to made aware. For example, in the case when pass + /// manager is run. + /// + /// Note that this does *NOT* clear the nested operations. + void clearOperation(MlirOperation op); + + /// Clears all operations nested inside the given op using + /// `clearOperation(MlirOperation)`. + void clearOperationsInside(PyOperationBase &op); + void clearOperationsInside(MlirOperation op); + + /// Clears the operaiton _and_ all operations inside using + /// `clearOperation(MlirOperation)`. + void clearOperationAndInside(PyOperationBase &op); + + /// Gets the count of live modules associated with this context. + /// Used for testing. + size_t getLiveModuleCount(); + + /// Enter and exit the context manager. + static nanobind::object contextEnter(nanobind::object context); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); + + /// Attaches a Python callback as a diagnostic handler, returning a + /// registration object (internally a PyDiagnosticHandler). + nanobind::object attachDiagnosticHandler(nanobind::object callback); + + /// Controls whether error diagnostics should be propagated to diagnostic + /// handlers, instead of being captured by `ErrorCapture`. + void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } + struct ErrorCapture; + +private: + // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, + // preserving the relationship that an MlirContext maps to a single + // PyMlirContext wrapper. This could be replaced in the future with an + // extension mechanism on the MlirContext for stashing user pointers. + // Note that this holds a handle, which does not imply ownership. + // Mappings will be removed when the context is destructed. + using LiveContextMap = llvm::DenseMap; + static nanobind::ft_mutex live_contexts_mutex; + static LiveContextMap &getLiveContexts(); + + // Interns all live modules associated with this context. Modules tracked + // in this map are valid. When a module is invalidated, it is removed + // from this map, and while it still exists as an instance, any + // attempt to access it will raise an error. + using LiveModuleMap = + llvm::DenseMap>; + LiveModuleMap liveModules; + + // Interns all live operations associated with this context. Operations + // tracked in this map are valid. When an operation is invalidated, it is + // removed from this map, and while it still exists as an instance, any + // attempt to access it will raise an error. + using LiveOperationMap = + llvm::DenseMap>; + nanobind::ft_mutex liveOperationsMutex; + + // Guarded by liveOperationsMutex in free-threading mode. + LiveOperationMap liveOperations; + + bool emitErrorDiagnostics = false; + + MlirContext context; + friend class PyModule; + friend class PyOperation; +}; + +/// Used in function arguments when None should resolve to the current context +/// manager set instance. +class DefaultingPyMlirContext + : public Defaulting { +public: + using Defaulting::Defaulting; + static constexpr const char kTypeDescription[] = "mlir.ir.Context"; + static PyMlirContext &resolve(); +}; + +/// Base class for all objects that directly or indirectly depend on an +/// MlirContext. The lifetime of the context will extend at least to the +/// lifetime of these instances. +/// Immutable objects that depend on a context extend this directly. +class BaseContextObject { +public: + BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { + assert(this->contextRef && + "context object constructed with null context ref"); + } + + /// Accesses the context reference. + PyMlirContextRef &getContext() { return contextRef; } + +private: + PyMlirContextRef contextRef; +}; + +/// Wrapper around an MlirLocation. +class PyLocation : public BaseContextObject { +public: + PyLocation(PyMlirContextRef contextRef, MlirLocation loc) + : BaseContextObject(std::move(contextRef)), loc(loc) {} + + operator MlirLocation() const { return loc; } + MlirLocation get() const { return loc; } + + /// Enter and exit the context manager. + static nanobind::object contextEnter(nanobind::object location); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); + + /// Gets a capsule wrapping the void* within the MlirLocation. + nanobind::object getCapsule(); + + /// Creates a PyLocation from the MlirLocation wrapped by a capsule. + /// Note that PyLocation instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirLocation + /// is taken by calling this function. + static PyLocation createFromCapsule(nanobind::object capsule); + +private: + MlirLocation loc; +}; + +/// Python class mirroring the C MlirDiagnostic struct. Note that these structs +/// are only valid for the duration of a diagnostic callback and attempting +/// to access them outside of that will raise an exception. This applies to +/// nested diagnostics (in the notes) as well. +class PyDiagnostic { +public: + PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} + void invalidate(); + bool isValid() { return valid; } + MlirDiagnosticSeverity getSeverity(); + PyLocation getLocation(); + nanobind::str getMessage(); + nanobind::tuple getNotes(); + + /// Materialized diagnostic information. This is safe to access outside the + /// diagnostic callback. + struct DiagnosticInfo { + MlirDiagnosticSeverity severity; + PyLocation location; + std::string message; + std::vector notes; + }; + DiagnosticInfo getInfo(); + +private: + MlirDiagnostic diagnostic; + + void checkValid(); + /// If notes have been materialized from the diagnostic, then this will + /// be populated with the corresponding objects (all castable to + /// PyDiagnostic). + std::optional materializedNotes; + bool valid = true; +}; + +/// Represents a diagnostic handler attached to the context. The handler's +/// callback will be invoked with PyDiagnostic instances until the detach() +/// method is called or the context is destroyed. A diagnostic handler can be +/// the subject of a `with` block, which will detach it when the block exits. +/// +/// Since diagnostic handlers can call back into Python code which can do +/// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, +/// etc), this is generally not deemed to be a great user-level API. Users +/// should generally use some form of DiagnosticCollector. If the handler raises +/// any exceptions, they will just be emitted to stderr and dropped. +/// +/// The unique usage of this class means that its lifetime management is +/// different from most other parts of the API. Instances are always created +/// in an attached state and can transition to a detached state by either: +/// a) The context being destroyed and unregistering all handlers. +/// b) An explicit call to detach(). +/// The object may remain live from a Python perspective for an arbitrary time +/// after detachment, but there is nothing the user can do with it (since there +/// is no way to attach an existing handler object). +class PyDiagnosticHandler { +public: + PyDiagnosticHandler(MlirContext context, nanobind::object callback); + ~PyDiagnosticHandler(); + + bool isAttached() { return registeredID.has_value(); } + bool getHadError() { return hadError; } + + /// Detaches the handler. Does nothing if not attached. + void detach(); + + nanobind::object contextEnter() { return nanobind::cast(this); } + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb) { + detach(); + } + +private: + MlirContext context; + nanobind::object callback; + std::optional registeredID; + bool hadError = false; + friend class PyMlirContext; +}; + +/// RAII object that captures any error diagnostics emitted to the provided +/// context. +struct PyMlirContext::ErrorCapture { + ErrorCapture(PyMlirContextRef ctx) + : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( + ctx->get(), handler, /*userData=*/this, + /*deleteUserData=*/nullptr)) {} + ~ErrorCapture() { + mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); + assert(errors.empty() && "unhandled captured errors"); + } + + std::vector take() { + return std::move(errors); + }; + +private: + PyMlirContextRef ctx; + MlirDiagnosticHandlerID handlerID; + std::vector errors; + + static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); +}; + +/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in +/// order to differentiate it from the `Dialect` base class which is extended by +/// plugins which extend dialect functionality through extension python code. +/// This should be seen as the "low-level" object and `Dialect` as the +/// high-level, user facing object. +class PyDialectDescriptor : public BaseContextObject { +public: + PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) + : BaseContextObject(std::move(contextRef)), dialect(dialect) {} + + MlirDialect get() { return dialect; } + +private: + MlirDialect dialect; +}; + +/// User-level object for accessing dialects with dotted syntax such as: +/// ctx.dialect.std +class PyDialects : public BaseContextObject { +public: + PyDialects(PyMlirContextRef contextRef) + : BaseContextObject(std::move(contextRef)) {} + + MlirDialect getDialectForKey(const std::string &key, bool attrError); +}; + +/// User-level dialect object. For dialects that have a registered extension, +/// this will be the base class of the extension dialect type. For un-extended, +/// objects of this type will be returned directly. +class PyDialect { +public: + PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} + + nanobind::object getDescriptor() { return descriptor; } + +private: + nanobind::object descriptor; +}; + +/// Wrapper around an MlirDialectRegistry. +/// Upon construction, the Python wrapper takes ownership of the +/// underlying MlirDialectRegistry. +class PyDialectRegistry { +public: + PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} + PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} + ~PyDialectRegistry() { + if (!mlirDialectRegistryIsNull(registry)) + mlirDialectRegistryDestroy(registry); + } + PyDialectRegistry(PyDialectRegistry &) = delete; + PyDialectRegistry(PyDialectRegistry &&other) noexcept + : registry(other.registry) { + other.registry = {nullptr}; + } + + operator MlirDialectRegistry() const { return registry; } + MlirDialectRegistry get() const { return registry; } + + nanobind::object getCapsule(); + static PyDialectRegistry createFromCapsule(nanobind::object capsule); + +private: + MlirDialectRegistry registry; +}; + +/// Used in function arguments when None should resolve to the current context +/// manager set instance. +class DefaultingPyLocation + : public Defaulting { +public: + using Defaulting::Defaulting; + static constexpr const char kTypeDescription[] = "mlir.ir.Location"; + static PyLocation &resolve(); + + operator MlirLocation() const { return *get(); } +}; + +/// Wrapper around MlirModule. +/// This is the top-level, user-owned object that contains regions/ops/blocks. +class PyModule; +using PyModuleRef = PyObjectRef; +class PyModule : public BaseContextObject { +public: + /// Returns a PyModule reference for the given MlirModule. This may return + /// a pre-existing or new object. + static PyModuleRef forModule(MlirModule module); + PyModule(PyModule &) = delete; + PyModule(PyMlirContext &&) = delete; + ~PyModule(); + + /// Gets the backing MlirModule. + MlirModule get() { return module; } + + /// Gets a strong reference to this module. + PyModuleRef getRef() { + return PyModuleRef(this, nanobind::borrow(handle)); + } + + /// Gets a capsule wrapping the void* within the MlirModule. + /// Note that the module does not (yet) provide a corresponding factory for + /// constructing from a capsule as that would require uniquing PyModule + /// instances, which is not currently done. + nanobind::object getCapsule(); + + /// Creates a PyModule from the MlirModule wrapped by a capsule. + /// Note that PyModule instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirModule + /// is taken by calling this function. + static nanobind::object createFromCapsule(nanobind::object capsule); + +private: + PyModule(PyMlirContextRef contextRef, MlirModule module); + MlirModule module; + nanobind::handle handle; +}; + +class PyAsmState; + +/// Base class for PyOperation and PyOpView which exposes the primary, user +/// visible methods for manipulating it. +class PyOperationBase { +public: + virtual ~PyOperationBase() = default; + /// Implements the bound 'print' method and helps with others. + void print(std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, + bool useNameLocAsPrefix, bool assumeVerified, + nanobind::object fileObject, bool binary, bool skipRegions); + void print(PyAsmState &state, nanobind::object fileObject, bool binary); + + nanobind::object + getAsm(bool binary, std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, + bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions); + + // Implement the bound 'writeBytecode' method. + void writeBytecode(const nanobind::object &fileObject, + std::optional bytecodeVersion); + + // Implement the walk method. + void walk(std::function callback, + MlirWalkOrder walkOrder); + + /// Moves the operation before or after the other operation. + void moveAfter(PyOperationBase &other); + void moveBefore(PyOperationBase &other); + + /// Given an operation 'other' that is within the same parent block, return + /// whether the current operation is before 'other' in the operation list + /// of the parent block. + /// Note: This function has an average complexity of O(1), but worst case may + /// take O(N) where N is the number of operations within the parent block. + bool isBeforeInBlock(PyOperationBase &other); + + /// Verify the operation. Throws `MLIRError` if verification fails, and + /// returns `true` otherwise. + bool verify(); + + /// Each must provide access to the raw Operation. + virtual PyOperation &getOperation() = 0; +}; + +/// Wrapper around PyOperation. +/// Operations exist in either an attached (dependent) or detached (top-level) +/// state. In the detached state (as on creation), an operation is owned by +/// the creator and its lifetime extends either until its reference count +/// drops to zero or it is attached to a parent, at which point its lifetime +/// is bounded by its top-level parent reference. +class PyOperation; +using PyOperationRef = PyObjectRef; +class PyOperation : public PyOperationBase, public BaseContextObject { +public: + ~PyOperation() override; + PyOperation &getOperation() override { return *this; } + + /// Returns a PyOperation for the given MlirOperation, optionally associating + /// it with a parentKeepAlive. + static PyOperationRef + forOperation(PyMlirContextRef contextRef, MlirOperation operation, + nanobind::object parentKeepAlive = nanobind::object()); + + /// Creates a detached operation. The operation must not be associated with + /// any existing live operation. + static PyOperationRef + createDetached(PyMlirContextRef contextRef, MlirOperation operation, + nanobind::object parentKeepAlive = nanobind::object()); + + /// Parses a source string (either text assembly or bytecode), creating a + /// detached operation. + static PyOperationRef parse(PyMlirContextRef contextRef, + const std::string &sourceStr, + const std::string &sourceName); + + /// Detaches the operation from its parent block and updates its state + /// accordingly. + void detachFromParent() { + mlirOperationRemoveFromParent(getOperation()); + setDetached(); + parentKeepAlive = nanobind::object(); + } + + /// Gets the backing operation. + operator MlirOperation() const { return get(); } + MlirOperation get() const { + checkValid(); + return operation; + } + + PyOperationRef getRef() { + return PyOperationRef(this, nanobind::borrow(handle)); + } + + bool isAttached() { return attached; } + void setAttached(const nanobind::object &parent = nanobind::object()) { + assert(!attached && "operation already attached"); + attached = true; + } + void setDetached() { + assert(attached && "operation already detached"); + attached = false; + } + void checkValid() const; + + /// Gets the owning block or raises an exception if the operation has no + /// owning block. + PyBlock getBlock(); + + /// Gets the parent operation or raises an exception if the operation has + /// no parent. + std::optional getParentOperation(); + + /// Gets a capsule wrapping the void* within the MlirOperation. + nanobind::object getCapsule(); + + /// Creates a PyOperation from the MlirOperation wrapped by a capsule. + /// Ownership of the underlying MlirOperation is taken by calling this + /// function. + static nanobind::object createFromCapsule(nanobind::object capsule); + + /// Creates an operation. See corresponding python docstring. + static nanobind::object + create(std::string_view name, std::optional> results, + llvm::ArrayRef operands, + std::optional attributes, + std::optional> successors, int regions, + PyLocation &location, const nanobind::object &ip, bool inferType); + + /// Creates an OpView suitable for this operation. + nanobind::object createOpView(); + + /// Erases the underlying MlirOperation, removes its pointer from the + /// parent context's live operations map, and sets the valid bit false. + void erase(); + + /// Invalidate the operation. + void setInvalid() { valid = false; } + + /// Clones this operation. + nanobind::object clone(const nanobind::object &ip); + + PyOperation(PyMlirContextRef contextRef, MlirOperation operation); + +private: + static PyOperationRef createInstance(PyMlirContextRef contextRef, + MlirOperation operation, + nanobind::object parentKeepAlive); + + MlirOperation operation; + nanobind::handle handle; + // Keeps the parent alive, regardless of whether it is an Operation or + // Module. + // TODO: As implemented, this facility is only sufficient for modeling the + // trivial module parent back-reference. Generalize this to also account for + // transitions from detached to attached and address TODOs in the + // ir_operation.py regarding testing corresponding lifetime guarantees. + nanobind::object parentKeepAlive; + bool attached = true; + bool valid = true; + + friend class PyOperationBase; + friend class PySymbolTable; +}; + +/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for +/// providing more instance-specific accessors and serve as the base class for +/// custom ODS-style operation classes. Since this class is subclass on the +/// python side, it must present an __init__ method that operates in pure +/// python types. +class PyOpView : public PyOperationBase { +public: + PyOpView(const nanobind::object &operationObject); + PyOperation &getOperation() override { return operation; } + + nanobind::object getOperationObject() { return operationObject; } + + static nanobind::object + buildGeneric(std::string_view name, std::tuple opRegionSpec, + nanobind::object operandSegmentSpecObj, + nanobind::object resultSegmentSpecObj, + std::optional resultTypeList, + nanobind::list operandList, + std::optional attributes, + std::optional> successors, + std::optional regions, PyLocation &location, + const nanobind::object &maybeIp); + + /// Construct an instance of a class deriving from OpView, bypassing its + /// `__init__` method. The derived class will typically define a constructor + /// that provides a convenient builder, but we need to side-step this when + /// constructing an `OpView` for an already-built operation. + /// + /// The caller is responsible for verifying that `operation` is a valid + /// operation to construct `cls` with. + static nanobind::object constructDerived(const nanobind::object &cls, + const nanobind::object &operation); + +private: + PyOperation &operation; // For efficient, cast-free access from C++ + nanobind::object operationObject; // Holds the reference. +}; + +/// Wrapper around an MlirRegion. +/// Regions are managed completely by their containing operation. Unlike the +/// C++ API, the python API does not support detached regions. +class PyRegion { +public: + PyRegion(PyOperationRef parentOperation, MlirRegion region) + : parentOperation(std::move(parentOperation)), region(region) { + assert(!mlirRegionIsNull(region) && "python region cannot be null"); + } + operator MlirRegion() const { return region; } + + MlirRegion get() { return region; } + PyOperationRef &getParentOperation() { return parentOperation; } + + void checkValid() { return parentOperation->checkValid(); } + +private: + PyOperationRef parentOperation; + MlirRegion region; +}; + +/// Wrapper around an MlirAsmState. +class PyAsmState { +public: + PyAsmState(MlirValue value, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = mlirAsmStateCreateForValue(value, flags); + } + + PyAsmState(PyOperationBase &operation, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = + mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); + } + ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } + // Delete copy constructors. + PyAsmState(PyAsmState &other) = delete; + PyAsmState(const PyAsmState &other) = delete; + + MlirAsmState get() { return state; } + +private: + MlirAsmState state; + MlirOpPrintingFlags flags; +}; + +/// Wrapper around an MlirBlock. +/// Blocks are managed completely by their containing operation. Unlike the +/// C++ API, the python API does not support detached blocks. +class PyBlock { +public: + PyBlock(PyOperationRef parentOperation, MlirBlock block) + : parentOperation(std::move(parentOperation)), block(block) { + assert(!mlirBlockIsNull(block) && "python block cannot be null"); + } + + MlirBlock get() { return block; } + PyOperationRef &getParentOperation() { return parentOperation; } + + void checkValid() { return parentOperation->checkValid(); } + + /// Gets a capsule wrapping the void* within the MlirBlock. + nanobind::object getCapsule(); + +private: + PyOperationRef parentOperation; + MlirBlock block; +}; + +/// An insertion point maintains a pointer to a Block and a reference operation. +/// Calls to insert() will insert a new operation before the +/// reference operation. If the reference operation is null, then appends to +/// the end of the block. +class PyInsertionPoint { +public: + /// Creates an insertion point positioned after the last operation in the + /// block, but still inside the block. + PyInsertionPoint(PyBlock &block); + /// Creates an insertion point positioned before a reference operation. + PyInsertionPoint(PyOperationBase &beforeOperationBase); + + /// Shortcut to create an insertion point at the beginning of the block. + static PyInsertionPoint atBlockBegin(PyBlock &block); + /// Shortcut to create an insertion point before the block terminator. + static PyInsertionPoint atBlockTerminator(PyBlock &block); + + /// Inserts an operation. + void insert(PyOperationBase &operationBase); + + /// Enter and exit the context manager. + static nanobind::object contextEnter(nanobind::object insertionPoint); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); + + PyBlock &getBlock() { return block; } + std::optional &getRefOperation() { return refOperation; } + +private: + // Trampoline constructor that avoids null initializing members while + // looking up parents. + PyInsertionPoint(PyBlock block, std::optional refOperation) + : refOperation(std::move(refOperation)), block(std::move(block)) {} + + std::optional refOperation; + PyBlock block; +}; +/// Wrapper around the generic MlirType. +/// The lifetime of a type is bound by the PyContext that created it. +class PyType : public BaseContextObject { +public: + PyType(PyMlirContextRef contextRef, MlirType type) + : BaseContextObject(std::move(contextRef)), type(type) {} + bool operator==(const PyType &other) const; + operator MlirType() const { return type; } + MlirType get() const { return type; } + + /// Gets a capsule wrapping the void* within the MlirType. + nanobind::object getCapsule(); + + /// Creates a PyType from the MlirType wrapped by a capsule. + /// Note that PyType instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirType + /// is taken by calling this function. + static PyType createFromCapsule(nanobind::object capsule); + +private: + MlirType type; +}; + +/// A TypeID provides an efficient and unique identifier for a specific C++ +/// type. This allows for a C++ type to be compared, hashed, and stored in an +/// opaque context. This class wraps around the generic MlirTypeID. +class PyTypeID { +public: + PyTypeID(MlirTypeID typeID) : typeID(typeID) {} + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the PyTypeID objects are the same (i.e., PyTypeID is a value type). + bool operator==(const PyTypeID &other) const; + operator MlirTypeID() const { return typeID; } + MlirTypeID get() { return typeID; } + + /// Gets a capsule wrapping the void* within the MlirTypeID. + nanobind::object getCapsule(); + + /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. + static PyTypeID createFromCapsule(nanobind::object capsule); + +private: + MlirTypeID typeID; +}; + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = nanobind::class_; + using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); + } + return orig; + } + + static void bind(nanobind::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_type")); + cls.def_static( + "isinstance", + [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }, + nanobind::arg("other")); + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); + }); + cls.def_prop_ro("typeid", [](PyType &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + nanobind::cast(nanobind::cpp_function( + [](PyType pyType) -> DerivedTy { return pyType; }))); + } + + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +/// Wrapper around the generic MlirAttribute. +/// The lifetime of a type is bound by the PyContext that created it. +class PyAttribute : public BaseContextObject { +public: + PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseContextObject(std::move(contextRef)), attr(attr) {} + bool operator==(const PyAttribute &other) const; + operator MlirAttribute() const { return attr; } + MlirAttribute get() const { return attr; } + + /// Gets a capsule wrapping the void* within the MlirAttribute. + nanobind::object getCapsule(); + + /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. + /// Note that PyAttribute instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirAttribute + /// is taken by calling this function. + static PyAttribute createFromCapsule(nanobind::object capsule); + +private: + MlirAttribute attr; +}; + +/// Represents a Python MlirNamedAttr, carrying an optional owned name. +/// TODO: Refactor this and the C-API to be based on an Identifier owned +/// by the context so as to avoid ownership issues here. +class PyNamedAttribute { +public: + /// Constructs a PyNamedAttr that retains an owned name. This should be + /// used in any code that originates an MlirNamedAttribute from a python + /// string. + /// The lifetime of the PyNamedAttr must extend to the lifetime of the + /// passed attribute. + PyNamedAttribute(MlirAttribute attr, std::string ownedName); + + MlirNamedAttribute namedAttr; + +private: + // Since the MlirNamedAttr contains an internal pointer to the actual + // memory of the owned string, it must be heap allocated to remain valid. + // Otherwise, strings that fit within the small object optimization threshold + // will have their memory address change as the containing object is moved, + // resulting in an invalid aliased pointer. + std::unique_ptr ownedName; +}; + +/// CRTP base classes for Python attributes that subclass Attribute and should +/// be castable from it (i.e. via something like StringAttr(attr)). +/// By default, attribute class hierarchies are one level deep (i.e. a +/// concrete attribute class extends PyAttribute); however, intermediate +/// python-visible base classes can be modeled by specifying a BaseTy. +template +class PyConcreteAttribute : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = nanobind::class_; + using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; + + PyConcreteAttribute() = default; + PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseTy(std::move(contextRef), attr) {} + PyConcreteAttribute(PyAttribute &orig) + : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} + + static MlirAttribute castFrom(PyAttribute &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); + } + return orig; + } + + static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { + ClassTy cls; + if (slots) { + cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); + } else { + cls = ClassTy(m, DerivedTy::pyClassName); + } + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_attr")); + cls.def_static( + "isinstance", + [](PyAttribute &otherAttr) -> bool { + return DerivedTy::isaFunction(otherAttr); + }, + nanobind::arg("other")); + cls.def_prop_ro( + "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); + }); + cls.def_prop_ro("typeid", [](PyAttribute &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + nanobind::cast( + nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + }))); + } + + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +/// Wrapper around the generic MlirValue. +/// Values are managed completely by the operation that resulted in their +/// definition. For op result value, this is the operation that defines the +/// value. For block argument values, this is the operation that contains the +/// block to which the value is an argument (blocks cannot be detached in Python +/// bindings so such operation always exists). +class PyValue { +public: + // The virtual here is "load bearing" in that it enables RTTI + // for PyConcreteValue CRTP classes that support maybeDownCast. + // See PyValue::maybeDownCast. + virtual ~PyValue() = default; + PyValue(PyOperationRef parentOperation, MlirValue value) + : parentOperation(std::move(parentOperation)), value(value) {} + operator MlirValue() const { return value; } + + MlirValue get() { return value; } + PyOperationRef &getParentOperation() { return parentOperation; } + + void checkValid() { return parentOperation->checkValid(); } + + /// Gets a capsule wrapping the void* within the MlirValue. + nanobind::object getCapsule(); + + nanobind::object maybeDownCast(); + + /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of + /// the underlying MlirValue is still tied to the owning operation. + static PyValue createFromCapsule(nanobind::object capsule); + +private: + PyOperationRef parentOperation; + MlirValue value; +}; + +/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. +class PyAffineExpr : public BaseContextObject { +public: + PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) + : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} + bool operator==(const PyAffineExpr &other) const; + operator MlirAffineExpr() const { return affineExpr; } + MlirAffineExpr get() const { return affineExpr; } + + /// Gets a capsule wrapping the void* within the MlirAffineExpr. + nanobind::object getCapsule(); + + /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. + /// Note that PyAffineExpr instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr + /// is taken by calling this function. + static PyAffineExpr createFromCapsule(nanobind::object capsule); + + PyAffineExpr add(const PyAffineExpr &other) const; + PyAffineExpr mul(const PyAffineExpr &other) const; + PyAffineExpr floorDiv(const PyAffineExpr &other) const; + PyAffineExpr ceilDiv(const PyAffineExpr &other) const; + PyAffineExpr mod(const PyAffineExpr &other) const; + +private: + MlirAffineExpr affineExpr; +}; + +class PyAffineMap : public BaseContextObject { +public: + PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) + : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} + bool operator==(const PyAffineMap &other) const; + operator MlirAffineMap() const { return affineMap; } + MlirAffineMap get() const { return affineMap; } + + /// Gets a capsule wrapping the void* within the MlirAffineMap. + nanobind::object getCapsule(); + + /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. + /// Note that PyAffineMap instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirAffineMap + /// is taken by calling this function. + static PyAffineMap createFromCapsule(nanobind::object capsule); + +private: + MlirAffineMap affineMap; +}; + +class PyIntegerSet : public BaseContextObject { +public: + PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) + : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} + bool operator==(const PyIntegerSet &other) const; + operator MlirIntegerSet() const { return integerSet; } + MlirIntegerSet get() const { return integerSet; } + + /// Gets a capsule wrapping the void* within the MlirIntegerSet. + nanobind::object getCapsule(); + + /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. + /// Note that PyIntegerSet instances may be uniqued, so the returned object + /// may be a pre-existing object. Integer sets are owned by the context. + static PyIntegerSet createFromCapsule(nanobind::object capsule); + +private: + MlirIntegerSet integerSet; +}; + +/// Bindings for MLIR symbol tables. +class PySymbolTable { +public: + /// Constructs a symbol table for the given operation. + explicit PySymbolTable(PyOperationBase &operation); + + /// Destroys the symbol table. + ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } + + /// Returns the symbol (opview) with the given name, throws if there is no + /// such symbol in the table. + nanobind::object dunderGetItem(const std::string &name); + + /// Removes the given operation from the symbol table and erases it. + void erase(PyOperationBase &symbol); + + /// Removes the operation with the given name from the symbol table and erases + /// it, throws if there is no such symbol in the table. + void dunderDel(const std::string &name); + + /// Inserts the given operation into the symbol table. The operation must have + /// the symbol trait. + MlirAttribute insert(PyOperationBase &symbol); + + /// Gets and sets the name of a symbol op. + static MlirAttribute getSymbolName(PyOperationBase &symbol); + static void setSymbolName(PyOperationBase &symbol, const std::string &name); + + /// Gets and sets the visibility of a symbol op. + static MlirAttribute getVisibility(PyOperationBase &symbol); + static void setVisibility(PyOperationBase &symbol, + const std::string &visibility); + + /// Replaces all symbol uses within an operation. See the API + /// mlirSymbolTableReplaceAllSymbolUses for all caveats. + static void replaceAllSymbolUses(const std::string &oldSymbol, + const std::string &newSymbol, + PyOperationBase &from); + + /// Walks all symbol tables under and including 'from'. + static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, + nanobind::object callback); + + /// Casts the bindings class into the C API structure. + operator MlirSymbolTable() { return symbolTable; } + +private: + PyOperationRef operation; + MlirSymbolTable symbolTable; +}; + +/// Custom exception that allows access to error diagnostic information. This is +/// converted to the `ir.MLIRError` python exception when thrown. +struct MLIRError { + MLIRError(llvm::Twine message, + std::vector &&errorDiagnostics = {}) + : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} + std::string message; + std::vector errorDiagnostics; +}; + +void populateIRAffine(nanobind::module_ &m); +void populateIRAttributes(nanobind::module_ &m); +void populateIRCore(nanobind::module_ &m); +void populateIRInterfaces(nanobind::module_ &m); +void populateIRTypes(nanobind::module_ &m); + +} // namespace python +} // namespace mlir + +namespace nanobind { +namespace detail { + +template <> +struct type_caster + : MlirDefaultingCaster {}; +template <> +struct type_caster + : MlirDefaultingCaster {}; + +} // namespace detail +} // namespace nanobind + +#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp deleted file mode 100644 index 9152fd06d..000000000 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ /dev/null @@ -1,4497 +0,0 @@ -//===- IRModules.cpp - IR Submodules of pybind module ---------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "IRModules.h" - -#include "Globals.h" -#include "PybindUtils.h" - -#include "mlir-c/AffineMap.h" -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" -#include "mlir-c/IntegerSet.h" -#include "mlir-c/Registration.h" -#include "llvm/ADT/SmallVector.h" -#include - -namespace py = pybind11; -using namespace mlir; -using namespace mlir::python; - -using llvm::SmallVector; -using llvm::StringRef; -using llvm::Twine; - -//------------------------------------------------------------------------------ -// Docstrings (trivial, non-duplicated docstrings are included inline). -//------------------------------------------------------------------------------ - -static const char kContextParseTypeDocstring[] = - R"(Parses the assembly form of a type. - -Returns a Type object or raises a ValueError if the type cannot be parsed. - -See also: https://mlir.llvm.org/docs/LangRef/#type-system -)"; - -static const char kContextGetFileLocationDocstring[] = - R"(Gets a Location representing a file, line and column)"; - -static const char kModuleParseDocstring[] = - R"(Parses a module's assembly format from a string. - -Returns a new MlirModule or raises a ValueError if the parsing fails. - -See also: https://mlir.llvm.org/docs/LangRef/ -)"; - -static const char kOperationCreateDocstring[] = - R"(Creates a new operation. - -Args: - name: Operation name (e.g. "dialect.operation"). - results: Sequence of Type representing op result types. - attributes: Dict of str:Attribute. - successors: List of Block for the operation's successors. - regions: Number of regions to create. - location: A Location object (defaults to resolve from context manager). - ip: An InsertionPoint (defaults to resolve from context manager or set to - False to disable insertion, even with an insertion point set in the - context manager). -Returns: - A new "detached" Operation object. Detached operations can be added - to blocks, which causes them to become "attached." -)"; - -static const char kOperationPrintDocstring[] = - R"(Prints the assembly form of the operation to a file like object. - -Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - large_elements_limit: Whether to elide elements attributes above this - number of elements. Defaults to None (no limit). - enable_debug_info: Whether to print debug/location information. Defaults - to False. - pretty_debug_info: Whether to format debug information for easier reading - by a human (warning: the result is unparseable). - print_generic_op_form: Whether to print the generic assembly forms of all - ops. Defaults to False. - use_local_Scope: Whether to print in a way that is more optimized for - multi-threaded access but may not be consistent with how the overall - module prints. -)"; - -static const char kOperationGetAsmDocstring[] = - R"(Gets the assembly form of the operation with all options available. - -Args: - binary: Whether to return a bytes (True) or str (False) object. Defaults to - False. - ... others ...: See the print() method for common keyword arguments for - configuring the printout. -Returns: - Either a bytes or str object, depending on the setting of the 'binary' - argument. -)"; - -static const char kOperationStrDunderDocstring[] = - R"(Gets the assembly form of the operation with default options. - -If more advanced control over the assembly formatting or I/O options is needed, -use the dedicated print or get_asm method, which supports keyword arguments to -customize behavior. -)"; - -static const char kDumpDocstring[] = - R"(Dumps a debug representation of the object to stderr.)"; - -static const char kAppendBlockDocstring[] = - R"(Appends a new block, with argument types as positional args. - -Returns: - The created block. -)"; - -static const char kValueDunderStrDocstring[] = - R"(Returns the string form of the value. - -If the value is a block argument, this is the assembly form of its type and the -position in the argument list. If the value is an operation result, this is -equivalent to printing the operation that produced it. -)"; - -//------------------------------------------------------------------------------ -// Utilities. -//------------------------------------------------------------------------------ - -// Helper for creating an @classmethod. -template -py::object classmethod(Func f, Args... args) { - py::object cf = py::cpp_function(f, args...); - return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); -} - -/// Checks whether the given type is an integer or float type. -static int mlirTypeIsAIntegerOrFloat(MlirType type) { - return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || - mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); -} - -static py::object -createCustomDialectWrapper(const std::string &dialectNamespace, - py::object dialectDescriptor) { - auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); - if (!dialectClass) { - // Use the base class. - return py::cast(PyDialect(std::move(dialectDescriptor))); - } - - // Create the custom implementation. - return (*dialectClass)(std::move(dialectDescriptor)); -} - -static MlirStringRef toMlirStringRef(const std::string &s) { - return mlirStringRefCreate(s.data(), s.size()); -} - -template -static bool isPermutation(std::vector permutation) { - llvm::SmallVector seen(permutation.size(), false); - for (auto val : permutation) { - if (val < permutation.size()) { - if (seen[val]) - return false; - seen[val] = true; - continue; - } - return false; - } - return true; -} - -//------------------------------------------------------------------------------ -// Collections. -//------------------------------------------------------------------------------ - -namespace { - -class PyRegionIterator { -public: - PyRegionIterator(PyOperationRef operation) - : operation(std::move(operation)) {} - - PyRegionIterator &dunderIter() { return *this; } - - PyRegion dunderNext() { - operation->checkValid(); - if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw py::stop_iteration(); - } - MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); - return PyRegion(operation, region); - } - - static void bind(py::module &m) { - py::class_(m, "RegionIterator") - .def("__iter__", &PyRegionIterator::dunderIter) - .def("__next__", &PyRegionIterator::dunderNext); - } - -private: - PyOperationRef operation; - int nextIndex = 0; -}; - -/// Regions of an op are fixed length and indexed numerically so are represented -/// with a sequence-like container. -class PyRegionList { -public: - PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} - - intptr_t dunderLen() { - operation->checkValid(); - return mlirOperationGetNumRegions(operation->get()); - } - - PyRegion dunderGetItem(intptr_t index) { - // dunderLen checks validity. - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); - } - MlirRegion region = mlirOperationGetRegion(operation->get(), index); - return PyRegion(operation, region); - } - - static void bind(py::module &m) { - py::class_(m, "RegionSequence") - .def("__len__", &PyRegionList::dunderLen) - .def("__getitem__", &PyRegionList::dunderGetItem); - } - -private: - PyOperationRef operation; -}; - -class PyBlockIterator { -public: - PyBlockIterator(PyOperationRef operation, MlirBlock next) - : operation(std::move(operation)), next(next) {} - - PyBlockIterator &dunderIter() { return *this; } - - PyBlock dunderNext() { - operation->checkValid(); - if (mlirBlockIsNull(next)) { - throw py::stop_iteration(); - } - - PyBlock returnBlock(operation, next); - next = mlirBlockGetNextInRegion(next); - return returnBlock; - } - - static void bind(py::module &m) { - py::class_(m, "BlockIterator") - .def("__iter__", &PyBlockIterator::dunderIter) - .def("__next__", &PyBlockIterator::dunderNext); - } - -private: - PyOperationRef operation; - MlirBlock next; -}; - -/// Blocks are exposed by the C-API as a forward-only linked list. In Python, -/// we present them as a more full-featured list-like container but optimize -/// it for forward iteration. Blocks are always owned by a region. -class PyBlockList { -public: - PyBlockList(PyOperationRef operation, MlirRegion region) - : operation(std::move(operation)), region(region) {} - - PyBlockIterator dunderIter() { - operation->checkValid(); - return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); - } - - intptr_t dunderLen() { - operation->checkValid(); - intptr_t count = 0; - MlirBlock block = mlirRegionGetFirstBlock(region); - while (!mlirBlockIsNull(block)) { - count += 1; - block = mlirBlockGetNextInRegion(block); - } - return count; - } - - PyBlock dunderGetItem(intptr_t index) { - operation->checkValid(); - if (index < 0) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds block"); - } - MlirBlock block = mlirRegionGetFirstBlock(region); - while (!mlirBlockIsNull(block)) { - if (index == 0) { - return PyBlock(operation, block); - } - block = mlirBlockGetNextInRegion(block); - index -= 1; - } - throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); - } - - PyBlock appendBlock(py::args pyArgTypes) { - operation->checkValid(); - llvm::SmallVector argTypes; - argTypes.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); - mlirRegionAppendOwnedBlock(region, block); - return PyBlock(operation, block); - } - - static void bind(py::module &m) { - py::class_(m, "BlockList") - .def("__getitem__", &PyBlockList::dunderGetItem) - .def("__iter__", &PyBlockList::dunderIter) - .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); - } - -private: - PyOperationRef operation; - MlirRegion region; -}; - -class PyOperationIterator { -public: - PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) - : parentOperation(std::move(parentOperation)), next(next) {} - - PyOperationIterator &dunderIter() { return *this; } - - py::object dunderNext() { - parentOperation->checkValid(); - if (mlirOperationIsNull(next)) { - throw py::stop_iteration(); - } - - PyOperationRef returnOperation = - PyOperation::forOperation(parentOperation->getContext(), next); - next = mlirOperationGetNextInBlock(next); - return returnOperation->createOpView(); - } - - static void bind(py::module &m) { - py::class_(m, "OperationIterator") - .def("__iter__", &PyOperationIterator::dunderIter) - .def("__next__", &PyOperationIterator::dunderNext); - } - -private: - PyOperationRef parentOperation; - MlirOperation next; -}; - -/// Operations are exposed by the C-API as a forward-only linked list. In -/// Python, we present them as a more full-featured list-like container but -/// optimize it for forward iteration. Iterable operations are always owned -/// by a block. -class PyOperationList { -public: - PyOperationList(PyOperationRef parentOperation, MlirBlock block) - : parentOperation(std::move(parentOperation)), block(block) {} - - PyOperationIterator dunderIter() { - parentOperation->checkValid(); - return PyOperationIterator(parentOperation, - mlirBlockGetFirstOperation(block)); - } - - intptr_t dunderLen() { - parentOperation->checkValid(); - intptr_t count = 0; - MlirOperation childOp = mlirBlockGetFirstOperation(block); - while (!mlirOperationIsNull(childOp)) { - count += 1; - childOp = mlirOperationGetNextInBlock(childOp); - } - return count; - } - - py::object dunderGetItem(intptr_t index) { - parentOperation->checkValid(); - if (index < 0) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds operation"); - } - MlirOperation childOp = mlirBlockGetFirstOperation(block); - while (!mlirOperationIsNull(childOp)) { - if (index == 0) { - return PyOperation::forOperation(parentOperation->getContext(), childOp) - ->createOpView(); - } - childOp = mlirOperationGetNextInBlock(childOp); - index -= 1; - } - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds operation"); - } - - static void bind(py::module &m) { - py::class_(m, "OperationList") - .def("__getitem__", &PyOperationList::dunderGetItem) - .def("__iter__", &PyOperationList::dunderIter) - .def("__len__", &PyOperationList::dunderLen); - } - -private: - PyOperationRef parentOperation; - MlirBlock block; -}; - -} // namespace - -//------------------------------------------------------------------------------ -// PyMlirContext -//------------------------------------------------------------------------------ - -PyMlirContext::PyMlirContext(MlirContext context) : context(context) { - py::gil_scoped_acquire acquire; - auto &liveContexts = getLiveContexts(); - liveContexts[context.ptr] = this; -} - -PyMlirContext::~PyMlirContext() { - // Note that the only public way to construct an instance is via the - // forContext method, which always puts the associated handle into - // liveContexts. - py::gil_scoped_acquire acquire; - getLiveContexts().erase(context.ptr); - mlirContextDestroy(context); -} - -py::object PyMlirContext::getCapsule() { - return py::reinterpret_steal(mlirPythonContextToCapsule(get())); -} - -py::object PyMlirContext::createFromCapsule(py::object capsule) { - MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); - if (mlirContextIsNull(rawContext)) - throw py::error_already_set(); - return forContext(rawContext).releaseObject(); -} - -PyMlirContext *PyMlirContext::createNewContextForInit() { - MlirContext context = mlirContextCreate(); - mlirRegisterAllDialects(context); - return new PyMlirContext(context); -} - -PyMlirContextRef PyMlirContext::forContext(MlirContext context) { - py::gil_scoped_acquire acquire; - auto &liveContexts = getLiveContexts(); - auto it = liveContexts.find(context.ptr); - if (it == liveContexts.end()) { - // Create. - PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - py::object pyRef = py::cast(unownedContextWrapper); - assert(pyRef && "cast to py::object failed"); - liveContexts[context.ptr] = unownedContextWrapper; - return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); - } - // Use existing. - py::object pyRef = py::cast(it->second); - return PyMlirContextRef(it->second, std::move(pyRef)); -} - -PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { - static LiveContextMap liveContexts; - return liveContexts; -} - -size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } - -size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } - -size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } - -pybind11::object PyMlirContext::contextEnter() { - return PyThreadContextEntry::pushContext(*this); -} - -void PyMlirContext::contextExit(pybind11::object excType, - pybind11::object excVal, - pybind11::object excTb) { - PyThreadContextEntry::popContext(*this); -} - -PyMlirContext &DefaultingPyMlirContext::resolve() { - PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); - if (!context) { - throw SetPyError( - PyExc_RuntimeError, - "An MLIR function requires a Context but none was provided in the call " - "or from the surrounding environment. Either pass to the function with " - "a 'context=' argument or establish a default using 'with Context():'"); - } - return *context; -} - -//------------------------------------------------------------------------------ -// PyThreadContextEntry management -//------------------------------------------------------------------------------ - -std::vector &PyThreadContextEntry::getStack() { - static thread_local std::vector stack; - return stack; -} - -PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { - auto &stack = getStack(); - if (stack.empty()) - return nullptr; - return &stack.back(); -} - -void PyThreadContextEntry::push(FrameKind frameKind, py::object context, - py::object insertionPoint, - py::object location) { - auto &stack = getStack(); - stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), - std::move(location)); - // If the new stack has more than one entry and the context of the new top - // entry matches the previous, copy the insertionPoint and location from the - // previous entry if missing from the new top entry. - if (stack.size() > 1) { - auto &prev = *(stack.rbegin() + 1); - auto ¤t = stack.back(); - if (current.context.is(prev.context)) { - // Default non-context objects from the previous entry. - if (!current.insertionPoint) - current.insertionPoint = prev.insertionPoint; - if (!current.location) - current.location = prev.location; - } - } -} - -PyMlirContext *PyThreadContextEntry::getContext() { - if (!context) - return nullptr; - return py::cast(context); -} - -PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { - if (!insertionPoint) - return nullptr; - return py::cast(insertionPoint); -} - -PyLocation *PyThreadContextEntry::getLocation() { - if (!location) - return nullptr; - return py::cast(location); -} - -PyMlirContext *PyThreadContextEntry::getDefaultContext() { - auto *tos = getTopOfStack(); - return tos ? tos->getContext() : nullptr; -} - -PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { - auto *tos = getTopOfStack(); - return tos ? tos->getInsertionPoint() : nullptr; -} - -PyLocation *PyThreadContextEntry::getDefaultLocation() { - auto *tos = getTopOfStack(); - return tos ? tos->getLocation() : nullptr; -} - -py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { - py::object contextObj = py::cast(context); - push(FrameKind::Context, /*context=*/contextObj, - /*insertionPoint=*/py::object(), - /*location=*/py::object()); - return contextObj; -} - -void PyThreadContextEntry::popContext(PyMlirContext &context) { - auto &stack = getStack(); - if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); - auto &tos = stack.back(); - if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); - stack.pop_back(); -} - -py::object -PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { - py::object contextObj = - insertionPoint.getBlock().getParentOperation()->getContext().getObject(); - py::object insertionPointObj = py::cast(insertionPoint); - push(FrameKind::InsertionPoint, - /*context=*/contextObj, - /*insertionPoint=*/insertionPointObj, - /*location=*/py::object()); - return insertionPointObj; -} - -void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { - auto &stack = getStack(); - if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced InsertionPoint enter/exit"); - auto &tos = stack.back(); - if (tos.frameKind != FrameKind::InsertionPoint && - tos.getInsertionPoint() != &insertionPoint) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced InsertionPoint enter/exit"); - stack.pop_back(); -} - -py::object PyThreadContextEntry::pushLocation(PyLocation &location) { - py::object contextObj = location.getContext().getObject(); - py::object locationObj = py::cast(location); - push(FrameKind::Location, /*context=*/contextObj, - /*insertionPoint=*/py::object(), - /*location=*/locationObj); - return locationObj; -} - -void PyThreadContextEntry::popLocation(PyLocation &location) { - auto &stack = getStack(); - if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); - auto &tos = stack.back(); - if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); - stack.pop_back(); -} - -//------------------------------------------------------------------------------ -// PyDialect, PyDialectDescriptor, PyDialects -//------------------------------------------------------------------------------ - -MlirDialect PyDialects::getDialectForKey(const std::string &key, - bool attrError) { - // If the "std" dialect was asked for, substitute the empty namespace :( - static const std::string emptyKey; - const std::string *canonKey = key == "std" ? &emptyKey : &key; - MlirDialect dialect = mlirContextGetOrLoadDialect( - getContext()->get(), {canonKey->data(), canonKey->size()}); - if (mlirDialectIsNull(dialect)) { - throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, - Twine("Dialect '") + key + "' not found"); - } - return dialect; -} - -//------------------------------------------------------------------------------ -// PyLocation -//------------------------------------------------------------------------------ - -py::object PyLocation::getCapsule() { - return py::reinterpret_steal(mlirPythonLocationToCapsule(*this)); -} - -PyLocation PyLocation::createFromCapsule(py::object capsule) { - MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); - if (mlirLocationIsNull(rawLoc)) - throw py::error_already_set(); - return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), - rawLoc); -} - -py::object PyLocation::contextEnter() { - return PyThreadContextEntry::pushLocation(*this); -} - -void PyLocation::contextExit(py::object excType, py::object excVal, - py::object excTb) { - PyThreadContextEntry::popLocation(*this); -} - -PyLocation &DefaultingPyLocation::resolve() { - auto *location = PyThreadContextEntry::getDefaultLocation(); - if (!location) { - throw SetPyError( - PyExc_RuntimeError, - "An MLIR function requires a Location but none was provided in the " - "call or from the surrounding environment. Either pass to the function " - "with a 'loc=' argument or establish a default using 'with loc:'"); - } - return *location; -} - -//------------------------------------------------------------------------------ -// PyModule -//------------------------------------------------------------------------------ - -PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) - : BaseContextObject(std::move(contextRef)), module(module) {} - -PyModule::~PyModule() { - py::gil_scoped_acquire acquire; - auto &liveModules = getContext()->liveModules; - assert(liveModules.count(module.ptr) == 1 && - "destroying module not in live map"); - liveModules.erase(module.ptr); - mlirModuleDestroy(module); -} - -PyModuleRef PyModule::forModule(MlirModule module) { - MlirContext context = mlirModuleGetContext(module); - PyMlirContextRef contextRef = PyMlirContext::forContext(context); - - py::gil_scoped_acquire acquire; - auto &liveModules = contextRef->liveModules; - auto it = liveModules.find(module.ptr); - if (it == liveModules.end()) { - // Create. - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - py::object pyRef = - py::cast(unownedModule, py::return_value_policy::take_ownership); - unownedModule->handle = pyRef; - liveModules[module.ptr] = - std::make_pair(unownedModule->handle, unownedModule); - return PyModuleRef(unownedModule, std::move(pyRef)); - } - // Use existing. - PyModule *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); - return PyModuleRef(existing, std::move(pyRef)); -} - -py::object PyModule::createFromCapsule(py::object capsule) { - MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); - if (mlirModuleIsNull(rawModule)) - throw py::error_already_set(); - return forModule(rawModule).releaseObject(); -} - -py::object PyModule::getCapsule() { - return py::reinterpret_steal(mlirPythonModuleToCapsule(get())); -} - -//------------------------------------------------------------------------------ -// PyOperation -//------------------------------------------------------------------------------ - -PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) - : BaseContextObject(std::move(contextRef)), operation(operation) {} - -PyOperation::~PyOperation() { - auto &liveOperations = getContext()->liveOperations; - assert(liveOperations.count(operation.ptr) == 1 && - "destroying operation not in live map"); - liveOperations.erase(operation.ptr); - if (!isAttached()) { - mlirOperationDestroy(operation); - } -} - -PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, - MlirOperation operation, - py::object parentKeepAlive) { - auto &liveOperations = contextRef->liveOperations; - // Create. - PyOperation *unownedOperation = - new PyOperation(std::move(contextRef), operation); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - py::object pyRef = - py::cast(unownedOperation, py::return_value_policy::take_ownership); - unownedOperation->handle = pyRef; - if (parentKeepAlive) { - unownedOperation->parentKeepAlive = std::move(parentKeepAlive); - } - liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); - return PyOperationRef(unownedOperation, std::move(pyRef)); -} - -PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, - MlirOperation operation, - py::object parentKeepAlive) { - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - return createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - } - // Use existing. - PyOperation *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); -} - -PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, - MlirOperation operation, - py::object parentKeepAlive) { - auto &liveOperations = contextRef->liveOperations; - assert(liveOperations.count(operation.ptr) == 0 && - "cannot create detached operation that already exists"); - (void)liveOperations; - - PyOperationRef created = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - created->attached = false; - return created; -} - -void PyOperation::checkValid() const { - if (!valid) { - throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); - } -} - -void PyOperationBase::print(py::object fileObject, bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope) { - PyOperation &operation = getOperation(); - operation.checkValid(); - if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); - - if (!printGenericOpForm && !mlirOperationVerify(operation)) { - fileObject.attr("write")("// Verification failed, printing generic form\n"); - printGenericOpForm = true; - } - - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (largeElementsLimit) - mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); - if (enableDebugInfo) - mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); - if (printGenericOpForm) - mlirOpPrintingFlagsPrintGenericOpForm(flags); - - PyFileAccumulator accum(fileObject, binary); - py::gil_scoped_release(); - mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), - accum.getUserData()); - mlirOpPrintingFlagsDestroy(flags); -} - -py::object PyOperationBase::getAsm(bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, - bool useLocalScope) { - py::object fileObject; - if (binary) { - fileObject = py::module::import("io").attr("BytesIO")(); - } else { - fileObject = py::module::import("io").attr("StringIO")(); - } - print(fileObject, /*binary=*/binary, - /*largeElementsLimit=*/largeElementsLimit, - /*enableDebugInfo=*/enableDebugInfo, - /*prettyDebugInfo=*/prettyDebugInfo, - /*printGenericOpForm=*/printGenericOpForm, - /*useLocalScope=*/useLocalScope); - - return fileObject.attr("getvalue")(); -} - -PyOperationRef PyOperation::getParentOperation() { - if (!isAttached()) - throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); - MlirOperation operation = mlirOperationGetParentOperation(get()); - if (mlirOperationIsNull(operation)) - throw SetPyError(PyExc_ValueError, "Operation has no parent."); - return PyOperation::forOperation(getContext(), operation); -} - -PyBlock PyOperation::getBlock() { - PyOperationRef parentOperation = getParentOperation(); - MlirBlock block = mlirOperationGetBlock(get()); - assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); - return PyBlock{std::move(parentOperation), block}; -} - -py::object PyOperation::create( - std::string name, llvm::Optional> results, - llvm::Optional> operands, - llvm::Optional attributes, - llvm::Optional> successors, int regions, - DefaultingPyLocation location, py::object maybeIp) { - llvm::SmallVector mlirOperands; - llvm::SmallVector mlirResults; - llvm::SmallVector mlirSuccessors; - llvm::SmallVector, 4> mlirAttributes; - - // General parameter validation. - if (regions < 0) - throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); - - // Unpack/validate operands. - if (operands) { - mlirOperands.reserve(operands->size()); - for (PyValue *operand : *operands) { - if (!operand) - throw SetPyError(PyExc_ValueError, "operand value cannot be None"); - mlirOperands.push_back(operand->get()); - } - } - - // Unpack/validate results. - if (results) { - mlirResults.reserve(results->size()); - for (PyType *result : *results) { - // TODO: Verify result type originate from the same context. - if (!result) - throw SetPyError(PyExc_ValueError, "result type cannot be None"); - mlirResults.push_back(*result); - } - } - // Unpack/validate attributes. - if (attributes) { - mlirAttributes.reserve(attributes->size()); - for (auto &it : *attributes) { - std::string key; - try { - key = it.first.cast(); - } catch (py::cast_error &err) { - std::string msg = "Invalid attribute key (not a string) when " - "attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); - } - try { - auto &attribute = it.second.cast(); - // TODO: Verify attribute originates from the same context. - mlirAttributes.emplace_back(std::move(key), attribute); - } catch (py::reference_cast_error &) { - // This exception seems thrown when the value is "None". - std::string msg = - "Found an invalid (`None`?) attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + name + "\""; - throw py::cast_error(msg); - } catch (py::cast_error &err) { - std::string msg = "Invalid attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); - } - } - } - // Unpack/validate successors. - if (successors) { - llvm::SmallVector mlirSuccessors; - mlirSuccessors.reserve(successors->size()); - for (auto *successor : *successors) { - // TODO: Verify successor originate from the same context. - if (!successor) - throw SetPyError(PyExc_ValueError, "successor block cannot be None"); - mlirSuccessors.push_back(successor->get()); - } - } - - // Apply unpacked/validated to the operation state. Beyond this - // point, exceptions cannot be thrown or else the state will leak. - MlirOperationState state = - mlirOperationStateGet(toMlirStringRef(name), location); - if (!mlirOperands.empty()) - mlirOperationStateAddOperands(&state, mlirOperands.size(), - mlirOperands.data()); - if (!mlirResults.empty()) - mlirOperationStateAddResults(&state, mlirResults.size(), - mlirResults.data()); - if (!mlirAttributes.empty()) { - // Note that the attribute names directly reference bytes in - // mlirAttributes, so that vector must not be changed from here - // on. - llvm::SmallVector mlirNamedAttributes; - mlirNamedAttributes.reserve(mlirAttributes.size()); - for (auto &it : mlirAttributes) - mlirNamedAttributes.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(it.second), - toMlirStringRef(it.first)), - it.second)); - mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), - mlirNamedAttributes.data()); - } - if (!mlirSuccessors.empty()) - mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), - mlirSuccessors.data()); - if (regions) { - llvm::SmallVector mlirRegions; - mlirRegions.resize(regions); - for (int i = 0; i < regions; ++i) - mlirRegions[i] = mlirRegionCreate(); - mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), - mlirRegions.data()); - } - - // Construct the operation. - MlirOperation operation = mlirOperationCreate(&state); - PyOperationRef created = - PyOperation::createDetached(location->getContext(), operation); - - // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { - PyInsertionPoint *ip; - if (maybeIp.is_none()) { - ip = PyThreadContextEntry::getDefaultInsertionPoint(); - } else { - ip = py::cast(maybeIp); - } - if (ip) - ip->insert(*created.get()); - } - - return created->createOpView(); -} - -py::object PyOperation::createOpView() { - MlirIdentifier ident = mlirOperationGetName(get()); - MlirStringRef identStr = mlirIdentifierStr(ident); - auto opViewClass = PyGlobals::get().lookupRawOpViewClass( - StringRef(identStr.data, identStr.length)); - if (opViewClass) - return (*opViewClass)(getRef().getObject()); - return py::cast(PyOpView(getRef().getObject())); -} - -//------------------------------------------------------------------------------ -// PyOpView -//------------------------------------------------------------------------------ - -py::object -PyOpView::buildGeneric(py::object cls, py::list resultTypeList, - py::list operandList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, - DefaultingPyLocation location, py::object maybeIp) { - PyMlirContextRef context = location->getContext(); - // Class level operation construction metadata. - std::string name = py::cast(cls.attr("OPERATION_NAME")); - // Operand and result segment specs are either none, which does no - // variadic unpacking, or a list of ints with segment sizes, where each - // element is either a positive number (typically 1 for a scalar) or -1 to - // indicate that it is derived from the length of the same-indexed operand - // or result (implying that it is a list at that position). - py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); - - std::vector operandSegmentLengths; - std::vector resultSegmentLengths; - - // Validate/determine region count. - auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); - int opMinRegionCount = std::get<0>(opRegionSpec); - bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); - if (!regions) { - regions = opMinRegionCount; - } - if (*regions < opMinRegionCount) { - throw py::value_error( - (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + - llvm::Twine(opMinRegionCount) + - " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); - } - if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw py::value_error( - (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + - llvm::Twine(opMinRegionCount) + - " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); - } - - // Unpack results. - std::vector resultTypes; - resultTypes.reserve(resultTypeList.size()); - if (resultSegmentSpecObj.is_none()) { - // Non-variadic result unpacking. - for (auto it : llvm::enumerate(resultTypeList)) { - try { - resultTypes.push_back(py::cast(it.value())); - if (!resultTypes.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Type (" + err.what() + ")") - .str()); - } - } - } else { - // Sized result unpacking. - auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); - if (resultSegmentSpec.size() != resultTypeList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + - "\" requires " + - llvm::Twine(resultSegmentSpec.size()) + - "result segments but was provided " + - llvm::Twine(resultTypeList.size())) - .str()); - } - resultSegmentLengths.reserve(resultTypeList.size()); - for (auto it : - llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { - int segmentSpec = std::get<1>(it.value()); - if (segmentSpec == 1 || segmentSpec == 0) { - // Unpack unary element. - try { - auto resultType = py::cast(std::get<0>(it.value())); - if (resultType) { - resultTypes.push_back(resultType); - resultSegmentLengths.push_back(1); - } else if (segmentSpec == 0) { - // Allowed to be optional. - resultSegmentLengths.push_back(0); - } else { - throw py::cast_error("was None and result is not optional"); - } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Type (" + err.what() + - ")") - .str()); - } - } else if (segmentSpec == -1) { - // Unpack sequence by appending. - try { - if (std::get<0>(it.value()).is_none()) { - // Treat it as an empty list. - resultSegmentLengths.push_back(0); - } else { - // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - resultTypes.push_back(py::cast(segmentItem)); - if (!resultTypes.back()) { - throw py::cast_error("contained a None item"); - } - } - resultSegmentLengths.push_back(segment.size()); - } - } catch (std::exception &err) { - // NOTE: Sloppy to be using a catch-all here, but there are at least - // three different unrelated exceptions that can be thrown in the - // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Result ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Sequence of Types (" + - err.what() + ")") - .str()); - } - } else { - throw py::value_error("Unexpected segment spec"); - } - } - } - - // Unpack operands. - std::vector operands; - operands.reserve(operands.size()); - if (operandSegmentSpecObj.is_none()) { - // Non-sized operand unpacking. - for (auto it : llvm::enumerate(operandList)) { - try { - operands.push_back(py::cast(it.value())); - if (!operands.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Value (" + err.what() + ")") - .str()); - } - } - } else { - // Sized operand unpacking. - auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); - if (operandSegmentSpec.size() != operandList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + - "\" requires " + - llvm::Twine(operandSegmentSpec.size()) + - "operand segments but was provided " + - llvm::Twine(operandList.size())) - .str()); - } - operandSegmentLengths.reserve(operandList.size()); - for (auto it : - llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { - int segmentSpec = std::get<1>(it.value()); - if (segmentSpec == 1 || segmentSpec == 0) { - // Unpack unary element. - try { - auto operandValue = py::cast(std::get<0>(it.value())); - if (operandValue) { - operands.push_back(operandValue); - operandSegmentLengths.push_back(1); - } else if (segmentSpec == 0) { - // Allowed to be optional. - operandSegmentLengths.push_back(0); - } else { - throw py::cast_error("was None and operand is not optional"); - } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Value (" + err.what() + - ")") - .str()); - } - } else if (segmentSpec == -1) { - // Unpack sequence by appending. - try { - if (std::get<0>(it.value()).is_none()) { - // Treat it as an empty list. - operandSegmentLengths.push_back(0); - } else { - // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - operands.push_back(py::cast(segmentItem)); - if (!operands.back()) { - throw py::cast_error("contained a None item"); - } - } - operandSegmentLengths.push_back(segment.size()); - } - } catch (std::exception &err) { - // NOTE: Sloppy to be using a catch-all here, but there are at least - // three different unrelated exceptions that can be thrown in the - // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Operand ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Sequence of Values (" + - err.what() + ")") - .str()); - } - } else { - throw py::value_error("Unexpected segment spec"); - } - } - } - - // Merge operand/result segment lengths into attributes if needed. - if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { - // Dup. - if (attributes) { - attributes = py::dict(*attributes); - } else { - attributes = py::dict(); - } - if (attributes->contains("result_segment_sizes") || - attributes->contains("operand_segment_sizes")) { - throw py::value_error("Manually setting a 'result_segment_sizes' or " - "'operand_segment_sizes' attribute is unsupported. " - "Use Operation.create for such low-level access."); - } - - // Add result_segment_sizes attribute. - if (!resultSegmentLengths.empty()) { - int64_t size = resultSegmentLengths.size(); - MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( - mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)), - resultSegmentLengths.size(), resultSegmentLengths.data()); - (*attributes)["result_segment_sizes"] = - PyAttribute(context, segmentLengthAttr); - } - - // Add operand_segment_sizes attribute. - if (!operandSegmentLengths.empty()) { - int64_t size = operandSegmentLengths.size(); - MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( - mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)), - operandSegmentLengths.size(), operandSegmentLengths.data()); - (*attributes)["operand_segment_sizes"] = - PyAttribute(context, segmentLengthAttr); - } - } - - // Delegate to create. - return PyOperation::create(std::move(name), - /*results=*/std::move(resultTypes), - /*operands=*/std::move(operands), - /*attributes=*/std::move(attributes), - /*successors=*/std::move(successors), - /*regions=*/*regions, location, maybeIp); -} - -PyOpView::PyOpView(py::object operationObject) - // Casting through the PyOperationBase base-class and then back to the - // Operation lets us accept any PyOperationBase subclass. - : operation(py::cast(operationObject).getOperation()), - operationObject(operation.getRef().getObject()) {} - -py::object PyOpView::createRawSubclass(py::object userClass) { - // This is... a little gross. The typical pattern is to have a pure python - // class that extends OpView like: - // class AddFOp(_cext.ir.OpView): - // def __init__(self, loc, lhs, rhs): - // operation = loc.context.create_operation( - // "addf", lhs, rhs, results=[lhs.type]) - // super().__init__(operation) - // - // I.e. The goal of the user facing type is to provide a nice constructor - // that has complete freedom for the op under construction. This is at odds - // with our other desire to sometimes create this object by just passing an - // operation (to initialize the base class). We could do *arg and **kwargs - // munging to try to make it work, but instead, we synthesize a new class - // on the fly which extends this user class (AddFOp in this example) and - // *give it* the base class's __init__ method, thus bypassing the - // intermediate subclass's __init__ method entirely. While slightly, - // underhanded, this is safe/legal because the type hierarchy has not changed - // (we just added a new leaf) and we aren't mucking around with __new__. - // Typically, this new class will be stored on the original as "_Raw" and will - // be used for casts and other things that need a variant of the class that - // is initialized purely from an operation. - py::object parentMetaclass = - py::reinterpret_borrow((PyObject *)&PyType_Type); - py::dict attributes; - // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from - // now. - // auto opViewType = py::type::of(); - auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); - attributes["__init__"] = opViewType.attr("__init__"); - py::str origName = userClass.attr("__name__"); - py::str newName = py::str("_") + origName; - return parentMetaclass(newName, py::make_tuple(userClass), attributes); -} - -//------------------------------------------------------------------------------ -// PyInsertionPoint. -//------------------------------------------------------------------------------ - -PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} - -PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) - : refOperation(beforeOperationBase.getOperation().getRef()), - block((*refOperation)->getBlock()) {} - -void PyInsertionPoint::insert(PyOperationBase &operationBase) { - PyOperation &operation = operationBase.getOperation(); - if (operation.isAttached()) - throw SetPyError(PyExc_ValueError, - "Attempt to insert operation that is already attached"); - block.getParentOperation()->checkValid(); - MlirOperation beforeOp = {nullptr}; - if (refOperation) { - // Insert before operation. - (*refOperation)->checkValid(); - beforeOp = (*refOperation)->get(); - } else { - // Insert at end (before null) is only valid if the block does not - // already end in a known terminator (violating this will cause assertion - // failures later). - if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { - throw py::index_error("Cannot insert operation at the end of a block " - "that already has a terminator. Did you mean to " - "use 'InsertionPoint.at_block_terminator(block)' " - "versus 'InsertionPoint(block)'?"); - } - } - mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); - operation.setAttached(); -} - -PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { - MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); - if (mlirOperationIsNull(firstOp)) { - // Just insert at end. - return PyInsertionPoint(block); - } - - // Insert before first op. - PyOperationRef firstOpRef = PyOperation::forOperation( - block.getParentOperation()->getContext(), firstOp); - return PyInsertionPoint{block, std::move(firstOpRef)}; -} - -PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { - MlirOperation terminator = mlirBlockGetTerminator(block.get()); - if (mlirOperationIsNull(terminator)) - throw SetPyError(PyExc_ValueError, "Block has no terminator"); - PyOperationRef terminatorOpRef = PyOperation::forOperation( - block.getParentOperation()->getContext(), terminator); - return PyInsertionPoint{block, std::move(terminatorOpRef)}; -} - -py::object PyInsertionPoint::contextEnter() { - return PyThreadContextEntry::pushInsertionPoint(*this); -} - -void PyInsertionPoint::contextExit(pybind11::object excType, - pybind11::object excVal, - pybind11::object excTb) { - PyThreadContextEntry::popInsertionPoint(*this); -} - -//------------------------------------------------------------------------------ -// PyAttribute. -//------------------------------------------------------------------------------ - -bool PyAttribute::operator==(const PyAttribute &other) { - return mlirAttributeEqual(attr, other.attr); -} - -py::object PyAttribute::getCapsule() { - return py::reinterpret_steal(mlirPythonAttributeToCapsule(*this)); -} - -PyAttribute PyAttribute::createFromCapsule(py::object capsule) { - MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); - if (mlirAttributeIsNull(rawAttr)) - throw py::error_already_set(); - return PyAttribute( - PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); -} - -//------------------------------------------------------------------------------ -// PyNamedAttribute. -//------------------------------------------------------------------------------ - -PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) - : ownedName(new std::string(std::move(ownedName))) { - namedAttr = mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(attr), - toMlirStringRef(*this->ownedName)), - attr); -} - -//------------------------------------------------------------------------------ -// PyType. -//------------------------------------------------------------------------------ - -bool PyType::operator==(const PyType &other) { - return mlirTypeEqual(type, other.type); -} - -py::object PyType::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeToCapsule(*this)); -} - -PyType PyType::createFromCapsule(py::object capsule) { - MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); - if (mlirTypeIsNull(rawType)) - throw py::error_already_set(); - return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), - rawType); -} - -//------------------------------------------------------------------------------ -// PyValue and subclases. -//------------------------------------------------------------------------------ - -namespace { -/// CRTP base class for Python MLIR values that subclass Value and should be -/// castable from it. The value hierarchy is one level deep and is not supposed -/// to accommodate other levels unless core MLIR changes. -template class PyConcreteValue : public PyValue { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirValue); - - PyConcreteValue() = default; - PyConcreteValue(PyOperationRef operationRef, MlirValue value) - : PyValue(operationRef, value) {} - PyConcreteValue(PyValue &orig) - : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} - - /// Attempts to cast the original value to the derived type and throws on - /// type mismatches. - static MlirValue castFrom(PyValue &orig) { - if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig.get(); - } - - /// Binds the Python module objects to functions of this class. - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -/// Python wrapper for MlirBlockArgument. -class PyBlockArgument : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; - static constexpr const char *pyClassName = "BlockArgument"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyBlockArgument &self) { - return PyBlock(self.getParentOperation(), - mlirBlockArgumentGetOwner(self.get())); - }); - c.def_property_readonly("arg_number", [](PyBlockArgument &self) { - return mlirBlockArgumentGetArgNumber(self.get()); - }); - c.def("set_type", [](PyBlockArgument &self, PyType type) { - return mlirBlockArgumentSetType(self.get(), type); - }); - } -}; - -/// Python wrapper for MlirOpResult. -class PyOpResult : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; - static constexpr const char *pyClassName = "OpResult"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyOpResult &self) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation(); - }); - c.def_property_readonly("result_number", [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }); - } -}; - -/// A list of block arguments. Internally, these are stored as consecutive -/// elements, random access is cheap. The argument list is associated with the -/// operation that contains the block (detached blocks are not allowed in -/// Python bindings) and extends its lifetime. -class PyBlockArgumentList { -public: - PyBlockArgumentList(PyOperationRef operation, MlirBlock block) - : operation(std::move(operation)), block(block) {} - - /// Returns the length of the block argument list. - intptr_t dunderLen() { - operation->checkValid(); - return mlirBlockGetNumArguments(block); - } - - /// Returns `index`-th element of the block argument list. - PyBlockArgument dunderGetItem(intptr_t index) { - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); - } - PyValue value(operation, mlirBlockGetArgument(block, index)); - return PyBlockArgument(value); - } - - /// Defines a Python class in the bindings. - static void bind(py::module &m) { - py::class_(m, "BlockArgumentList") - .def("__len__", &PyBlockArgumentList::dunderLen) - .def("__getitem__", &PyBlockArgumentList::dunderGetItem); - } - -private: - PyOperationRef operation; - MlirBlock block; -}; - -/// A list of operation operands. Internally, these are stored as consecutive -/// elements, random access is cheap. The result list is associated with the -/// operation whose results these are, and extends the lifetime of this -/// operation. -class PyOpOperandList : public Sliceable { -public: - static constexpr const char *pyClassName = "OpOperandList"; - - PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumOperands(operation->get()) - : length, - step), - operation(operation) {} - - intptr_t getNumElements() { - operation->checkValid(); - return mlirOperationGetNumOperands(operation->get()); - } - - PyValue getElement(intptr_t pos) { - return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); - } - - PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpOperandList(operation, startIndex, length, step); - } - -private: - PyOperationRef operation; -}; - -/// A list of operation results. Internally, these are stored as consecutive -/// elements, random access is cheap. The result list is associated with the -/// operation whose results these are, and extends the lifetime of this -/// operation. -class PyOpResultList : public Sliceable { -public: - static constexpr const char *pyClassName = "OpResultList"; - - PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumResults(operation->get()) - : length, - step), - operation(operation) {} - - intptr_t getNumElements() { - operation->checkValid(); - return mlirOperationGetNumResults(operation->get()); - } - - PyOpResult getElement(intptr_t index) { - PyValue value(operation, mlirOperationGetResult(operation->get(), index)); - return PyOpResult(value); - } - - PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpResultList(operation, startIndex, length, step); - } - -private: - PyOperationRef operation; -}; - -/// A list of operation attributes. Can be indexed by name, producing -/// attributes, or by index, producing named attributes. -class PyOpAttributeMap { -public: - PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} - - PyAttribute dunderGetItemNamed(const std::string &name) { - MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), - toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); - } - return PyAttribute(operation->getContext(), attr); - } - - PyNamedAttribute dunderGetItemIndexed(intptr_t index) { - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); - } - MlirNamedAttribute namedAttr = - mlirOperationGetAttribute(operation->get(), index); - return PyNamedAttribute( - namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); - } - - void dunderSetItem(const std::string &name, PyAttribute attr) { - mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), - attr); - } - - void dunderDelItem(const std::string &name) { - int removed = mlirOperationRemoveAttributeByName(operation->get(), - toMlirStringRef(name)); - if (!removed) - throw SetPyError(PyExc_KeyError, - "attempt to delete a non-existent attribute"); - } - - intptr_t dunderLen() { - return mlirOperationGetNumAttributes(operation->get()); - } - - bool dunderContains(const std::string &name) { - return !mlirAttributeIsNull(mlirOperationGetAttributeByName( - operation->get(), toMlirStringRef(name))); - } - - static void bind(py::module &m) { - py::class_(m, "OpAttributeMap") - .def("__contains__", &PyOpAttributeMap::dunderContains) - .def("__len__", &PyOpAttributeMap::dunderLen) - .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) - .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) - .def("__setitem__", &PyOpAttributeMap::dunderSetItem) - .def("__delitem__", &PyOpAttributeMap::dunderDelItem); - } - -private: - PyOperationRef operation; -}; - -} // end namespace - -//------------------------------------------------------------------------------ -// Builtin attribute subclasses. -//------------------------------------------------------------------------------ - -namespace { - -/// CRTP base classes for Python attributes that subclass Attribute and should -/// be castable from it (i.e. via something like StringAttr(attr)). -/// By default, attribute class hierarchies are one level deep (i.e. a -/// concrete attribute class extends PyAttribute); however, intermediate -/// python-visible base classes can be modeled by specifying a BaseTy. -template -class PyConcreteAttribute : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAttribute); - - PyConcreteAttribute() = default; - PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) - : BaseTy(std::move(contextRef), attr) {} - PyConcreteAttribute(PyAttribute &orig) - : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} - - static MlirAttribute castFrom(PyAttribute &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyAffineMapAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; - static constexpr const char *pyClassName = "AffineMapAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyAffineMap &affineMap) { - MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); - return PyAffineMapAttribute(affineMap.getContext(), attr); - }, - py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - } -}; - -class PyArrayAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; - static constexpr const char *pyClassName = "ArrayAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - class PyArrayAttributeIterator { - public: - PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} - - PyArrayAttributeIterator &dunderIter() { return *this; } - - PyAttribute dunderNext() { - if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { - throw py::stop_iteration(); - } - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); - } - - static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator") - .def("__iter__", &PyArrayAttributeIterator::dunderIter) - .def("__next__", &PyArrayAttributeIterator::dunderNext); - } - - private: - PyAttribute attr; - int nextIndex = 0; - }; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](py::list attributes, DefaultingPyMlirContext context) { - SmallVector mlirAttributes; - mlirAttributes.reserve(py::len(attributes)); - for (auto attribute : attributes) { - try { - mlirAttributes.push_back(attribute.cast()); - } catch (py::cast_error &err) { - std::string msg = std::string("Invalid attribute when attempting " - "to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - // This exception seems thrown when the value is "None". - std::string msg = - std::string("Invalid attribute (None?) when attempting to " - "create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } - } - MlirAttribute attr = mlirArrayAttrGet( - context->get(), mlirAttributes.size(), mlirAttributes.data()); - return PyArrayAttribute(context->getRef(), attr); - }, - py::arg("attributes"), py::arg("context") = py::none(), - "Gets a uniqued Array attribute"); - c.def("__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { - if (i >= mlirArrayAttrGetNumElements(arr)) - throw py::index_error("ArrayAttribute index out of range"); - return PyAttribute(arr.getContext(), - mlirArrayAttrGetElement(arr, i)); - }) - .def("__len__", - [](const PyArrayAttribute &arr) { - return mlirArrayAttrGetNumElements(arr); - }) - .def("__iter__", [](const PyArrayAttribute &arr) { - return PyArrayAttributeIterator(arr); - }); - } -}; - -/// Float Point Attribute subclass - FloatAttr. -class PyFloatAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; - static constexpr const char *pyClassName = "FloatAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, double value, DefaultingPyLocation loc) { - MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(type)).cast() + - "' and expected floating point type."); - } - return PyFloatAttribute(type.getContext(), attr); - }, - py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), - "Gets an uniqued float point attribute associated to a type"); - c.def_static( - "get_f32", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF32TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued float point attribute associated to a f32 type"); - c.def_static( - "get_f64", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF64TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly( - "value", - [](PyFloatAttribute &self) { - return mlirFloatAttrGetValueDouble(self); - }, - "Returns the value of the float point attribute"); - } -}; - -/// Integer Attribute subclass - IntegerAttr. -class PyIntegerAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; - static constexpr const char *pyClassName = "IntegerAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, int64_t value) { - MlirAttribute attr = mlirIntegerAttrGet(type, value); - return PyIntegerAttribute(type.getContext(), attr); - }, - py::arg("type"), py::arg("value"), - "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyIntegerAttribute &self) { - return mlirIntegerAttrGetValueInt(self); - }, - "Returns the value of the integer attribute"); - } -}; - -/// Bool Attribute subclass - BoolAttr. -class PyBoolAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; - static constexpr const char *pyClassName = "BoolAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](bool value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirBoolAttrGet(context->get(), value); - return PyBoolAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued bool attribute"); - c.def_property_readonly( - "value", - [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, - "Returns the value of the bool attribute"); - } -}; - -class PyFlatSymbolRefAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; - static constexpr const char *pyClassName = "FlatSymbolRefAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); - return PyFlatSymbolRefAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued FlatSymbolRef attribute"); - c.def_property_readonly( - "value", - [](PyFlatSymbolRefAttribute &self) { - MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the FlatSymbolRef attribute as a string"); - } -}; - -class PyStringAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; - static constexpr const char *pyClassName = "StringAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get_typed", - [](PyType &type, std::string value) { - MlirAttribute attr = - mlirStringAttrTypedGet(type, toMlirStringRef(value)); - return PyStringAttribute(type.getContext(), attr); - }, - - "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); - } -}; - -// TODO: Support construction of bool elements. -// TODO: Support construction of string elements. -class PyDenseElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; - static constexpr const char *pyClassName = "DenseElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, - DefaultingPyMlirContext contextWrapper) { - // Request a contiguous view. In exotic cases, this will cause a copy. - int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; - Py_buffer *view = new Py_buffer(); - if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { - delete view; - throw py::error_already_set(); - } - py::buffer_info arrayInfo(view); - - MlirContext context = contextWrapper->get(); - // Switch on the types that can be bulk loaded between the Python and - // MLIR-C APIs. - // See: https://docs.python.org/3/library/struct.html#format-characters - if (arrayInfo.format == "f") { - // f32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrFloatGet, - mlirF32TypeGet(context), arrayInfo)); - } else if (arrayInfo.format == "d") { - // f64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrDoubleGet, - mlirF64TypeGet(context), arrayInfo)); - } else if (isSignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // i32 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.itemsize == 8) { - // i64 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt64Get, - elementType, arrayInfo)); - } - } else if (isUnsignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // unsigned i32 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.itemsize == 8) { - // unsigned i64 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt64Get, - elementType, arrayInfo)); - } - } - - // TODO: Fall back to string-based get. - std::string message = "unimplemented array format conversion from format: "; - message.append(arrayInfo.format); - throw SetPyError(PyExc_ValueError, message); - } - - static PyDenseElementsAttribute getSplat(PyType shapedType, - PyAttribute &elementAttr) { - auto contextWrapper = - PyMlirContext::forContext(mlirTypeGetContext(shapedType)); - if (!mlirAttributeIsAInteger(elementAttr) && - !mlirAttributeIsAFloat(elementAttr)) { - std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); - } - if (!mlirTypeIsAShaped(shapedType) || - !mlirShapedTypeHasStaticShape(shapedType)) { - std::string message = - "Expected a static ShapedType for the shaped_type parameter: "; - message.append(py::repr(py::cast(shapedType))); - throw SetPyError(PyExc_ValueError, message); - } - MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); - MlirType attrType = mlirAttributeGetType(elementAttr); - if (!mlirTypeEqual(shapedElementType, attrType)) { - std::string message = - "Shaped element type and attribute type must be equal: shaped="; - message.append(py::repr(py::cast(shapedType))); - message.append(", element="); - message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); - } - - MlirAttribute elements = - mlirDenseElementsAttrSplatGet(shapedType, elementAttr); - return PyDenseElementsAttribute(contextWrapper->getRef(), elements); - } - - intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - - py::buffer_info accessBuffer() { - MlirType shapedType = mlirAttributeGetType(*this); - MlirType elementType = mlirShapedTypeGetElementType(shapedType); - - if (mlirTypeIsAF32(elementType)) { - // f32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); - } else if (mlirTypeIsAF64(elementType)) { - // f64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 32) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 64) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); - } - } - - std::string message = "unimplemented array format."; - throw SetPyError(PyExc_ValueError, message); - } - - static void bindDerived(ClassTy &c) { - c.def("__len__", &PyDenseElementsAttribute::dunderLen) - .def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("context") = py::none(), - "Gets from a buffer or ndarray") - .def_static("get_splat", PyDenseElementsAttribute::getSplat, - py::arg("shaped_type"), py::arg("element_attr"), - "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def_buffer(&PyDenseElementsAttribute::accessBuffer); - } - -private: - template - static MlirAttribute - bulkLoad(MlirContext context, - MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), - MlirType mlirElementType, py::buffer_info &arrayInfo) { - SmallVector shape(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); - auto shapedType = - mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); - intptr_t numElements = arrayInfo.size; - const ElementTy *contents = static_cast(arrayInfo.ptr); - return ctor(shapedType, numElements, contents); - } - - static bool isUnsignedIntegerFormat(const std::string &format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'I' || code == 'B' || code == 'H' || code == 'L' || - code == 'Q'; - } - - static bool isSignedIntegerFormat(const std::string &format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'i' || code == 'b' || code == 'h' || code == 'l' || - code == 'q'; - } - - template - py::buffer_info bufferInfo(MlirType shapedType, - Type (*value)(MlirAttribute, intptr_t)) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); - // Prepare the data for the buffer_info. - // Buffer is configured for read-only access below. - Type *data = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(*this))); - // Prepare the shape for the buffer_info. - SmallVector shape; - for (intptr_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); - // Prepare the strides for the buffer_info. - SmallVector strides; - intptr_t strideFactor = 1; - for (intptr_t i = 1; i < rank; ++i) { - strideFactor = 1; - for (intptr_t j = i; j < rank; ++j) { - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); - } - strides.push_back(sizeof(Type) * strideFactor); - } - strides.push_back(sizeof(Type)); - return py::buffer_info(data, sizeof(Type), - py::format_descriptor::format(), rank, shape, - strides, /*readonly=*/true); - } -}; // namespace - -/// Refinement of the PyDenseElementsAttribute for attributes containing integer -/// (and boolean) values. Supports element access. -class PyDenseIntElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; - static constexpr const char *pyClassName = "DenseIntElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - /// Returns the element at the given linear position. Asserts if the index is - /// out of range. - py::int_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); - } - - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - assert(mlirTypeIsAInteger(type) && - "expected integer element type in dense int elements attribute"); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::int_ is implicitly constructible - // from any C++ integral type and handles bitwidth correctly. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - unsigned width = mlirIntegerTypeGetWidth(type); - bool isUnsigned = mlirIntegerTypeIsUnsigned(type); - if (isUnsigned) { - if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); - } - if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(*this, pos); - } - if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(*this, pos); - } - } else { - if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); - } - if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(*this, pos); - } - if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(*this, pos); - } - } - throw SetPyError(PyExc_TypeError, "Unsupported integer type"); - } - - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); - } -}; - -class PyDictAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; - static constexpr const char *pyClassName = "DictAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } - - static void bindDerived(ClassTy &c) { - c.def("__len__", &PyDictAttribute::dunderLen); - c.def_static( - "get", - [](py::dict attributes, DefaultingPyMlirContext context) { - SmallVector mlirNamedAttributes; - mlirNamedAttributes.reserve(attributes.size()); - for (auto &it : attributes) { - auto &mlir_attr = it.second.cast(); - auto name = it.first.cast(); - mlirNamedAttributes.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), - toMlirStringRef(name)), - mlir_attr)); - } - MlirAttribute attr = - mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), - mlirNamedAttributes.data()); - return PyDictAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued dict attribute"); - c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { - MlirAttribute attr = - mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); - } - return PyAttribute(self.getContext(), attr); - }); - c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { - if (index < 0 || index >= self.dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); - } - MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); - return PyNamedAttribute( - namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); - }); - } -}; - -/// Refinement of PyDenseElementsAttribute for attributes containing -/// floating-point values. Supports element access. -class PyDenseFPElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; - static constexpr const char *pyClassName = "DenseFPElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - py::float_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); - } - - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::float_ is implicitly constructible - // from float and double. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(*this, pos); - } - if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(*this, pos); - } - throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); - } - - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); - } -}; - -class PyTypeAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; - static constexpr const char *pyClassName = "TypeAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirTypeAttrGet(value.get()); - return PyTypeAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued Type attribute"); - c.def_property_readonly("value", [](PyTypeAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirTypeAttrGetValue(self.get())); - }); - } -}; - -/// Unit Attribute subclass. Unit attributes don't have values. -class PyUnitAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; - static constexpr const char *pyClassName = "UnitAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - return PyUnitAttribute(context->getRef(), - mlirUnitAttrGet(context->get())); - }, - py::arg("context") = py::none(), "Create a Unit attribute."); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// Builtin type subclasses. -//------------------------------------------------------------------------------ - -namespace { - -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create an unsigned integer type"); - c.def_property_readonly( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_property_readonly( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_property_readonly( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_property_readonly( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a index type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f32 type."); - } -}; - -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - }, - "Create a complex type"); - c.def_property_readonly( - "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns element type."); - } -}; - -class PyShapedType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; - static constexpr const char *pyClassName = "ShapedType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly( - "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); - c.def_property_readonly( - "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, - "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( - "has_static_shape", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); - }, - "Returns whether the given shaped type has a static shape."); - c.def( - "is_dynamic_dim", - [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); - }, - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); - c.def( - "get_dim_size", - [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); - }, - "Returns the dim-th dimension of the given ranked shaped type."); - c.def_static( - "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - "Returns whether the given dimension size indicates a dynamic " - "dimension."); - c.def( - "is_dynamic_stride_or_offset", - [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); - }, - "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } - } -}; - -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } - return PyVectorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), - "Create a vector type"); - } -}; - -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyRankedTensorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - } -}; - -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedTensorType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("loc") = py::none(), - "Create a unranked tensor type"); - } -}; - -class PyMemRefLayoutMapList; - -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; - - PyMemRefLayoutMapList getLayout(); - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - std::vector layout, unsigned memorySpace, - DefaultingPyLocation loc) { - SmallVector maps; - maps.reserve(layout.size()); - for (PyAffineMap &map : layout) - maps.push_back(map); - - MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memorySpace); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyMemRefType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = 0, - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly("layout", &PyMemRefType::getLayout, - "The list of layout maps of the MemRef type.") - .def_property_readonly( - "memory_space", - [](PyMemRefType &self) -> unsigned { - return mlirMemRefTypeGetMemorySpace(self); - }, - "Returns the memory space of the given MemRef type."); - } -}; - -/// A list of affine layout maps in a memref type. Internally, these are stored -/// as consecutive elements, random access is cheap. Both the type and the maps -/// are owned by the context, no need to worry about lifetime extension. -class PyMemRefLayoutMapList - : public Sliceable { -public: - static constexpr const char *pyClassName = "MemRefLayoutMapList"; - - PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, - step), - memref(type) {} - - intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } - - PyAffineMap getElement(intptr_t index) { - return PyAffineMap(memref.getContext(), - mlirMemRefTypeGetAffineMap(memref, index)); - } - - PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyMemRefLayoutMapList(memref, startIndex, length, step); - } - -private: - PyMemRefType memref; -}; - -PyMemRefLayoutMapList PyMemRefType::getLayout() { - return PyMemRefLayoutMapList(*this); -} - -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, unsigned memorySpace, - DefaultingPyLocation loc) { - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( - "memory_space", - [](PyUnrankedMemRefType &self) -> unsigned { - return mlirUnrankedMemrefGetMemorySpace(self); - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; - -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](py::list elementList, DefaultingPyMlirContext context) { - intptr_t num = py::len(elementList); - // Mapping py::list to SmallVector. - SmallVector elements; - for (auto element : elementList) - elements.push_back(element.cast()); - MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); - return PyTupleType(context->getRef(), t); - }, - py::arg("elements"), py::arg("context") = py::none(), - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); - }, - "Returns the pos-th type in the tuple type."); - c.def_property_readonly( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; - -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - SmallVector inputsRaw(inputs.begin(), inputs.end()); - SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), - inputsRaw.data(), resultsRaw.size(), - resultsRaw.data()); - return PyFunctionType(context->getRef(), t); - }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_property_readonly( - "results", - [](PyFunctionType &self) { - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// PyAffineExpr and subclasses. -//------------------------------------------------------------------------------ - -namespace { -/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr -/// and should be castable from it. Intermediate hierarchy classes can be -/// modeled by specifying BaseTy. -template -class PyConcreteAffineExpr : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAffineExpr); - - PyConcreteAffineExpr() = default; - PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) - : BaseTy(std::move(contextRef), affineExpr) {} - PyConcreteAffineExpr(PyAffineExpr &orig) - : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} - - static MlirAffineExpr castFrom(PyAffineExpr &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - Twine("Cannot cast affine expression to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyAffineConstantExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; - static constexpr const char *pyClassName = "AffineConstantExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineConstantExpr get(intptr_t value, - DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = - mlirAffineConstantExprGet(context->get(), static_cast(value)); - return PyAffineConstantExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none()); - c.def_property_readonly("value", [](PyAffineConstantExpr &self) { - return mlirAffineConstantExprGetValue(self); - }); - } -}; - -class PyAffineDimExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; - static constexpr const char *pyClassName = "AffineDimExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); - return PyAffineDimExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineDimExpr &self) { - return mlirAffineDimExprGetPosition(self); - }); - } -}; - -class PyAffineSymbolExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; - static constexpr const char *pyClassName = "AffineSymbolExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); - return PyAffineSymbolExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { - return mlirAffineSymbolExprGetPosition(self); - }); - } -}; - -class PyAffineBinaryExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; - static constexpr const char *pyClassName = "AffineBinaryExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - PyAffineExpr lhs() { - MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); - return PyAffineExpr(getContext(), lhsExpr); - } - - PyAffineExpr rhs() { - MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); - return PyAffineExpr(getContext(), rhsExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); - c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); - } -}; - -class PyAffineAddExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; - static constexpr const char *pyClassName = "AffineAddExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); - return PyAffineAddExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineAddExpr::get); - } -}; - -class PyAffineMulExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; - static constexpr const char *pyClassName = "AffineMulExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); - return PyAffineMulExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineMulExpr::get); - } -}; - -class PyAffineModExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; - static constexpr const char *pyClassName = "AffineModExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); - return PyAffineModExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineModExpr::get); - } -}; - -class PyAffineFloorDivExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; - static constexpr const char *pyClassName = "AffineFloorDivExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); - return PyAffineFloorDivExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineFloorDivExpr::get); - } -}; - -class PyAffineCeilDivExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; - static constexpr const char *pyClassName = "AffineCeilDivExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); - return PyAffineCeilDivExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineCeilDivExpr::get); - } -}; -} // namespace - -bool PyAffineExpr::operator==(const PyAffineExpr &other) { - return mlirAffineExprEqual(affineExpr, other.affineExpr); -} - -py::object PyAffineExpr::getCapsule() { - return py::reinterpret_steal( - mlirPythonAffineExprToCapsule(*this)); -} - -PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { - MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); - if (mlirAffineExprIsNull(rawAffineExpr)) - throw py::error_already_set(); - return PyAffineExpr( - PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), - rawAffineExpr); -} - -//------------------------------------------------------------------------------ -// PyAffineMap and utilities. -//------------------------------------------------------------------------------ - -namespace { -/// A list of expressions contained in an affine map. Internally these are -/// stored as a consecutive array leading to inexpensive random access. Both -/// the map and the expression are owned by the context so we need not bother -/// with lifetime extension. -class PyAffineMapExprList - : public Sliceable { -public: - static constexpr const char *pyClassName = "AffineExprList"; - - PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirAffineMapGetNumResults(map) : length, - step), - affineMap(map) {} - - intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } - - PyAffineExpr getElement(intptr_t pos) { - return PyAffineExpr(affineMap.getContext(), - mlirAffineMapGetResult(affineMap, pos)); - } - - PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyAffineMapExprList(affineMap, startIndex, length, step); - } - -private: - PyAffineMap affineMap; -}; -} // end namespace - -bool PyAffineMap::operator==(const PyAffineMap &other) { - return mlirAffineMapEqual(affineMap, other.affineMap); -} - -py::object PyAffineMap::getCapsule() { - return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); -} - -PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { - MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); - if (mlirAffineMapIsNull(rawAffineMap)) - throw py::error_already_set(); - return PyAffineMap( - PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), - rawAffineMap); -} - -//------------------------------------------------------------------------------ -// PyIntegerSet and utilities. -//------------------------------------------------------------------------------ - -class PyIntegerSetConstraint { -public: - PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} - - PyAffineExpr getExpr() { - return PyAffineExpr(set.getContext(), - mlirIntegerSetGetConstraint(set, pos)); - } - - bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - - static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint") - .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) - .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); - } - -private: - PyIntegerSet set; - intptr_t pos; -}; - -class PyIntegerSetConstraintList - : public Sliceable { -public: - static constexpr const char *pyClassName = "IntegerSetConstraintList"; - - PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, - step), - set(set) {} - - intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } - - PyIntegerSetConstraint getElement(intptr_t pos) { - return PyIntegerSetConstraint(set, pos); - } - - PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyIntegerSetConstraintList(set, startIndex, length, step); - } - -private: - PyIntegerSet set; -}; - -bool PyIntegerSet::operator==(const PyIntegerSet &other) { - return mlirIntegerSetEqual(integerSet, other.integerSet); -} - -py::object PyIntegerSet::getCapsule() { - return py::reinterpret_steal( - mlirPythonIntegerSetToCapsule(*this)); -} - -PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { - MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); - if (mlirIntegerSetIsNull(rawIntegerSet)) - throw py::error_already_set(); - return PyIntegerSet( - PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), - rawIntegerSet); -} - -/// Attempts to populate `result` with the content of `list` casted to the -/// appropriate type (Python and C types are provided as template arguments). -/// Throws errors in case of failure, using "action" to describe what the caller -/// was attempting to do. -template -static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, - StringRef action) { - result.reserve(py::len(list)); - for (py::handle item : list) { - try { - result.push_back(item.cast()); - } catch (py::cast_error &err) { - std::string msg = (llvm::Twine("Invalid expression when ") + action + - " (" + err.what() + ")") - .str(); - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - std::string msg = (llvm::Twine("Invalid expression (None?) when ") + - action + " (" + err.what() + ")") - .str(); - throw py::cast_error(msg); - } - } -} - -//------------------------------------------------------------------------------ -// Populates the pybind11 IR submodule. -//------------------------------------------------------------------------------ - -void mlir::python::populateIRSubmodule(py::module &m) { - //---------------------------------------------------------------------------- - // Mapping of MlirContext - //---------------------------------------------------------------------------- - py::class_(m, "Context") - .def(py::init<>(&PyMlirContext::createNewContextForInit)) - .def_static("_get_live_count", &PyMlirContext::getLiveCount) - .def("_get_context_again", - [](PyMlirContext &self) { - PyMlirContextRef ref = PyMlirContext::forContext(self.get()); - return ref.releaseObject(); - }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyMlirContext::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) - .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *context = PyThreadContextEntry::getDefaultContext(); - if (!context) - throw SetPyError(PyExc_ValueError, "No current Context"); - return context; - }, - "Gets the Context bound to the current thread or raises ValueError") - .def_property_readonly( - "dialects", - [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Gets a container for accessing dialects by name") - .def_property_readonly( - "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Alias for 'dialect'") - .def( - "get_dialect_descriptor", - [=](PyMlirContext &self, std::string &name) { - MlirDialect dialect = mlirContextGetOrLoadDialect( - self.get(), {name.data(), name.size()}); - if (mlirDialectIsNull(dialect)) { - throw SetPyError(PyExc_ValueError, - Twine("Dialect '") + name + "' not found"); - } - return PyDialectDescriptor(self.getRef(), dialect); - }, - "Gets or loads a dialect by name, returning its descriptor object") - .def_property( - "allow_unregistered_dialects", - [](PyMlirContext &self) -> bool { - return mlirContextGetAllowUnregisteredDialects(self.get()); - }, - [](PyMlirContext &self, bool value) { - mlirContextSetAllowUnregisteredDialects(self.get(), value); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialectDescriptor - //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor") - .def_property_readonly("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = - mlirDialectGetNamespace(self.get()); - return py::str(ns.data, ns.length); - }) - .def("__repr__", [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - std::string repr(""); - return repr; - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialects - //---------------------------------------------------------------------------- - py::class_(m, "Dialects") - .def("__getitem__", - [=](PyDialects &self, std::string keyName) { - MlirDialect dialect = - self.getDialectForKey(keyName, /*attrError=*/false); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(keyName, std::move(descriptor)); - }) - .def("__getattr__", [=](PyDialects &self, std::string attrName) { - MlirDialect dialect = - self.getDialectForKey(attrName, /*attrError=*/true); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(attrName, std::move(descriptor)); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialect - //---------------------------------------------------------------------------- - py::class_(m, "Dialect") - .def(py::init(), "descriptor") - .def_property_readonly( - "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](py::object self) { - auto clazz = self.attr("__class__"); - return py::str(""); - }); - - //---------------------------------------------------------------------------- - // Mapping of Location - //---------------------------------------------------------------------------- - py::class_(m, "Location") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) - .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit) - .def("__eq__", - [](PyLocation &self, PyLocation &other) -> bool { - return mlirLocationEqual(self, other); - }) - .def("__eq__", [](PyLocation &self, py::object other) { return false; }) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *loc = PyThreadContextEntry::getDefaultLocation(); - if (!loc) - throw SetPyError(PyExc_ValueError, "No current Location"); - return loc; - }, - "Gets the Location bound to the current thread or raises ValueError") - .def_static( - "unknown", - [](DefaultingPyMlirContext context) { - return PyLocation(context->getRef(), - mlirLocationUnknownGet(context->get())); - }, - py::arg("context") = py::none(), - "Gets a Location representing an unknown location") - .def_static( - "file", - [](std::string filename, int line, int col, - DefaultingPyMlirContext context) { - return PyLocation( - context->getRef(), - mlirLocationFileLineColGet( - context->get(), toMlirStringRef(filename), line, col)); - }, - py::arg("filename"), py::arg("line"), py::arg("col"), - py::arg("context") = py::none(), kContextGetFileLocationDocstring) - .def_property_readonly( - "context", - [](PyLocation &self) { return self.getContext().getObject(); }, - "Context that owns the Location") - .def("__repr__", [](PyLocation &self) { - PyPrintAccumulator printAccum; - mlirLocationPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }); - - //---------------------------------------------------------------------------- - // Mapping of Module - //---------------------------------------------------------------------------- - py::class_(m, "Module") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) - .def_static( - "parse", - [](const std::string moduleAsm, DefaultingPyMlirContext context) { - MlirModule module = mlirModuleCreateParse( - context->get(), toMlirStringRef(moduleAsm)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirModuleIsNull(module)) { - throw SetPyError( - PyExc_ValueError, - "Unable to parse module assembly (see diagnostics)"); - } - return PyModule::forModule(module).releaseObject(); - }, - py::arg("asm"), py::arg("context") = py::none(), - kModuleParseDocstring) - .def_static( - "create", - [](DefaultingPyLocation loc) { - MlirModule module = mlirModuleCreateEmpty(loc); - return PyModule::forModule(module).releaseObject(); - }, - py::arg("loc") = py::none(), "Creates an empty module") - .def_property_readonly( - "context", - [](PyModule &self) { return self.getContext().getObject(); }, - "Context that created the Module") - .def_property_readonly( - "operation", - [](PyModule &self) { - return PyOperation::forOperation(self.getContext(), - mlirModuleGetOperation(self.get()), - self.getRef().releaseObject()) - .releaseObject(); - }, - "Accesses the module as an operation") - .def_property_readonly( - "body", - [](PyModule &self) { - PyOperationRef module_op = PyOperation::forOperation( - self.getContext(), mlirModuleGetOperation(self.get()), - self.getRef().releaseObject()); - PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); - return returnBlock; - }, - "Return the block for this module") - .def( - "dump", - [](PyModule &self) { - mlirOperationDump(mlirModuleGetOperation(self.get())); - }, - kDumpDocstring) - .def( - "__str__", - [](PyModule &self) { - MlirOperation operation = mlirModuleGetOperation(self.get()); - PyPrintAccumulator printAccum; - mlirOperationPrint(operation, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }, - kOperationStrDunderDocstring); - - //---------------------------------------------------------------------------- - // Mapping of Operation. - //---------------------------------------------------------------------------- - py::class_(m, "_OperationBase") - .def("__eq__", - [](PyOperationBase &self, PyOperationBase &other) { - return &self.getOperation() == &other.getOperation(); - }) - .def("__eq__", - [](PyOperationBase &self, py::object other) { return false; }) - .def_property_readonly("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap( - self.getOperation().getRef()); - }) - .def_property_readonly("operands", - [](PyOperationBase &self) { - return PyOpOperandList( - self.getOperation().getRef()); - }) - .def_property_readonly("regions", - [](PyOperationBase &self) { - return PyRegionList( - self.getOperation().getRef()); - }) - .def_property_readonly( - "results", - [](PyOperationBase &self) { - return PyOpResultList(self.getOperation().getRef()); - }, - "Returns the list of Operation results.") - .def_property_readonly( - "result", - [](PyOperationBase &self) { - auto &operation = self.getOperation(); - auto numResults = mlirOperationGetNumResults(operation); - if (numResults != 1) { - auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw SetPyError( - PyExc_ValueError, - Twine("Cannot call .result on operation ") + - StringRef(name.data, name.length) + " which has " + - Twine(numResults) + - " results (it is only valid for operations with a " - "single result)"); - } - return PyOpResult(operation.getRef(), - mlirOperationGetResult(operation, 0)); - }, - "Shortcut to get an op result if it has only one (throws an error " - "otherwise).") - .def("__iter__", - [](PyOperationBase &self) { - return PyRegionIterator(self.getOperation().getRef()); - }) - .def( - "__str__", - [](PyOperationBase &self) { - return self.getAsm(/*binary=*/false, - /*largeElementsLimit=*/llvm::None, - /*enableDebugInfo=*/false, - /*prettyDebugInfo=*/false, - /*printGenericOpForm=*/false, - /*useLocalScope=*/false); - }, - "Returns the assembly form of the operation.") - .def("print", &PyOperationBase::print, - // Careful: Lots of arguments must match up with print method. - py::arg("file") = py::none(), py::arg("binary") = false, - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, kOperationPrintDocstring) - .def("get_asm", &PyOperationBase::getAsm, - // Careful: Lots of arguments must match up with get_asm method. - py::arg("binary") = false, - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, kOperationGetAsmDocstring) - .def( - "verify", - [](PyOperationBase &self) { - return mlirOperationVerify(self.getOperation()); - }, - "Verify the operation and return true if it passes, false if it " - "fails."); - - py::class_(m, "Operation") - .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("results") = py::none(), - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - kOperationCreateDocstring) - .def_property_readonly("name", - [](PyOperation &self) { - MlirOperation operation = self.get(); - MlirStringRef name = mlirIdentifierStr( - mlirOperationGetName(operation)); - return py::str(name.data, name.length); - }) - .def_property_readonly( - "context", - [](PyOperation &self) { return self.getContext().getObject(); }, - "Context that owns the Operation") - .def_property_readonly("opview", &PyOperation::createOpView); - - auto opViewClass = - py::class_(m, "OpView") - .def(py::init()) - .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly( - "context", - [](PyOpView &self) { - return self.getOperation().getContext().getObject(); - }, - "Context that owns the Operation") - .def("__str__", [](PyOpView &self) { - return py::str(self.getOperationObject()); - }); - opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); - opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); - opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); - opViewClass.attr("build_generic") = classmethod( - &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), - py::arg("operands") = py::none(), py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = py::none(), - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - "Builds a specific, generated OpView based on class level attributes."); - - //---------------------------------------------------------------------------- - // Mapping of PyRegion. - //---------------------------------------------------------------------------- - py::class_(m, "Region") - .def_property_readonly( - "blocks", - [](PyRegion &self) { - return PyBlockList(self.getParentOperation(), self.get()); - }, - "Returns a forward-optimized sequence of blocks.") - .def( - "__iter__", - [](PyRegion &self) { - self.checkValid(); - MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); - return PyBlockIterator(self.getParentOperation(), firstBlock); - }, - "Iterates over blocks in the region.") - .def("__eq__", - [](PyRegion &self, PyRegion &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); - - //---------------------------------------------------------------------------- - // Mapping of PyBlock. - //---------------------------------------------------------------------------- - py::class_(m, "Block") - .def_property_readonly( - "arguments", - [](PyBlock &self) { - return PyBlockArgumentList(self.getParentOperation(), self.get()); - }, - "Returns a list of block arguments.") - .def_property_readonly( - "operations", - [](PyBlock &self) { - return PyOperationList(self.getParentOperation(), self.get()); - }, - "Returns a forward-optimized sequence of operations.") - .def( - "__iter__", - [](PyBlock &self) { - self.checkValid(); - MlirOperation firstOperation = - mlirBlockGetFirstOperation(self.get()); - return PyOperationIterator(self.getParentOperation(), - firstOperation); - }, - "Iterates over operations in the block.") - .def("__eq__", - [](PyBlock &self, PyBlock &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) - .def( - "__str__", - [](PyBlock &self) { - self.checkValid(); - PyPrintAccumulator printAccum; - mlirBlockPrint(self.get(), printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }, - "Returns the assembly form of the block."); - - //---------------------------------------------------------------------------- - // Mapping of PyInsertionPoint. - //---------------------------------------------------------------------------- - - py::class_(m, "InsertionPoint") - .def(py::init(), py::arg("block"), - "Inserts after the last operation but still inside the block.") - .def("__enter__", &PyInsertionPoint::contextEnter) - .def("__exit__", &PyInsertionPoint::contextExit) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); - if (!ip) - throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); - return ip; - }, - "Gets the InsertionPoint bound to the current thread or raises " - "ValueError if none has been set") - .def(py::init(), py::arg("beforeOperation"), - "Inserts before a referenced operation.") - .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - py::arg("block"), "Inserts at the beginning of the block.") - .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - py::arg("block"), "Inserts before the block terminator.") - .def("insert", &PyInsertionPoint::insert, py::arg("operation"), - "Inserts an operation."); - - //---------------------------------------------------------------------------- - // Mapping of PyAttribute. - //---------------------------------------------------------------------------- - py::class_(m, "Attribute") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAttribute::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) - .def_static( - "parse", - [](std::string attrSpec, DefaultingPyMlirContext context) { - MlirAttribute type = mlirAttributeParseGet( - context->get(), toMlirStringRef(attrSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse attribute: '") + - attrSpec + "'"); - } - return PyAttribute(context->getRef(), type); - }, - py::arg("asm"), py::arg("context") = py::none(), - "Parses an attribute from an assembly form") - .def_property_readonly( - "context", - [](PyAttribute &self) { return self.getContext().getObject(); }, - "Context that owns the Attribute") - .def_property_readonly("type", - [](PyAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirAttributeGetType(self)); - }) - .def( - "get_named", - [](PyAttribute &self, std::string name) { - return PyNamedAttribute(self, std::move(name)); - }, - py::keep_alive<0, 1>(), "Binds a name to the attribute") - .def("__eq__", - [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) - .def( - "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, - kDumpDocstring) - .def( - "__str__", - [](PyAttribute &self) { - PyPrintAccumulator printAccum; - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }, - "Returns the assembly form of the Attribute.") - .def("__repr__", [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and are - // printed. This may need to be re-evaluated if debug dumps end up - // being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyNamedAttribute - //---------------------------------------------------------------------------- - py::class_(m, "NamedAttribute") - .def("__repr__", - [](PyNamedAttribute &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("NamedAttribute("); - printAccum.parts.append( - mlirIdentifierStr(self.namedAttr.name).data); - printAccum.parts.append("="); - mlirAttributePrint(self.namedAttr.attribute, - printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "name", - [](PyNamedAttribute &self) { - return py::str(mlirIdentifierStr(self.namedAttr.name).data, - mlirIdentifierStr(self.namedAttr.name).length); - }, - "The name of the NamedAttribute binding") - .def_property_readonly( - "attr", - [](PyNamedAttribute &self) { - // TODO: When named attribute is removed/refactored, also remove - // this constructor (it does an inefficient table lookup). - auto contextRef = PyMlirContext::forContext( - mlirAttributeGetContext(self.namedAttr.attribute)); - return PyAttribute(std::move(contextRef), self.namedAttr.attribute); - }, - py::keep_alive<0, 1>(), - "The underlying generic attribute of the NamedAttribute binding"); - - // Builtin attribute bindings. - PyAffineMapAttribute::bind(m); - PyArrayAttribute::bind(m); - PyArrayAttribute::PyArrayAttributeIterator::bind(m); - PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m); - PyDenseFPElementsAttribute::bind(m); - PyDenseIntElementsAttribute::bind(m); - PyDictAttribute::bind(m); - PyFlatSymbolRefAttribute::bind(m); - PyFloatAttribute::bind(m); - PyIntegerAttribute::bind(m); - PyStringAttribute::bind(m); - PyTypeAttribute::bind(m); - PyUnitAttribute::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyType. - //---------------------------------------------------------------------------- - py::class_(m, "Type") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) - .def_static( - "parse", - [](std::string typeSpec, DefaultingPyMlirContext context) { - MlirType type = - mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse type: '") + typeSpec + - "'"); - } - return PyType(context->getRef(), type); - }, - py::arg("asm"), py::arg("context") = py::none(), - kContextParseTypeDocstring) - .def_property_readonly( - "context", [](PyType &self) { return self.getContext().getObject(); }, - "Context that owns the Type") - .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) - .def("__eq__", [](PyType &self, py::object &other) { return false; }) - .def( - "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) - .def( - "__str__", - [](PyType &self) { - PyPrintAccumulator printAccum; - mlirTypePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }, - "Returns the assembly form of the type.") - .def("__repr__", [](PyType &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, types are an exception as they typically have compact - // assembly forms and printing them is useful. - PyPrintAccumulator printAccum; - printAccum.parts.append("Type("); - mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }); - - // Builtin type bindings. - PyIntegerType::bind(m); - PyIndexType::bind(m); - PyBF16Type::bind(m); - PyF16Type::bind(m); - PyF32Type::bind(m); - PyF64Type::bind(m); - PyNoneType::bind(m); - PyComplexType::bind(m); - PyShapedType::bind(m); - PyVectorType::bind(m); - PyRankedTensorType::bind(m); - PyUnrankedTensorType::bind(m); - PyMemRefType::bind(m); - PyMemRefLayoutMapList::bind(m); - PyUnrankedMemRefType::bind(m); - PyTupleType::bind(m); - PyFunctionType::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of Value. - //---------------------------------------------------------------------------- - py::class_(m, "Value") - .def_property_readonly( - "context", - [](PyValue &self) { return self.getParentOperation()->getContext(); }, - "Context in which the value lives.") - .def( - "dump", [](PyValue &self) { mlirValueDump(self.get()); }, - kDumpDocstring) - .def("__eq__", - [](PyValue &self, PyValue &other) { - return self.get().ptr == other.get().ptr; - }) - .def("__eq__", [](PyValue &self, py::object other) { return false; }) - .def( - "__str__", - [](PyValue &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("Value("); - mlirValuePrint(self.get(), printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }, - kValueDunderStrDocstring) - .def_property_readonly("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }); - PyBlockArgument::bind(m); - PyOpResult::bind(m); - - // Container bindings. - PyBlockArgumentList::bind(m); - PyBlockIterator::bind(m); - PyBlockList::bind(m); - PyOperationIterator::bind(m); - PyOperationList::bind(m); - PyOpAttributeMap::bind(m); - PyOpOperandList::bind(m); - PyOpResultList::bind(m); - PyRegionIterator::bind(m); - PyRegionList::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyAffineExpr and derived classes. - //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineExpr::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) - .def("__add__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineAddExpr::get(self, other); - }) - .def("__mul__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineMulExpr::get(self, other); - }) - .def("__mod__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineModExpr::get(self, other); - }) - .def("__sub__", - [](PyAffineExpr &self, PyAffineExpr &other) { - auto negOne = - PyAffineConstantExpr::get(-1, *self.getContext().get()); - return PyAffineAddExpr::get(self, - PyAffineMulExpr::get(negOne, other)); - }) - .def("__eq__", [](PyAffineExpr &self, - PyAffineExpr &other) { return self == other; }) - .def("__eq__", - [](PyAffineExpr &self, py::object &other) { return false; }) - .def("__str__", - [](PyAffineExpr &self) { - PyPrintAccumulator printAccum; - mlirAffineExprPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyAffineExpr &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("AffineExpr("); - mlirAffineExprPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyAffineExpr &self) { return self.getContext().getObject(); }) - .def_static( - "get_add", &PyAffineAddExpr::get, - "Gets an affine expression containing a sum of two expressions.") - .def_static( - "get_mul", &PyAffineMulExpr::get, - "Gets an affine expression containing a product of two expressions.") - .def_static("get_mod", &PyAffineModExpr::get, - "Gets an affine expression containing the modulo of dividing " - "one expression by another.") - .def_static("get_floor_div", &PyAffineFloorDivExpr::get, - "Gets an affine expression containing the rounded-down " - "result of dividing one expression by another.") - .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, - "Gets an affine expression containing the rounded-up result " - "of dividing one expression by another.") - .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none(), - "Gets a constant affine expression with the given value.") - .def_static( - "get_dim", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none(), - "Gets an affine expression of a dimension at the given position.") - .def_static( - "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none(), - "Gets an affine expression of a symbol at the given position.") - .def( - "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, - kDumpDocstring); - PyAffineConstantExpr::bind(m); - PyAffineDimExpr::bind(m); - PyAffineSymbolExpr::bind(m); - PyAffineBinaryExpr::bind(m); - PyAffineAddExpr::bind(m); - PyAffineMulExpr::bind(m); - PyAffineModExpr::bind(m); - PyAffineFloorDivExpr::bind(m); - PyAffineCeilDivExpr::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyAffineMap. - //---------------------------------------------------------------------------- - py::class_(m, "AffineMap") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineMap::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) - .def("__eq__", - [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) - .def("__str__", - [](PyAffineMap &self) { - PyPrintAccumulator printAccum; - mlirAffineMapPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyAffineMap &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("AffineMap("); - mlirAffineMapPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyAffineMap &self) { return self.getContext().getObject(); }, - "Context that owns the Affine Map") - .def( - "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, - kDumpDocstring) - .def_static( - "get", - [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, - DefaultingPyMlirContext context) { - SmallVector affineExprs; - pyListToVector( - exprs, affineExprs, "attempting to create an AffineMap"); - MlirAffineMap map = - mlirAffineMapGet(context->get(), dimCount, symbolCount, - affineExprs.size(), affineExprs.data()); - return PyAffineMap(context->getRef(), map); - }, - py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), - py::arg("context") = py::none(), - "Gets a map with the given expressions as results.") - .def_static( - "get_constant", - [](intptr_t value, DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapConstantGet(context->get(), value); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an affine map with a single constant result") - .def_static( - "get_empty", - [](DefaultingPyMlirContext context) { - MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("context") = py::none(), "Gets an empty affine map.") - .def_static( - "get_identity", - [](intptr_t nDims, DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapMultiDimIdentityGet(context->get(), nDims); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("n_dims"), py::arg("context") = py::none(), - "Gets an identity map with the given number of dimensions.") - .def_static( - "get_minor_identity", - [](intptr_t nDims, intptr_t nResults, - DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("n_dims"), py::arg("n_results"), - py::arg("context") = py::none(), - "Gets a minor identity map with the given number of dimensions and " - "results.") - .def_static( - "get_permutation", - [](std::vector permutation, - DefaultingPyMlirContext context) { - if (!isPermutation(permutation)) - throw py::cast_error("Invalid permutation when attempting to " - "create an AffineMap"); - MlirAffineMap affineMap = mlirAffineMapPermutationGet( - context->get(), permutation.size(), permutation.data()); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("permutation"), py::arg("context") = py::none(), - "Gets an affine map that permutes its inputs.") - .def("get_submap", - [](PyAffineMap &self, std::vector &resultPos) { - intptr_t numResults = mlirAffineMapGetNumResults(self); - for (intptr_t pos : resultPos) { - if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); - } - MlirAffineMap affineMap = mlirAffineMapGetSubMap( - self, resultPos.size(), resultPos.data()); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_major_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMajorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_minor_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMinorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def_property_readonly( - "is_permutation", - [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_property_readonly("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_property_readonly( - "n_dims", - [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_property_readonly( - "n_inputs", - [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_property_readonly( - "n_symbols", - [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_property_readonly("results", [](PyAffineMap &self) { - return PyAffineMapExprList(self); - }); - PyAffineMapExprList::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyIntegerSet. - //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyIntegerSet::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) - .def("__eq__", [](PyIntegerSet &self, - PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) - .def("__str__", - [](PyIntegerSet &self) { - PyPrintAccumulator printAccum; - mlirIntegerSetPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyIntegerSet &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("IntegerSet("); - mlirIntegerSetPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyIntegerSet &self) { return self.getContext().getObject(); }) - .def( - "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, - kDumpDocstring) - .def_static( - "get", - [](intptr_t numDims, intptr_t numSymbols, py::list exprs, - std::vector eqFlags, DefaultingPyMlirContext context) { - if (exprs.size() != eqFlags.size()) - throw py::value_error( - "Expected the number of constraints to match " - "that of equality flags"); - if (exprs.empty()) - throw py::value_error("Expected non-empty list of constraints"); - - // Copy over to a SmallVector because std::vector has a - // specialization for booleans that packs data and does not - // expose a `bool *`. - SmallVector flags(eqFlags.begin(), eqFlags.end()); - - SmallVector affineExprs; - pyListToVector(exprs, affineExprs, - "attempting to create an IntegerSet"); - MlirIntegerSet set = mlirIntegerSetGet( - context->get(), numDims, numSymbols, exprs.size(), - affineExprs.data(), flags.data()); - return PyIntegerSet(context->getRef(), set); - }, - py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), - py::arg("eq_flags"), py::arg("context") = py::none()) - .def_static( - "get_empty", - [](intptr_t numDims, intptr_t numSymbols, - DefaultingPyMlirContext context) { - MlirIntegerSet set = - mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); - return PyIntegerSet(context->getRef(), set); - }, - py::arg("num_dims"), py::arg("num_symbols"), - py::arg("context") = py::none()) - .def("get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, - intptr_t numResultDims, intptr_t numResultSymbols) { - if (static_cast(dimExprs.size()) != - mlirIntegerSetGetNumDims(self)) - throw py::value_error( - "Expected the number of dimension replacement expressions " - "to match that of dimensions"); - if (static_cast(symbolExprs.size()) != - mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( - "Expected the number of symbol replacement expressions " - "to match that of symbols"); - - SmallVector dimAffineExprs, symbolAffineExprs; - pyListToVector( - dimExprs, dimAffineExprs, - "attempting to create an IntegerSet by replacing dimensions"); - pyListToVector( - symbolExprs, symbolAffineExprs, - "attempting to create an IntegerSet by replacing symbols"); - MlirIntegerSet set = mlirIntegerSetReplaceGet( - self, dimAffineExprs.data(), symbolAffineExprs.data(), - numResultDims, numResultSymbols); - return PyIntegerSet(self.getContext(), set); - }) - .def_property_readonly("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_property_readonly( - "n_dims", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_property_readonly( - "n_symbols", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_property_readonly( - "n_inputs", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_property_readonly("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_property_readonly("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_property_readonly("constraints", [](PyIntegerSet &self) { - return PyIntegerSetConstraintList(self); - }); - PyIntegerSetConstraint::bind(m); - PyIntegerSetConstraintList::bind(m); -} diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModules.h deleted file mode 100644 index 8140d7043..000000000 --- a/mlir/lib/Bindings/Python/IRModules.h +++ /dev/null @@ -1,768 +0,0 @@ -//===- IRModules.h - IR Submodules of pybind module -----------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H -#define MLIR_BINDINGS_PYTHON_IRMODULES_H - -#include - -#include "PybindUtils.h" - -#include "mlir-c/AffineExpr.h" -#include "mlir-c/AffineMap.h" -#include "mlir-c/IR.h" -#include "mlir-c/IntegerSet.h" -#include "llvm/ADT/DenseMap.h" - -namespace mlir { -namespace python { - -class PyBlock; -class PyInsertionPoint; -class PyLocation; -class DefaultingPyLocation; -class PyMlirContext; -class DefaultingPyMlirContext; -class PyModule; -class PyOperation; -class PyType; -class PyValue; - -/// Template for a reference to a concrete type which captures a python -/// reference to its underlying python object. -template -class PyObjectRef { -public: - PyObjectRef(T *referrent, pybind11::object object) - : referrent(referrent), object(std::move(object)) { - assert(this->referrent && - "cannot construct PyObjectRef with null referrent"); - assert(this->object && "cannot construct PyObjectRef with null object"); - } - PyObjectRef(PyObjectRef &&other) - : referrent(other.referrent), object(std::move(other.object)) { - other.referrent = nullptr; - assert(!other.object); - } - PyObjectRef(const PyObjectRef &other) - : referrent(other.referrent), object(other.object /* copies */) {} - ~PyObjectRef() {} - - int getRefCount() { - if (!object) - return 0; - return object.ref_count(); - } - - /// Releases the object held by this instance, returning it. - /// This is the proper thing to return from a function that wants to return - /// the reference. Note that this does not work from initializers. - pybind11::object releaseObject() { - assert(referrent && object); - referrent = nullptr; - auto stolen = std::move(object); - return stolen; - } - - T *get() { return referrent; } - T *operator->() { - assert(referrent && object); - return referrent; - } - pybind11::object getObject() { - assert(referrent && object); - return object; - } - operator bool() const { return referrent && object; } - -private: - T *referrent; - pybind11::object object; -}; - -/// Tracks an entry in the thread context stack. New entries are pushed onto -/// here for each with block that activates a new InsertionPoint, Context or -/// Location. -/// -/// Pushing either a Location or InsertionPoint also pushes its associated -/// Context. Pushing a Context will not modify the Location or InsertionPoint -/// unless if they are from a different context, in which case, they are -/// cleared. -class PyThreadContextEntry { -public: - enum class FrameKind { - Context, - InsertionPoint, - Location, - }; - - PyThreadContextEntry(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, - pybind11::object location) - : context(std::move(context)), insertionPoint(std::move(insertionPoint)), - location(std::move(location)), frameKind(frameKind) {} - - /// Gets the top of stack context and return nullptr if not defined. - static PyMlirContext *getDefaultContext(); - - /// Gets the top of stack insertion point and return nullptr if not defined. - static PyInsertionPoint *getDefaultInsertionPoint(); - - /// Gets the top of stack location and returns nullptr if not defined. - static PyLocation *getDefaultLocation(); - - PyMlirContext *getContext(); - PyInsertionPoint *getInsertionPoint(); - PyLocation *getLocation(); - FrameKind getFrameKind() { return frameKind; } - - /// Stack management. - static PyThreadContextEntry *getTopOfStack(); - static pybind11::object pushContext(PyMlirContext &context); - static void popContext(PyMlirContext &context); - static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); - static void popInsertionPoint(PyInsertionPoint &insertionPoint); - static pybind11::object pushLocation(PyLocation &location); - static void popLocation(PyLocation &location); - - /// Gets the thread local stack. - static std::vector &getStack(); - -private: - static void push(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, pybind11::object location); - - /// An object reference to the PyContext. - pybind11::object context; - /// An object reference to the current insertion point. - pybind11::object insertionPoint; - /// An object reference to the current location. - pybind11::object location; - // The kind of push that was performed. - FrameKind frameKind; -}; - -/// Wrapper around MlirContext. -using PyMlirContextRef = PyObjectRef; -class PyMlirContext { -public: - PyMlirContext() = delete; - PyMlirContext(const PyMlirContext &) = delete; - PyMlirContext(PyMlirContext &&) = delete; - - /// For the case of a python __init__ (py::init) method, pybind11 is quite - /// strict about needing to return a pointer that is not yet associated to - /// an py::object. Since the forContext() method acts like a pool, possibly - /// returning a recycled context, it does not satisfy this need. The usual - /// way in python to accomplish such a thing is to override __new__, but - /// that is also not supported by pybind11. Instead, we use this entry - /// point which always constructs a fresh context (which cannot alias an - /// existing one because it is fresh). - static PyMlirContext *createNewContextForInit(); - - /// Returns a context reference for the singleton PyMlirContext wrapper for - /// the given context. - static PyMlirContextRef forContext(MlirContext context); - ~PyMlirContext(); - - /// Accesses the underlying MlirContext. - MlirContext get() { return context; } - - /// Gets a strong reference to this context, which will ensure it is kept - /// alive for the life of the reference. - PyMlirContextRef getRef() { - return PyMlirContextRef(this, pybind11::cast(this)); - } - - /// Gets a capsule wrapping the void* within the MlirContext. - pybind11::object getCapsule(); - - /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. - /// Note that PyMlirContext instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirContext - /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); - - /// Gets the count of live context objects. Used for testing. - static size_t getLiveCount(); - - /// Gets the count of live operations associated with this context. - /// Used for testing. - size_t getLiveOperationCount(); - - /// Gets the count of live modules associated with this context. - /// Used for testing. - size_t getLiveModuleCount(); - - /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb); - -private: - PyMlirContext(MlirContext context); - // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, - // preserving the relationship that an MlirContext maps to a single - // PyMlirContext wrapper. This could be replaced in the future with an - // extension mechanism on the MlirContext for stashing user pointers. - // Note that this holds a handle, which does not imply ownership. - // Mappings will be removed when the context is destructed. - using LiveContextMap = llvm::DenseMap; - static LiveContextMap &getLiveContexts(); - - // Interns all live modules associated with this context. Modules tracked - // in this map are valid. When a module is invalidated, it is removed - // from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveModuleMap = - llvm::DenseMap>; - LiveModuleMap liveModules; - - // Interns all live operations associated with this context. Operations - // tracked in this map are valid. When an operation is invalidated, it is - // removed from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveOperationMap = - llvm::DenseMap>; - LiveOperationMap liveOperations; - - MlirContext context; - friend class PyModule; - friend class PyOperation; -}; - -/// Used in function arguments when None should resolve to the current context -/// manager set instance. -class DefaultingPyMlirContext - : public Defaulting { -public: - using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - "[ThreadContextAware] mlir.ir.Context"; - static PyMlirContext &resolve(); -}; - -/// Base class for all objects that directly or indirectly depend on an -/// MlirContext. The lifetime of the context will extend at least to the -/// lifetime of these instances. -/// Immutable objects that depend on a context extend this directly. -class BaseContextObject { -public: - BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { - assert(this->contextRef && - "context object constructed with null context ref"); - } - - /// Accesses the context reference. - PyMlirContextRef &getContext() { return contextRef; } - -private: - PyMlirContextRef contextRef; -}; - -/// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in -/// order to differentiate it from the `Dialect` base class which is extended by -/// plugins which extend dialect functionality through extension python code. -/// This should be seen as the "low-level" object and `Dialect` as the -/// high-level, user facing object. -class PyDialectDescriptor : public BaseContextObject { -public: - PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) - : BaseContextObject(std::move(contextRef)), dialect(dialect) {} - - MlirDialect get() { return dialect; } - -private: - MlirDialect dialect; -}; - -/// User-level object for accessing dialects with dotted syntax such as: -/// ctx.dialect.std -class PyDialects : public BaseContextObject { -public: - PyDialects(PyMlirContextRef contextRef) - : BaseContextObject(std::move(contextRef)) {} - - MlirDialect getDialectForKey(const std::string &key, bool attrError); -}; - -/// User-level dialect object. For dialects that have a registered extension, -/// this will be the base class of the extension dialect type. For un-extended, -/// objects of this type will be returned directly. -class PyDialect { -public: - PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} - - pybind11::object getDescriptor() { return descriptor; } - -private: - pybind11::object descriptor; -}; - -/// Wrapper around an MlirLocation. -class PyLocation : public BaseContextObject { -public: - PyLocation(PyMlirContextRef contextRef, MlirLocation loc) - : BaseContextObject(std::move(contextRef)), loc(loc) {} - - operator MlirLocation() const { return loc; } - MlirLocation get() const { return loc; } - - /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb); - - /// Gets a capsule wrapping the void* within the MlirLocation. - pybind11::object getCapsule(); - - /// Creates a PyLocation from the MlirLocation wrapped by a capsule. - /// Note that PyLocation instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirLocation - /// is taken by calling this function. - static PyLocation createFromCapsule(pybind11::object capsule); - -private: - MlirLocation loc; -}; - -/// Used in function arguments when None should resolve to the current context -/// manager set instance. -class DefaultingPyLocation - : public Defaulting { -public: - using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - "[ThreadContextAware] mlir.ir.Location"; - static PyLocation &resolve(); - - operator MlirLocation() const { return *get(); } -}; - -/// Wrapper around MlirModule. -/// This is the top-level, user-owned object that contains regions/ops/blocks. -class PyModule; -using PyModuleRef = PyObjectRef; -class PyModule : public BaseContextObject { -public: - /// Returns a PyModule reference for the given MlirModule. This may return - /// a pre-existing or new object. - static PyModuleRef forModule(MlirModule module); - PyModule(PyModule &) = delete; - PyModule(PyMlirContext &&) = delete; - ~PyModule(); - - /// Gets the backing MlirModule. - MlirModule get() { return module; } - - /// Gets a strong reference to this module. - PyModuleRef getRef() { - return PyModuleRef(this, - pybind11::reinterpret_borrow(handle)); - } - - /// Gets a capsule wrapping the void* within the MlirModule. - /// Note that the module does not (yet) provide a corresponding factory for - /// constructing from a capsule as that would require uniquing PyModule - /// instances, which is not currently done. - pybind11::object getCapsule(); - - /// Creates a PyModule from the MlirModule wrapped by a capsule. - /// Note that PyModule instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirModule - /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); - -private: - PyModule(PyMlirContextRef contextRef, MlirModule module); - MlirModule module; - pybind11::handle handle; -}; - -/// Base class for PyOperation and PyOpView which exposes the primary, user -/// visible methods for manipulating it. -class PyOperationBase { -public: - virtual ~PyOperationBase() = default; - /// Implements the bound 'print' method and helps with others. - void print(pybind11::object fileObject, bool binary, - llvm::Optional largeElementsLimit, bool enableDebugInfo, - bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); - pybind11::object getAsm(bool binary, - llvm::Optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope); - - /// Each must provide access to the raw Operation. - virtual PyOperation &getOperation() = 0; -}; - -/// Wrapper around PyOperation. -/// Operations exist in either an attached (dependent) or detached (top-level) -/// state. In the detached state (as on creation), an operation is owned by -/// the creator and its lifetime extends either until its reference count -/// drops to zero or it is attached to a parent, at which point its lifetime -/// is bounded by its top-level parent reference. -class PyOperation; -using PyOperationRef = PyObjectRef; -class PyOperation : public PyOperationBase, public BaseContextObject { -public: - ~PyOperation(); - PyOperation &getOperation() override { return *this; } - - /// Returns a PyOperation for the given MlirOperation, optionally associating - /// it with a parentKeepAlive. - static PyOperationRef - forOperation(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); - - /// Creates a detached operation. The operation must not be associated with - /// any existing live operation. - static PyOperationRef - createDetached(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); - - /// Gets the backing operation. - operator MlirOperation() const { return get(); } - MlirOperation get() const { - checkValid(); - return operation; - } - - PyOperationRef getRef() { - return PyOperationRef( - this, pybind11::reinterpret_borrow(handle)); - } - - bool isAttached() { return attached; } - void setAttached() { - assert(!attached && "operation already attached"); - attached = true; - } - void checkValid() const; - - /// Gets the owning block or raises an exception if the operation has no - /// owning block. - PyBlock getBlock(); - - /// Gets the parent operation or raises an exception if the operation has - /// no parent. - PyOperationRef getParentOperation(); - - /// Creates an operation. See corresponding python docstring. - static pybind11::object - create(std::string name, llvm::Optional> results, - llvm::Optional> operands, - llvm::Optional attributes, - llvm::Optional> successors, int regions, - DefaultingPyLocation location, pybind11::object ip); - - /// Creates an OpView suitable for this operation. - pybind11::object createOpView(); - -private: - PyOperation(PyMlirContextRef contextRef, MlirOperation operation); - static PyOperationRef createInstance(PyMlirContextRef contextRef, - MlirOperation operation, - pybind11::object parentKeepAlive); - - MlirOperation operation; - pybind11::handle handle; - // Keeps the parent alive, regardless of whether it is an Operation or - // Module. - // TODO: As implemented, this facility is only sufficient for modeling the - // trivial module parent back-reference. Generalize this to also account for - // transitions from detached to attached and address TODOs in the - // ir_operation.py regarding testing corresponding lifetime guarantees. - pybind11::object parentKeepAlive; - bool attached = true; - bool valid = true; -}; - -/// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for -/// providing more instance-specific accessors and serve as the base class for -/// custom ODS-style operation classes. Since this class is subclass on the -/// python side, it must present an __init__ method that operates in pure -/// python types. -class PyOpView : public PyOperationBase { -public: - PyOpView(pybind11::object operationObject); - PyOperation &getOperation() override { return operation; } - - static pybind11::object createRawSubclass(pybind11::object userClass); - - pybind11::object getOperationObject() { return operationObject; } - - static pybind11::object - buildGeneric(pybind11::object cls, pybind11::list resultTypeList, - pybind11::list operandList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, DefaultingPyLocation location, - pybind11::object maybeIp); - -private: - PyOperation &operation; // For efficient, cast-free access from C++ - pybind11::object operationObject; // Holds the reference. -}; - -/// Wrapper around an MlirRegion. -/// Regions are managed completely by their containing operation. Unlike the -/// C++ API, the python API does not support detached regions. -class PyRegion { -public: - PyRegion(PyOperationRef parentOperation, MlirRegion region) - : parentOperation(std::move(parentOperation)), region(region) { - assert(!mlirRegionIsNull(region) && "python region cannot be null"); - } - - MlirRegion get() { return region; } - PyOperationRef &getParentOperation() { return parentOperation; } - - void checkValid() { return parentOperation->checkValid(); } - -private: - PyOperationRef parentOperation; - MlirRegion region; -}; - -/// Wrapper around an MlirBlock. -/// Blocks are managed completely by their containing operation. Unlike the -/// C++ API, the python API does not support detached blocks. -class PyBlock { -public: - PyBlock(PyOperationRef parentOperation, MlirBlock block) - : parentOperation(std::move(parentOperation)), block(block) { - assert(!mlirBlockIsNull(block) && "python block cannot be null"); - } - - MlirBlock get() { return block; } - PyOperationRef &getParentOperation() { return parentOperation; } - - void checkValid() { return parentOperation->checkValid(); } - -private: - PyOperationRef parentOperation; - MlirBlock block; -}; - -/// An insertion point maintains a pointer to a Block and a reference operation. -/// Calls to insert() will insert a new operation before the -/// reference operation. If the reference operation is null, then appends to -/// the end of the block. -class PyInsertionPoint { -public: - /// Creates an insertion point positioned after the last operation in the - /// block, but still inside the block. - PyInsertionPoint(PyBlock &block); - /// Creates an insertion point positioned before a reference operation. - PyInsertionPoint(PyOperationBase &beforeOperationBase); - - /// Shortcut to create an insertion point at the beginning of the block. - static PyInsertionPoint atBlockBegin(PyBlock &block); - /// Shortcut to create an insertion point before the block terminator. - static PyInsertionPoint atBlockTerminator(PyBlock &block); - - /// Inserts an operation. - void insert(PyOperationBase &operationBase); - - /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb); - - PyBlock &getBlock() { return block; } - -private: - // Trampoline constructor that avoids null initializing members while - // looking up parents. - PyInsertionPoint(PyBlock block, llvm::Optional refOperation) - : refOperation(std::move(refOperation)), block(std::move(block)) {} - - llvm::Optional refOperation; - PyBlock block; -}; - -/// Wrapper around the generic MlirAttribute. -/// The lifetime of a type is bound by the PyContext that created it. -class PyAttribute : public BaseContextObject { -public: - PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) - : BaseContextObject(std::move(contextRef)), attr(attr) {} - bool operator==(const PyAttribute &other); - operator MlirAttribute() const { return attr; } - MlirAttribute get() const { return attr; } - - /// Gets a capsule wrapping the void* within the MlirAttribute. - pybind11::object getCapsule(); - - /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. - /// Note that PyAttribute instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirAttribute - /// is taken by calling this function. - static PyAttribute createFromCapsule(pybind11::object capsule); - -private: - MlirAttribute attr; -}; - -/// Represents a Python MlirNamedAttr, carrying an optional owned name. -/// TODO: Refactor this and the C-API to be based on an Identifier owned -/// by the context so as to avoid ownership issues here. -class PyNamedAttribute { -public: - /// Constructs a PyNamedAttr that retains an owned name. This should be - /// used in any code that originates an MlirNamedAttribute from a python - /// string. - /// The lifetime of the PyNamedAttr must extend to the lifetime of the - /// passed attribute. - PyNamedAttribute(MlirAttribute attr, std::string ownedName); - - MlirNamedAttribute namedAttr; - -private: - // Since the MlirNamedAttr contains an internal pointer to the actual - // memory of the owned string, it must be heap allocated to remain valid. - // Otherwise, strings that fit within the small object optimization threshold - // will have their memory address change as the containing object is moved, - // resulting in an invalid aliased pointer. - std::unique_ptr ownedName; -}; - -/// Wrapper around the generic MlirType. -/// The lifetime of a type is bound by the PyContext that created it. -class PyType : public BaseContextObject { -public: - PyType(PyMlirContextRef contextRef, MlirType type) - : BaseContextObject(std::move(contextRef)), type(type) {} - bool operator==(const PyType &other); - operator MlirType() const { return type; } - MlirType get() const { return type; } - - /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); - - /// Creates a PyType from the MlirType wrapped by a capsule. - /// Note that PyType instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirType - /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); - -private: - MlirType type; -}; - -/// Wrapper around the generic MlirValue. -/// Values are managed completely by the operation that resulted in their -/// definition. For op result value, this is the operation that defines the -/// value. For block argument values, this is the operation that contains the -/// block to which the value is an argument (blocks cannot be detached in Python -/// bindings so such operation always exists). -class PyValue { -public: - PyValue(PyOperationRef parentOperation, MlirValue value) - : parentOperation(parentOperation), value(value) {} - - MlirValue get() { return value; } - PyOperationRef &getParentOperation() { return parentOperation; } - - void checkValid() { return parentOperation->checkValid(); } - -private: - PyOperationRef parentOperation; - MlirValue value; -}; - -/// Wrapper around MlirAffineExpr. Affine expressions are owned by the context. -class PyAffineExpr : public BaseContextObject { -public: - PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) - : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} - bool operator==(const PyAffineExpr &other); - operator MlirAffineExpr() const { return affineExpr; } - MlirAffineExpr get() const { return affineExpr; } - - /// Gets a capsule wrapping the void* within the MlirAffineExpr. - pybind11::object getCapsule(); - - /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. - /// Note that PyAffineExpr instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr - /// is taken by calling this function. - static PyAffineExpr createFromCapsule(pybind11::object capsule); - - PyAffineExpr add(const PyAffineExpr &other) const; - PyAffineExpr mul(const PyAffineExpr &other) const; - PyAffineExpr floorDiv(const PyAffineExpr &other) const; - PyAffineExpr ceilDiv(const PyAffineExpr &other) const; - PyAffineExpr mod(const PyAffineExpr &other) const; - -private: - MlirAffineExpr affineExpr; -}; - -class PyAffineMap : public BaseContextObject { -public: - PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) - : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} - bool operator==(const PyAffineMap &other); - operator MlirAffineMap() const { return affineMap; } - MlirAffineMap get() const { return affineMap; } - - /// Gets a capsule wrapping the void* within the MlirAffineMap. - pybind11::object getCapsule(); - - /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. - /// Note that PyAffineMap instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirAffineMap - /// is taken by calling this function. - static PyAffineMap createFromCapsule(pybind11::object capsule); - -private: - MlirAffineMap affineMap; -}; - -class PyIntegerSet : public BaseContextObject { -public: - PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) - : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} - bool operator==(const PyIntegerSet &other); - operator MlirIntegerSet() const { return integerSet; } - MlirIntegerSet get() const { return integerSet; } - - /// Gets a capsule wrapping the void* within the MlirIntegerSet. - pybind11::object getCapsule(); - - /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. - /// Note that PyIntegerSet instances may be uniqued, so the returned object - /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(pybind11::object capsule); - -private: - MlirIntegerSet integerSet; -}; - -void populateIRSubmodule(pybind11::module &m); - -} // namespace python -} // namespace mlir - -namespace pybind11 { -namespace detail { - -template <> -struct type_caster - : MlirDefaultingCaster {}; -template <> -struct type_caster - : MlirDefaultingCaster {}; - -} // namespace detail -} // namespace pybind11 - -#endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp new file mode 100644 index 000000000..b11e3f75b --- /dev/null +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -0,0 +1,1026 @@ +//===- IRTypes.cpp - Exports builtin and standard types -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +// clang-format off +#include "IRModule.h" +#include "mlir/Bindings/Python/IRTypes.h" +// clang-format on + +#include + +#include "IRModule.h" +#include "NanobindUtils.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir-c/Support.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::Twine; + +namespace { + +/// Checks whether the given type is an integer or float type. +static int mlirTypeIsAIntegerOrFloat(MlirType type) { + return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || + mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); +} + +class PyIntegerType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nb::arg("width"), nb::arg("context").none() = nb::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nb::arg("width"), nb::arg("context").none() = nb::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nb::arg("width"), nb::arg("context").none() = nb::none(), + "Create an unsigned integer type"); + c.def_prop_ro( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_prop_ro( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_prop_ro( + "is_signed", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSigned(self); + }, + "Returns whether this is a signed integer"); + c.def_prop_ro( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); + } +}; + +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a index type."); + } +}; + +class PyFloatType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; + static constexpr const char *pyClassName = "FloatType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_prop_ro( + "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, + "Returns the width of the floating-point type"); + } +}; + +/// Floating Point Type subclass - Float4E2M1FNType. +class PyFloat4E2M1FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat4E2M1FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float4E2M1FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); + return PyFloat4E2M1FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); + } +}; + +/// Floating Point Type subclass - Float6E2M3FNType. +class PyFloat6E2M3FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E2M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E2M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); + return PyFloat6E2M3FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); + } +}; + +/// Floating Point Type subclass - Float6E3M2FNType. +class PyFloat6E3M2FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E3M2FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E3M2FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); + return PyFloat6E3M2FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); + } +}; + +/// Floating Point Type subclass - Float8E4M3FNType. +class PyFloat8E4M3FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); + } +}; + +/// Floating Point Type subclass - Float8E5M2Type. +class PyFloat8E5M2Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E5M2Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); + } +}; + +/// Floating Point Type subclass - Float8E4M3Type. +class PyFloat8E4M3Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3TypeGet(context->get()); + return PyFloat8E4M3Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); + } +}; + +/// Floating Point Type subclass - Float8E4M3FNUZ. +class PyFloat8E4M3FNUZType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3fnuz type."); + } +}; + +/// Floating Point Type subclass - Float8E4M3B11FNUZ. +class PyFloat8E4M3B11FNUZType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3b11fnuz type."); + } +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class PyFloat8E5M2FNUZType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), + "Create a float8_e5m2fnuz type."); + } +}; + +/// Floating Point Type subclass - Float8E3M4Type. +class PyFloat8E3M4Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E3M4TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E3M4Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E3M4TypeGet(context->get()); + return PyFloat8E3M4Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); + } +}; + +/// Floating Point Type subclass - Float8E8M0FNUType. +class PyFloat8E8M0FNUType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E8M0FNUTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E8M0FNUType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); + return PyFloat8E8M0FNUType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), + "Create a float8_e8m0fnu type."); + } +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a bf16 type."); + } +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a f16 type."); + } +}; + +/// Floating Point Type subclass - TF32Type. +class PyTF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatTF32TypeGetTypeID; + static constexpr const char *pyClassName = "FloatTF32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirTF32TypeGet(context->get()); + return PyTF32Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a tf32 type."); + } +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a f32 type."); + } +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a f64 type."); + } +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + nb::arg("context").none() = nb::none(), "Create a none type."); + } +}; + +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw nb::value_error( + (Twine("invalid '") + + nb::cast(nb::repr(nb::cast(elementType))) + + "' and expected floating point or integer type.") + .str() + .c_str()); + }, + "Create a complex type"); + c.def_prop_ro( + "element_type", + [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, + "Returns element type."); + } +}; + +} // namespace + +// Shaped Type Interface - ShapedType +void mlir::PyShapedType::bindDerived(ClassTy &c) { + c.def_prop_ro( + "element_type", + [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, + "Returns the element type of the shaped type."); + c.def_prop_ro( + "has_rank", + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, + "Returns whether the given shaped type is ranked."); + c.def_prop_ro( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self); + }, + "Returns the rank of the given ranked shaped type."); + c.def_prop_ro( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self, dim); + }, + nb::arg("dim"), + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "is_static_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsStaticDim(self, dim); + }, + nb::arg("dim"), + "Returns whether the dim-th dimension of the given shaped type is " + "static."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self, dim); + }, + nb::arg("dim"), + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + nb::arg("dim_size"), + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def_static( + "is_static_size", + [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); }, + nb::arg("dim_size"), + "Returns whether the given dimension size indicates a static " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + nb::arg("dim_size"), + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + c.def( + "is_static_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsStaticStrideOrOffset(val); + }, + nb::arg("dim_size"), + "Returns whether the given shaped type stride or offset value is " + "statically-sized."); + c.def_prop_ro( + "shape", + [](PyShapedType &self) { + self.requireHasRank(); + + std::vector shape; + int64_t rank = mlirShapedTypeGetRank(self); + shape.reserve(rank); + for (int64_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(self, i)); + return shape; + }, + "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); +} + +void mlir::PyShapedType::requireHasRank() { + if (!mlirShapedTypeHasRank(*this)) { + throw nb::value_error( + "calling this method requires that the type has a rank."); + } +} + +const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = + mlirTypeIsAShaped; + +namespace { + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable").none() = nb::none(), + nb::arg("scalable_dims").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a vector type") + .def_prop_ro( + "scalable", + [](MlirType self) { return mlirVectorTypeIsScalable(self); }) + .def_prop_ro("scalable_dims", [](MlirType self) { + std::vector scalableDims; + size_t rank = static_cast(mlirShapedTypeGetRank(self)); + scalableDims.reserve(rank); + for (size_t i = 0; i < rank; ++i) + scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); + return scalableDims; + }); + } + +private: + static PyVectorType get(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyLocation loc) { + if (scalable && scalableDims) { + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); + } + + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw nb::value_error("Expected len(scalable) == len(shape)."); + + SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( + *scalable, [](const nb::handle &h) { return nb::cast(h); })); + type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), + scalableDimFlags.data(), + elementType); + } else if (scalableDims) { + SmallVector scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) + throw nb::value_error("Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; + } + type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), + scalableDimFlags.data(), + elementType); + } else { + type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + } + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); + } +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); + c.def_prop_ro("encoding", + [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = + mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return encoding; + }); + } +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("loc").none() = nb::none(), + "Create a unranked tensor type"); + } +}; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + PyAttribute *layout, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout").none() = nb::none(), + nb::arg("memory_space").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a memref type") + .def_prop_ro( + "layout", + [](PyMemRefType &self) -> MlirAttribute { + return mlirMemRefTypeGetLayout(self); + }, + "The layout of the MemRef type.") + .def( + "get_strides_and_offset", + [](PyMemRefType &self) -> std::pair, int64_t> { + std::vector strides(mlirShapedTypeGetRank(self)); + int64_t offset; + if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( + self, strides.data(), &offset))) + throw std::runtime_error( + "Failed to extract strides and offset from memref."); + return {strides, offset}; + }, + "The strides and offset of the MemRef type.") + .def_prop_ro( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") + .def_prop_ro( + "memory_space", + [](PyMemRefType &self) -> std::optional { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; + }, + "Returns the memory space of the given MemRef type."); + } +}; + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("loc").none() = nb::none(), "Create a unranked memref type") + .def_prop_ro( + "memory_space", + [](PyUnrankedMemRefType &self) -> std::optional { + MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; + }, + "Returns the memory space of the given Unranked MemRef type."); + } +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](std::vector elements, DefaultingPyMlirContext context) { + MlirType t = mlirTupleTypeGet(context->get(), elements.size(), + elements.data()); + return PyTupleType(context->getRef(), t); + }, + nb::arg("elements"), nb::arg("context").none() = nb::none(), + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) { + return mlirTupleTypeGetType(self, pos); + }, + nb::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_prop_ro( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); + } +}; + +/// Function type. +class PyFunctionType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + MlirType t = + mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), + results.size(), results.data()); + return PyFunctionType(context->getRef(), t); + }, + nb::arg("inputs"), nb::arg("results"), + nb::arg("context").none() = nb::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_prop_ro( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + nb::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetInput(t, i)); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_prop_ro( + "results", + [](PyFunctionType &self) { + nb::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetResult(self, i)); + } + return types; + }, + "Returns the list of result types in the FunctionType."); + } +}; + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +/// Opaque Type subclass - OpaqueType. +class PyOpaqueType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; + static constexpr const char *pyClassName = "OpaqueType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, std::string typeData, + DefaultingPyMlirContext context) { + MlirType type = mlirOpaqueTypeGet(context->get(), + toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + nb::arg("dialect_namespace"), nb::arg("buffer"), + nb::arg("context").none() = nb::none(), + "Create an unregistered (opaque) dialect type."); + c.def_prop_ro( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_prop_ro( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); + } +}; + +} // namespace + +void mlir::python::populateIRTypes(nb::module_ &m) { + PyIntegerType::bind(m); + PyFloatType::bind(m); + PyIndexType::bind(m); + PyFloat4E2M1FNType::bind(m); + PyFloat6E2M3FNType::bind(m); + PyFloat6E3M2FNType::bind(m); + PyFloat8E4M3FNType::bind(m); + PyFloat8E5M2Type::bind(m); + PyFloat8E4M3Type::bind(m); + PyFloat8E4M3FNUZType::bind(m); + PyFloat8E4M3B11FNUZType::bind(m); + PyFloat8E5M2FNUZType::bind(m); + PyFloat8E3M4Type::bind(m); + PyFloat8E8M0FNUType::bind(m); + PyBF16Type::bind(m); + PyF16Type::bind(m); + PyTF32Type::bind(m); + PyF32Type::bind(m); + PyF64Type::bind(m); + PyNoneType::bind(m); + PyComplexType::bind(m); + PyShapedType::bind(m); + PyVectorType::bind(m); + PyRankedTensorType::bind(m); + PyUnrankedTensorType::bind(m); + PyMemRefType::bind(m); + PyUnrankedMemRefType::bind(m); + PyTupleType::bind(m); + PyFunctionType::bind(m); + PyOpaqueType::bind(m); +} diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp similarity index 58% rename from mlir/lib/Bindings/Python/Conversions/Conversions.cpp rename to mlir/lib/Bindings/Python/LinalgPasses.cpp index f8b3b2041..49f2ea941 100644 --- a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp +++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp @@ -1,4 +1,4 @@ -//===- Conversions.cpp - Pybind module for the Conversionss library -------===// +//===- LinalgPasses.cpp - Pybind module for the Linalg passes -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,19 +6,17 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Conversion.h" +#include "mlir-c/Dialect/Linalg.h" -#include - -namespace py = pybind11; +#include "mlir/Bindings/Python/Nanobind.h" // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirConversions, m) { - m.doc() = "MLIR Conversions library"; +NB_MODULE(_mlirLinalgPasses, m) { + m.doc() = "MLIR Linalg Dialect Passes"; - // Register all the passes in the Conversions library on load. - mlirRegisterConversionPasses(); + // Register all Linalg passes on load. + mlirRegisterLinalgPasses(); } diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 9bfe8b09f..278847e7a 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,220 +6,137 @@ // //===----------------------------------------------------------------------===// -#include - -#include "PybindUtils.h" - -#include "ExecutionEngine.h" #include "Globals.h" -#include "IRModules.h" +#include "IRModule.h" +#include "NanobindUtils.h" #include "Pass.h" +#include "Rewrite.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; +using namespace nb::literals; using namespace mlir::python; -// ----------------------------------------------------------------------------- -// PyGlobals -// ----------------------------------------------------------------------------- - -PyGlobals *PyGlobals::instance = nullptr; - -PyGlobals::PyGlobals() { - assert(!instance && "PyGlobals already constructed"); - instance = this; -} - -PyGlobals::~PyGlobals() { instance = nullptr; } - -void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - py::gil_scoped_acquire(); - if (loadedDialectModulesCache.contains(dialectNamespace)) - return; - // Since re-entrancy is possible, make a copy of the search prefixes. - std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded; - for (std::string moduleName : localSearchPrefixes) { - moduleName.push_back('.'); - moduleName.append(dialectNamespace.data(), dialectNamespace.size()); - - try { - py::gil_scoped_release(); - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { - if (e.matches(PyExc_ModuleNotFoundError)) { - continue; - } else { - throw; - } - } - break; - } - - // Note: Iterator cannot be shared from prior to loading, since re-entrancy - // may have occurred, which may do anything. - loadedDialectModulesCache.insert(dialectNamespace); -} - -void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::gil_scoped_acquire(); - py::object &found = dialectClassMap[dialectNamespace]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + - dialectNamespace + - "' is already registered."); - } - found = std::move(pyClass); -} - -void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, - py::object rawOpViewClass) { - py::gil_scoped_acquire(); - py::object &found = operationClassMap[operationName]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + - operationName + - "' is already registered."); - } - found = std::move(pyClass); - rawOpViewClassMap[operationName] = std::move(rawOpViewClass); -} - -llvm::Optional -PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { - py::gil_scoped_acquire(); - loadDialectModule(dialectNamespace); - // Fast match against the class map first (common case). - const auto foundIt = dialectClassMap.find(dialectNamespace); - if (foundIt != dialectClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - - // Not found and loading did not yield a registration. Negative cache. - dialectClassMap[dialectNamespace] = py::none(); - return llvm::None; -} - -llvm::Optional -PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMapCache.find(operationName); - if (foundIt != rawOpViewClassMapCache.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. - auto split = operationName.split('.'); - llvm::StringRef dialectNamespace = split.first; - loadDialectModule(dialectNamespace); - - // Attempt to find from the canonical map and cache. - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMap.find(operationName); - if (foundIt != rawOpViewClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - rawOpViewClassMapCache[operationName] = foundIt->second; - return foundIt->second; - } else { - // Negative cache. - rawOpViewClassMap[operationName] = py::none(); - return llvm::None; - } - } -} - -void PyGlobals::clearImportCache() { - py::gil_scoped_acquire(); - loadedDialectModulesCache.clear(); - rawOpViewClassMapCache.clear(); -} - // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlir, m) { +NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - py::class_(m, "_Globals") - .def_property("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) - .def("append_dialect_search_prefix", - [](PyGlobals &self, std::string moduleName) { - self.getDialectSearchPrefixes().push_back(std::move(moduleName)); - self.clearImportCache(); - }) + nb::class_(m, "_Globals") + .def_prop_rw("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) + .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, + "module_name"_a) + .def( + "_check_dialect_module_loaded", + [](PyGlobals &self, const std::string &dialectNamespace) { + return self.loadDialectModule(dialectNamespace); + }, + "dialect_namespace"_a) .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, + "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "Testing hook for directly registering an operation"); + "operation_name"_a, "operation_class"_a, nb::kw_only(), + "replace"_a = false, + "Testing hook for directly registering an operation") + .def("loc_tracebacks_enabled", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebacksEnabled(); + }) + .def("set_loc_tracebacks_enabled", + [](PyGlobals &self, bool enabled) { + self.getTracebackLoc().setLocTracebacksEnabled(enabled); + }) + .def("set_loc_tracebacks_frame_limit", + [](PyGlobals &self, int n) { + self.getTracebackLoc().setLocTracebackFramesLimit(n); + }) + .def("register_traceback_file_inclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileInclusion(filename); + }) + .def("register_traceback_file_exclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileExclusion(filename); + }); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python // resources) properly. - m.attr("globals") = - py::cast(new PyGlobals, py::return_value_policy::take_ownership); + m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); // Registration decorators. m.def( "register_dialect", - [](py::object pyClass) { + [](nb::type_object pyClass) { std::string dialectNamespace = - pyClass.attr("DIALECT_NAMESPACE").cast(); + nanobind::cast(pyClass.attr("DIALECT_NAMESPACE")); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, + "dialect_class"_a, "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](py::object dialectClass) -> py::cpp_function { - return py::cpp_function( - [dialectClass](py::object opClass) -> py::object { + [](const nb::type_object &dialectClass, bool replace) -> nb::object { + return nb::cpp_function( + [dialectClass, + replace](nb::type_object opClass) -> nb::type_object { std::string operationName = - opClass.attr("OPERATION_NAME").cast(); - auto rawSubclass = PyOpView::createRawSubclass(opClass); + nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, - rawSubclass); - + replace); // Dict-stuff the new opClass by name onto the dialect class. - py::object opClassName = opClass.attr("__name__"); + nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; - - // Now create a special "Raw" subclass that passes through - // construction to the OpView parent (bypasses the intermediate - // child's __init__). - opClass.attr("_Raw") = rawSubclass; return opClass; }); }, - "Class decorator for registering a custom Operation wrapper"); + "dialect_class"_a, nb::kw_only(), "replace"_a = false, + "Produce a class decorator for registering an Operation class as part of " + "a dialect"); + m.def( + MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { + PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); + return typeCaster; + }); + }, + "typeid"_a, nb::kw_only(), "replace"_a = false, + "Register a type caster for casting MLIR types to custom user types."); + m.def( + MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function( + [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { + PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, + replace); + return valueCaster; + }); + }, + "typeid"_a, nb::kw_only(), "replace"_a = false, + "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); - populateIRSubmodule(irModule); + populateIRCore(irModule); + populateIRAffine(irModule); + populateIRAttributes(irModule); + populateIRInterfaces(irModule); + populateIRTypes(irModule); + + auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings"); + populateRewriteSubmodule(rewriteModule); // Define and populate PassManager submodule. auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); - - // Define and populate ExecutionEngine submodule. - auto executionEngineModule = - m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); - populateExecutionEngineSubmodule(executionEngineModule); } diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h new file mode 100644 index 000000000..64ea4329f --- /dev/null +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -0,0 +1,427 @@ +//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ +//-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H +#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H + +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/DataTypes.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include + +template <> +struct std::iterator_traits { + using value_type = nanobind::handle; + using reference = const value_type; + using pointer = void; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; +}; + +namespace mlir { +namespace python { + +/// CRTP template for special wrapper types that are allowed to be passed in as +/// 'None' function arguments and can be resolved by some global mechanic if +/// so. Such types will raise an error if this global resolution fails, and +/// it is actually illegal for them to ever be unresolved. From a user +/// perspective, they behave like a smart ptr to the underlying type (i.e. +/// 'get' method and operator-> overloaded). +/// +/// Derived types must provide a method, which is called when an environmental +/// resolution is required. It must raise an exception if resolution fails: +/// static ReferrentTy &resolve() +/// +/// They must also provide a parameter description that will be used in +/// error messages about mismatched types: +/// static constexpr const char kTypeDescription[] = ""; + +template +class Defaulting { +public: + using ReferrentTy = T; + /// Type casters require the type to be default constructible, but using + /// such an instance is illegal. + Defaulting() = default; + Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} + + ReferrentTy *get() const { return referrent; } + ReferrentTy *operator->() { return referrent; } + +private: + ReferrentTy *referrent = nullptr; +}; + +} // namespace python +} // namespace mlir + +namespace nanobind { +namespace detail { + +template +struct MlirDefaultingCaster { + NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)) + + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + if (src.is_none()) { + // Note that we do want an exception to propagate from here as it will be + // the most informative. + value = DefaultingTy{DefaultingTy::resolve()}; + return true; + } + + // Unlike many casters that chain, these casters are expected to always + // succeed, so instead of doing an isinstance check followed by a cast, + // just cast in one step and handle the exception. Returning false (vs + // letting the exception propagate) causes higher level signature parsing + // code to produce nice error messages (other than "Cannot cast..."). + try { + value = DefaultingTy{ + nanobind::cast(src)}; + return true; + } catch (std::exception &) { + return false; + } + } + + static handle from_cpp(DefaultingTy src, rv_policy policy, + cleanup_list *cleanup) noexcept { + return nanobind::cast(src, policy); + } +}; +} // namespace detail +} // namespace nanobind + +//------------------------------------------------------------------------------ +// Conversion utilities. +//------------------------------------------------------------------------------ + +namespace mlir { + +/// Accumulates into a python string from a method that accepts an +/// MlirStringCallback. +struct PyPrintAccumulator { + nanobind::list parts; + + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](MlirStringRef part, void *userData) { + PyPrintAccumulator *printAccum = + static_cast(userData); + nanobind::str pyPart(part.data, + part.length); // Decodes as UTF-8 by default. + printAccum->parts.append(std::move(pyPart)); + }; + } + + nanobind::str join() { + nanobind::str delim("", 0); + return nanobind::cast(delim.attr("join")(parts)); + } +}; + +/// Accumulates into a file, either writing text (default) +/// or binary. The file may be a Python file-like object or a path to a file. +class PyFileAccumulator { +public: + PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary) + : binary(binary) { + std::string filePath; + if (nanobind::try_cast(fileOrStringObject, filePath)) { + std::error_code ec; + writeTarget.emplace(filePath, ec); + if (ec) { + throw nanobind::value_error( + (std::string("Unable to open file for writing: ") + ec.message()) + .c_str()); + } + } else { + writeTarget.emplace(fileOrStringObject.attr("write")); + } + } + + MlirStringCallback getCallback() { + return writeTarget.index() == 0 ? getPyWriteCallback() + : getOstreamCallback(); + } + + void *getUserData() { return this; } + +private: + MlirStringCallback getPyWriteCallback() { + return [](MlirStringRef part, void *userData) { + nanobind::gil_scoped_acquire acquire; + PyFileAccumulator *accum = static_cast(userData); + if (accum->binary) { + // Note: Still has to copy and not avoidable with this API. + nanobind::bytes pyBytes(part.data, part.length); + std::get(accum->writeTarget)(pyBytes); + } else { + nanobind::str pyStr(part.data, + part.length); // Decodes as UTF-8 by default. + std::get(accum->writeTarget)(pyStr); + } + }; + } + + MlirStringCallback getOstreamCallback() { + return [](MlirStringRef part, void *userData) { + PyFileAccumulator *accum = static_cast(userData); + std::get(accum->writeTarget) + .write(part.data, part.length); + }; + } + + std::variant writeTarget; + bool binary; +}; + +/// Accumulates into a python string from a method that is expected to make +/// one (no more, no less) call to the callback (asserts internally on +/// violation). +struct PySinglePartStringAccumulator { + void *getUserData() { return this; } + + MlirStringCallback getCallback() { + return [](MlirStringRef part, void *userData) { + PySinglePartStringAccumulator *accum = + static_cast(userData); + assert(!accum->invoked && + "PySinglePartStringAccumulator called back multiple times"); + accum->invoked = true; + accum->value = nanobind::str(part.data, part.length); + }; + } + + nanobind::str takeValue() { + assert(invoked && "PySinglePartStringAccumulator not called back"); + return std::move(value); + } + +private: + nanobind::str value; + bool invoked = false; +}; + +/// A CRTP base class for pseudo-containers willing to support Python-type +/// slicing access on top of indexed access. Calling ::bind on this class +/// will define `__len__` as well as `__getitem__` with integer and slice +/// arguments. +/// +/// This is intended for pseudo-containers that can refer to arbitrary slices of +/// underlying storage indexed by a single integer. Indexing those with an +/// integer produces an instance of ElementTy. Indexing those with a slice +/// produces a new instance of Derived, which can be sliced further. +/// +/// A derived class must provide the following: +/// - a `static const char *pyClassName ` field containing the name of the +/// Python class to bind; +/// - an instance method `intptr_t getRawNumElements()` that returns the +/// number +/// of elements in the backing container (NOT that of the slice); +/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a +/// single element at the given linear index (NOT slice index); +/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that +/// constructs a new instance of the derived pseudo-container with the +/// given slice parameters (to be forwarded to the Sliceable constructor). +/// +/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not +/// throw. +/// +/// A derived class may additionally define: +/// - a `static void bindDerived(ClassTy &)` method to bind additional methods +/// the python class. +template +class Sliceable { +protected: + using ClassTy = nanobind::class_; + + /// Transforms `index` into a legal value to access the underlying sequence. + /// Returns <0 on failure. + intptr_t wrapIndex(intptr_t index) { + if (index < 0) + index = length + index; + if (index < 0 || index >= length) + return -1; + return index; + } + + /// Computes the linear index given the current slice properties. + intptr_t linearizeIndex(intptr_t index) { + intptr_t linearIndex = index * step + startIndex; + assert(linearIndex >= 0 && + linearIndex < static_cast(this)->getRawNumElements() && + "linear index out of bounds, the slice is ill-formed"); + return linearIndex; + } + + /// Trait to check if T provides a `maybeDownCast` method. + /// Note, you need the & to detect inherited members. + template + using has_maybe_downcast = decltype(&T::maybeDownCast); + + /// Returns the element at the given slice index. Supports negative indices + /// by taking elements in inverse order. Returns a nullptr object if out + /// of bounds. + nanobind::object getItem(intptr_t index) { + // Negative indices mean we count from the end. + index = wrapIndex(index); + if (index < 0) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return {}; + } + + if constexpr (llvm::is_detected::value) + return static_cast(this) + ->getRawElement(linearizeIndex(index)) + .maybeDownCast(); + else + return nanobind::cast( + static_cast(this)->getRawElement(linearizeIndex(index))); + } + + /// Returns a new instance of the pseudo-container restricted to the given + /// slice. Returns a nullptr object on failure. + nanobind::object getItemSlice(PyObject *slice) { + ssize_t start, stop, extraStep, sliceLength; + if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, + &sliceLength) != 0) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return {}; + } + return nanobind::cast(static_cast(this)->slice( + startIndex + start * step, sliceLength, step * extraStep)); + } + +public: + explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) + : startIndex(startIndex), length(length), step(step) { + assert(length >= 0 && "expected non-negative slice length"); + } + + /// Returns the `index`-th element in the slice, supports negative indices. + /// Throws if the index is out of bounds. + ElementTy getElement(intptr_t index) { + // Negative indices mean we count from the end. + index = wrapIndex(index); + if (index < 0) { + throw nanobind::index_error("index out of range"); + } + + return static_cast(this)->getRawElement(linearizeIndex(index)); + } + + /// Returns the size of slice. + intptr_t size() { return length; } + + /// Returns a new vector (mapped to Python list) containing elements from two + /// slices. The new vector is necessary because slices may not be contiguous + /// or even come from the same original sequence. + std::vector dunderAdd(Derived &other) { + std::vector elements; + elements.reserve(length + other.length); + for (intptr_t i = 0; i < length; ++i) { + elements.push_back(static_cast(this)->getElement(i)); + } + for (intptr_t i = 0; i < other.length; ++i) { + elements.push_back(static_cast(&other)->getElement(i)); + } + return elements; + } + + /// Binds the indexing and length methods in the Python class. + static void bind(nanobind::module_ &m) { + auto clazz = nanobind::class_(m, Derived::pyClassName) + .def("__add__", &Sliceable::dunderAdd); + Derived::bindDerived(clazz); + + // Manually implement the sequence protocol via the C API. We do this + // because it is approx 4x faster than via nanobind, largely because that + // formulation requires a C++ exception to be thrown to detect end of + // sequence. + // Since we are in a C-context, any C++ exception that happens here + // will terminate the program. There is nothing in this implementation + // that should throw in a non-terminal way, so we forgo further + // exception marshalling. + // See: https://github.com/pybind/nanobind/issues/2842 + auto heap_type = reinterpret_cast(clazz.ptr()); + assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && + "must be heap type"); + heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { + auto self = nanobind::cast(nanobind::handle(rawSelf)); + return self->length; + }; + // sq_item is called as part of the sequence protocol for iteration, + // list construction, etc. + heap_type->as_sequence.sq_item = + +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { + auto self = nanobind::cast(nanobind::handle(rawSelf)); + return self->getItem(index).release().ptr(); + }; + // mp_subscript is used for both slices and integer lookups. + heap_type->as_mapping.mp_subscript = + +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { + auto self = nanobind::cast(nanobind::handle(rawSelf)); + Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); + if (!PyErr_Occurred()) { + // Integer indexing. + return self->getItem(index).release().ptr(); + } + PyErr_Clear(); + + // Assume slice-based indexing. + if (PySlice_Check(rawSubscript)) { + return self->getItemSlice(rawSubscript).release().ptr(); + } + + PyErr_SetString(PyExc_ValueError, "expected integer or slice"); + return nullptr; + }; + } + + /// Hook for derived classes willing to bind more methods. + static void bindDerived(ClassTy &) {} + +private: + intptr_t startIndex; + intptr_t length; + intptr_t step; +}; + +} // namespace mlir + +namespace llvm { + +template <> +struct DenseMapInfo { + static inline MlirTypeID getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlirTypeIDCreate(pointer); + } + static inline MlirTypeID getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlirTypeIDCreate(pointer); + } + static inline unsigned getHashValue(const MlirTypeID &val) { + return mlirTypeIDHashValue(val); + } + static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { + return mlirTypeIDEqual(lhs, rhs); + } +}; +} // namespace llvm + +#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index dd57647f0..20017e25b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,11 +8,13 @@ #include "Pass.h" -#include "IRModules.h" -#include "mlir-c/Bindings/Python/Interop.h" +#include "IRModule.h" #include "mlir-c/Pass.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. -namespace py = pybind11; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -22,7 +24,8 @@ namespace { class PyPassManager { public: PyPassManager(MlirPassManager passManager) : passManager(passManager) {} - PyPassManager(PyPassManager &&other) : passManager(other.passManager) { + PyPassManager(PyPassManager &&other) noexcept + : passManager(other.passManager) { other.passManager.ptr = nullptr; } ~PyPassManager() { @@ -32,70 +35,146 @@ class PyPassManager { MlirPassManager get() { return passManager; } void release() { passManager.ptr = nullptr; } - pybind11::object getCapsule() { - return py::reinterpret_steal( - mlirPythonPassManagerToCapsule(get())); + nb::object getCapsule() { + return nb::steal(mlirPythonPassManagerToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) - throw py::error_already_set(); - return py::cast(PyPassManager(rawPm), py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); } private: MlirPassManager passManager; }; -} // anonymous namespace +} // namespace /// Create the `mlir.passmanager` here. -void mlir::python::populatePassManagerSubmodule(py::module &m) { +void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "PassManager") - .def(py::init<>([](DefaultingPyMlirContext context) { - MlirPassManager passManager = - mlirPassManagerCreate(context->get()); - return new PyPassManager(passManager); - }), - py::arg("context") = py::none(), - "Create a new PassManager for the current (or provided) Context.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyPassManager::getCapsule) + nb::class_(m, "PassManager") + .def( + "__init__", + [](PyPassManager &self, const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); + new (&self) PyPassManager(passManager); + }, + "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), + "Create a new PassManager for the current (or provided) Context.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") + .def( + "enable_ir_printing", + [](PyPassManager &passManager, bool printBeforeAll, + bool printAfterAll, bool printModuleScope, bool printAfterChange, + bool printAfterFailure, std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, + bool printGenericOpForm, + std::optional optionalTreePrintingPath) { + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (largeElementsLimit) { + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, + *largeElementsLimit); + mlirOpPrintingFlagsElideLargeResourceString(flags, + *largeElementsLimit); + } + if (largeResourceLimit) + mlirOpPrintingFlagsElideLargeResourceString(flags, + *largeResourceLimit); + if (enableDebugInfo) + mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + /*prettyForm=*/false); + if (printGenericOpForm) + mlirOpPrintingFlagsPrintGenericOpForm(flags); + std::string treePrintingPath = ""; + if (optionalTreePrintingPath.has_value()) + treePrintingPath = optionalTreePrintingPath.value(); + mlirPassManagerEnableIRPrinting( + passManager.get(), printBeforeAll, printAfterAll, + printModuleScope, printAfterChange, printAfterFailure, flags, + mlirStringRefCreate(treePrintingPath.data(), + treePrintingPath.size())); + mlirOpPrintingFlagsDestroy(flags); + }, + "print_before_all"_a = false, "print_after_all"_a = true, + "print_module_scope"_a = false, "print_after_change"_a = false, + "print_after_failure"_a = false, + "large_elements_limit"_a.none() = nb::none(), + "large_resource_limit"_a.none() = nb::none(), + "enable_debug_info"_a = false, "print_generic_op_form"_a = false, + "tree_printing_dir_path"_a.none() = nb::none(), + "Enable IR printing, default as mlir-print-ir-after-all.") + .def( + "enable_verifier", + [](PyPassManager &passManager, bool enable) { + mlirPassManagerEnableVerifier(passManager.get(), enable); + }, + "enable"_a, "Enable / disable verify-each.") + .def( + "enable_timing", + [](PyPassManager &passManager) { + mlirPassManagerEnableTiming(passManager.get()); + }, + "Enable pass timing.") .def_static( "parse", - [](const std::string pipeline, DefaultingPyMlirContext context) { + [](const std::string &pipeline, DefaultingPyMlirContext context) { MlirPassManager passManager = mlirPassManagerCreate(context->get()); + PyPrintAccumulator errorMsg; MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(passManager), - mlirStringRefCreate(pipeline.data(), pipeline.size())); + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, - llvm::Twine("invalid pass pipeline '") + - pipeline + "'."); + throw nb::value_error(errorMsg.join().c_str()); return new PyPassManager(passManager); }, - py::arg("pipeline"), py::arg("context") = py::none(), + "pipeline"_a, "context"_a.none() = nb::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") + .def( + "add", + [](PyPassManager &passManager, const std::string &pipeline) { + PyPrintAccumulator errorMsg; + MlirLogicalResult status = mlirOpPassManagerAddPipeline( + mlirPassManagerGetAsOpPassManager(passManager.get()), + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); + if (mlirLogicalResultIsFailure(status)) + throw nb::value_error(errorMsg.join().c_str()); + }, + "pipeline"_a, + "Add textual pipeline elements to the pass manager. Throws a " + "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyModule &module) { - MlirLogicalResult status = - mlirPassManagerRun(passManager.get(), module.get()); + [](PyPassManager &passManager, PyOperationBase &op, + bool invalidateOps) { + if (invalidateOps) { + op.getOperation().getContext()->clearOperationsInside(op); + } + // Actually run the pass manager. + PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); + MlirLogicalResult status = mlirPassManagerRunOnOp( + passManager.get(), op.getOperation().get()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_RuntimeError, - "Failure while executing pass pipeline."); + throw MLIRError("Failure while executing pass pipeline", + errors.take()); }, - "Run the pass manager on the provided module, throw a RuntimeError " - "on failure.") + "operation"_a, "invalidate_ops"_a = true, + "Run the pass manager on the provided operation, raising an " + "MLIRError on failure.") .def( "__str__", [](PyPassManager &self) { diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index 550ff47c3..bc4094352 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,14 +9,14 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populatePassManagerSubmodule(pybind11::module &m); +void populatePassManagerSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir -#endif // MLIR_BINDINGS_PYTHON_PASS_H \ No newline at end of file +#endif // MLIR_BINDINGS_PYTHON_PASS_H diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp deleted file mode 100644 index bd80b8c14..000000000 --- a/mlir/lib/Bindings/Python/PybindUtils.cpp +++ /dev/null @@ -1,18 +0,0 @@ -//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PybindUtils.h" - -namespace py = pybind11; - -pybind11::error_already_set -mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { - auto messageStr = message.str(); - PyErr_SetString(excClass, messageStr.c_str()); - return pybind11::error_already_set(); -} diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h deleted file mode 100644 index 0cea24482..000000000 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ /dev/null @@ -1,278 +0,0 @@ -//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H -#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H - -#include "mlir-c/Support.h" -#include "llvm/ADT/Optional.h" -#include "llvm/ADT/Twine.h" - -#include -#include - -namespace mlir { -namespace python { - -// Sets a python error, ready to be thrown to return control back to the -// python runtime. -// Correct usage: -// throw SetPyError(PyExc_ValueError, "Foobar'd"); -pybind11::error_already_set SetPyError(PyObject *excClass, - const llvm::Twine &message); - -/// CRTP template for special wrapper types that are allowed to be passed in as -/// 'None' function arguments and can be resolved by some global mechanic if -/// so. Such types will raise an error if this global resolution fails, and -/// it is actually illegal for them to ever be unresolved. From a user -/// perspective, they behave like a smart ptr to the underlying type (i.e. -/// 'get' method and operator-> overloaded). -/// -/// Derived types must provide a method, which is called when an environmental -/// resolution is required. It must raise an exception if resolution fails: -/// static ReferrentTy &resolve() -/// -/// They must also provide a parameter description that will be used in -/// error messages about mismatched types: -/// static constexpr const char kTypeDescription[] = ""; - -template -class Defaulting { -public: - using ReferrentTy = T; - /// Type casters require the type to be default constructible, but using - /// such an instance is illegal. - Defaulting() = default; - Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} - - ReferrentTy *get() const { return referrent; } - ReferrentTy *operator->() { return referrent; } - -private: - ReferrentTy *referrent = nullptr; -}; - -} // namespace python -} // namespace mlir - -namespace pybind11 { -namespace detail { - -template -struct MlirDefaultingCaster { - PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); - - bool load(pybind11::handle src, bool) { - if (src.is_none()) { - // Note that we do want an exception to propagate from here as it will be - // the most informative. - value = DefaultingTy{DefaultingTy::resolve()}; - return true; - } - - // Unlike many casters that chain, these casters are expected to always - // succeed, so instead of doing an isinstance check followed by a cast, - // just cast in one step and handle the exception. Returning false (vs - // letting the exception propagate) causes higher level signature parsing - // code to produce nice error messages (other than "Cannot cast..."). - try { - value = DefaultingTy{ - pybind11::cast(src)}; - return true; - } catch (std::exception &) { - return false; - } - } - - static handle cast(DefaultingTy src, return_value_policy policy, - handle parent) { - return pybind11::cast(src, policy); - } -}; - -template -struct type_caster> : optional_caster> {}; -} // namespace detail -} // namespace pybind11 - -//------------------------------------------------------------------------------ -// Conversion utilities. -//------------------------------------------------------------------------------ - -namespace mlir { - -/// Accumulates into a python string from a method that accepts an -/// MlirStringCallback. -struct PyPrintAccumulator { - pybind11::list parts; - - void *getUserData() { return this; } - - MlirStringCallback getCallback() { - return [](MlirStringRef part, void *userData) { - PyPrintAccumulator *printAccum = - static_cast(userData); - pybind11::str pyPart(part.data, - part.length); // Decodes as UTF-8 by default. - printAccum->parts.append(std::move(pyPart)); - }; - } - - pybind11::str join() { - pybind11::str delim("", 0); - return delim.attr("join")(parts); - } -}; - -/// Accumulates int a python file-like object, either writing text (default) -/// or binary. -class PyFileAccumulator { -public: - PyFileAccumulator(pybind11::object fileObject, bool binary) - : pyWriteFunction(fileObject.attr("write")), binary(binary) {} - - void *getUserData() { return this; } - - MlirStringCallback getCallback() { - return [](MlirStringRef part, void *userData) { - pybind11::gil_scoped_acquire(); - PyFileAccumulator *accum = static_cast(userData); - if (accum->binary) { - // Note: Still has to copy and not avoidable with this API. - pybind11::bytes pyBytes(part.data, part.length); - accum->pyWriteFunction(pyBytes); - } else { - pybind11::str pyStr(part.data, - part.length); // Decodes as UTF-8 by default. - accum->pyWriteFunction(pyStr); - } - }; - } - -private: - pybind11::object pyWriteFunction; - bool binary; -}; - -/// Accumulates into a python string from a method that is expected to make -/// one (no more, no less) call to the callback (asserts internally on -/// violation). -struct PySinglePartStringAccumulator { - void *getUserData() { return this; } - - MlirStringCallback getCallback() { - return [](MlirStringRef part, void *userData) { - PySinglePartStringAccumulator *accum = - static_cast(userData); - assert(!accum->invoked && - "PySinglePartStringAccumulator called back multiple times"); - accum->invoked = true; - accum->value = pybind11::str(part.data, part.length); - }; - } - - pybind11::str takeValue() { - assert(invoked && "PySinglePartStringAccumulator not called back"); - return std::move(value); - } - -private: - pybind11::str value; - bool invoked = false; -}; - -/// A CRTP base class for pseudo-containers willing to support Python-type -/// slicing access on top of indexed access. Calling ::bind on this class -/// will define `__len__` as well as `__getitem__` with integer and slice -/// arguments. -/// -/// This is intended for pseudo-containers that can refer to arbitrary slices of -/// underlying storage indexed by a single integer. Indexing those with an -/// integer produces an instance of ElementTy. Indexing those with a slice -/// produces a new instance of Derived, which can be sliced further. -/// -/// A derived class must provide the following: -/// - a `static const char *pyClassName ` field containing the name of the -/// Python class to bind; -/// - an instance method `intptr_t getNumElements()` that returns the number -/// of elements in the backing container (NOT that of the slice); -/// - an instance method `ElementTy getElement(intptr_t)` that returns a -/// single element at the given index. -/// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that -/// constructs a new instance of the derived pseudo-container with the -/// given slice parameters (to be forwarded to the Sliceable constructor). -/// -/// A derived class may additionally define: -/// - a `static void bindDerived(ClassTy &)` method to bind additional methods -/// the python class. -template -class Sliceable { -protected: - using ClassTy = pybind11::class_; - -public: - explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) - : startIndex(startIndex), length(length), step(step) { - assert(length >= 0 && "expected non-negative slice length"); - } - - /// Returns the length of the slice. - intptr_t dunderLen() const { return length; } - - /// Returns the element at the given slice index. Supports negative indices - /// by taking elements in inverse order. Throws if the index is out of bounds. - ElementTy dunderGetItem(intptr_t index) { - // Negative indices mean we count from the end. - if (index < 0) - index = length + index; - if (index < 0 || index >= length) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } - - // Compute the linear index given the current slice properties. - int linearIndex = index * step + startIndex; - assert(linearIndex >= 0 && - linearIndex < static_cast(this)->getNumElements() && - "linear index out of bounds, the slice is ill-formed"); - return static_cast(this)->getElement(linearIndex); - } - - /// Returns a new instance of the pseudo-container restricted to the given - /// slice. - Derived dunderGetItemSlice(pybind11::slice slice) { - ssize_t start, stop, extraStep, sliceLength; - if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } - return static_cast(this)->slice(startIndex + start * step, - sliceLength, step * extraStep); - } - - /// Binds the indexing and length methods in the Python class. - static void bind(pybind11::module &m) { - auto clazz = pybind11::class_(m, Derived::pyClassName) - .def("__len__", &Sliceable::dunderLen) - .def("__getitem__", &Sliceable::dunderGetItem) - .def("__getitem__", &Sliceable::dunderGetItemSlice); - Derived::bindDerived(clazz); - } - - /// Hook for derived classes willing to bind more methods. - static void bindDerived(ClassTy &) {} - -private: - intptr_t startIndex; - intptr_t length; - intptr_t step; -}; - -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp new file mode 100644 index 000000000..3ba42bec5 --- /dev/null +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -0,0 +1,24 @@ +//===- RegisterEverything.cpp - API to register all dialects/passes -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/RegisterEverything.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +NB_MODULE(_mlirRegisterEverything, m) { + m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration"; + + m.def("register_dialects", [](MlirDialectRegistry registry) { + mlirRegisterAllDialects(registry); + }); + m.def("register_llvm_translations", + [](MlirContext context) { mlirRegisterAllLLVMTranslations(context); }); + + // Register all passes on load. + mlirRegisterAllPasses(); +} diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp new file mode 100644 index 000000000..0373f9c7a --- /dev/null +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -0,0 +1,112 @@ +//===- Rewrite.cpp - Rewrite ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Rewrite.h" + +#include "IRModule.h" +#include "mlir-c/Rewrite.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "mlir/Config/mlir-config.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace nb::literals; +using namespace mlir::python; + +namespace { + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +/// Owning Wrapper around a PDLPatternModule. +class PyPDLPatternModule { +public: + PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} + PyPDLPatternModule(PyPDLPatternModule &&other) noexcept + : module(other.module) { + other.module.ptr = nullptr; + } + ~PyPDLPatternModule() { + if (module.ptr != nullptr) + mlirPDLPatternModuleDestroy(module); + } + MlirPDLPatternModule get() { return module; } + +private: + MlirPDLPatternModule module; +}; +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + +/// Owning Wrapper around a FrozenRewritePatternSet. +class PyFrozenRewritePatternSet { +public: + PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} + PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept + : set(other.set) { + other.set.ptr = nullptr; + } + ~PyFrozenRewritePatternSet() { + if (set.ptr != nullptr) + mlirFrozenRewritePatternSetDestroy(set); + } + MlirFrozenRewritePatternSet get() { return set; } + + nb::object getCapsule() { + return nb::steal( + mlirPythonFrozenRewritePatternSetToCapsule(get())); + } + + static nb::object createFromCapsule(nb::object capsule) { + MlirFrozenRewritePatternSet rawPm = + mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + if (rawPm.ptr == nullptr) + throw nb::python_error(); + return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); + } + +private: + MlirFrozenRewritePatternSet set; +}; + +} // namespace + +/// Create the `mlir.rewrite` here. +void mlir::python::populateRewriteSubmodule(nb::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of the top-level PassManager + //---------------------------------------------------------------------------- +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "PDLModule") + .def( + "__init__", + [](PyPDLPatternModule &self, MlirModule module) { + new (&self) + PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); + }, + "module"_a, "Create a PDL module from the given module.") + .def("freeze", [](PyPDLPatternModule &self) { + return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + mlirRewritePatternSetFromPDLPatternModule(self.get()))); + }); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "FrozenRewritePatternSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyFrozenRewritePatternSet::createFromCapsule); + m.def( + "apply_patterns_and_fold_greedily", + [](MlirModule module, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); + if (mlirLogicalResultIsFailure(status)) + // FIXME: Not sure this is the right error to throw here. + throw nb::value_error("pattern application failed to converge"); + }, + "module"_a, "set"_a, + "Applys the given patterns to the given module greedily while folding " + "results."); +} diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.h b/mlir/lib/Bindings/Python/Rewrite.h similarity index 54% rename from mlir/lib/Bindings/Python/ExecutionEngine.h rename to mlir/lib/Bindings/Python/Rewrite.h index cc61648b5..ae89e2b95 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -1,4 +1,4 @@ -//===- ExecutionEngine.h - ExecutionEngine submodule of pybind module -----===// +//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,17 +6,17 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H -#define MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H +#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H +#define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populateExecutionEngineSubmodule(pybind11::module &m); +void populateRewriteSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir -#endif // MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H +#endif // MLIR_BINDINGS_PYTHON_REWRITE_H diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp new file mode 100644 index 000000000..8242f0973 --- /dev/null +++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp @@ -0,0 +1,22 @@ +//===- SparseTensorPasses.cpp - Pybind module for the SparseTensor passes -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SparseTensor.h" + +#include "mlir/Bindings/Python/Nanobind.h" + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +NB_MODULE(_mlirSparseTensorPasses, m) { + m.doc() = "MLIR SparseTensor Dialect Passes"; + + // Register all SparseTensor passes on load. + mlirRegisterSparseTensorPasses(); +} diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp new file mode 100644 index 000000000..f9b0fed62 --- /dev/null +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -0,0 +1,106 @@ +//===- TransformInterpreter.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pybind classes for the transform dialect interpreter. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform/Interpreter.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; + +namespace { +struct PyMlirTransformOptions { + PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; + PyMlirTransformOptions(PyMlirTransformOptions &&other) { + options = other.options; + other.options.ptr = nullptr; + } + PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; + + ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); } + + MlirTransformOptions options; +}; +} // namespace + +static void populateTransformInterpreterSubmodule(nb::module_ &m) { + nb::class_(m, "TransformOptions") + .def(nb::init<>()) + .def_prop_rw( + "expensive_checks", + [](const PyMlirTransformOptions &self) { + return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); + }, + [](PyMlirTransformOptions &self, bool value) { + mlirTransformOptionsEnableExpensiveChecks(self.options, value); + }) + .def_prop_rw( + "enforce_single_top_level_transform_op", + [](const PyMlirTransformOptions &self) { + return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + self.options); + }, + [](PyMlirTransformOptions &self, bool value) { + mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, + value); + }); + + m.def( + "apply_named_sequence", + [](MlirOperation payloadRoot, MlirOperation transformRoot, + MlirOperation transformModule, const PyMlirTransformOptions &options) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(transformRoot)); + + // Calling back into Python to invalidate everything under the payload + // root. This is awkward, but we don't have access to PyMlirContext + // object here otherwise. + nb::object obj = nb::cast(payloadRoot); + obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); + + MlirLogicalResult result = mlirTransformApplyNamedSequence( + payloadRoot, transformRoot, transformModule, options.options); + if (mlirLogicalResultIsSuccess(result)) + return; + + throw nb::value_error( + ("Failed to apply named transform sequence.\nDiagnostic message " + + scope.takeMessage()) + .c_str()); + }, + nb::arg("payload_root"), nb::arg("transform_root"), + nb::arg("transform_module"), + nb::arg("transform_options") = PyMlirTransformOptions()); + + m.def( + "copy_symbols_and_merge_into", + [](MlirOperation target, MlirOperation other) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(target)); + + MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); + if (mlirLogicalResultIsFailure(result)) { + throw nb::value_error( + ("Failed to merge symbols.\nDiagnostic message " + + scope.takeMessage()) + .c_str()); + } + }, + nb::arg("target"), nb::arg("other")); +} + +NB_MODULE(_mlirTransformInterpreter, m) { + m.doc() = "MLIR Transform dialect interpreter functionality."; + populateTransformInterpreterSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt deleted file mode 100644 index 8b53f03d4..000000000 --- a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -################################################################################ -# Build python extension -################################################################################ - -add_mlir_python_extension(MLIRTransformsBindingsPythonExtension _mlirTransforms - INSTALL_DIR - python - SOURCES - Transforms.cpp -) \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/mlir/_cext_loader.py b/mlir/lib/Bindings/Python/mlir/_cext_loader.py deleted file mode 100644 index 35847efa9..000000000 --- a/mlir/lib/Bindings/Python/mlir/_cext_loader.py +++ /dev/null @@ -1,55 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Common module for looking up and manipulating C-Extensions.""" - -# Packaged installs have a top-level _mlir_libs package with symbols: -# load_extension(name): Loads a named extension module -# preload_dependency(public_name): Loads a shared-library/DLL into the -# namespace. TODO: Remove this in favor of a more robust mechanism. -# Conditionally switch based on whether we are in a package context. -try: - import _mlir_libs -except ModuleNotFoundError: - # Assume that we are in-tree. - # The _dlloader takes care of platform specific setup before we try to - # load a shared library. - from ._dlloader import preload_dependency as _preload_dependency - - def _load_extension(name): - import importlib - return importlib.import_module(name) # i.e. '_mlir' at the top level -else: - # Packaged distribution. - _load_extension = _mlir_libs.load_extension - _preload_dependency = _mlir_libs.preload_dependency - -_preload_dependency("MLIRPublicAPI") - -# Expose the corresponding C-Extension module with a well-known name at this -# top-level module. This allows relative imports like the following to -# function: -# from .._cext_loader import _cext -# This reduces coupling, allowing embedding of the python sources into another -# project that can just vary based on this top-level loader module. -_cext = _load_extension("_mlir") - - -def _reexport_cext(cext_module_name, target_module_name): - """Re-exports a named sub-module of the C-Extension into another module. - - Typically: - from ._cext_loader import _reexport_cext - _reexport_cext("ir", __name__) - del _reexport_cext - """ - import sys - target_module = sys.modules[target_module_name] - source_module = getattr(_cext, cext_module_name) - for attr_name in dir(source_module): - if not attr_name.startswith("__"): - setattr(target_module, attr_name, getattr(source_module, attr_name)) - - -# Add our 'dialects' parent module to the search path for implementations. -_cext.globals.append_dialect_search_prefix("mlir.dialects") diff --git a/mlir/lib/Bindings/Python/mlir/_dlloader.py b/mlir/lib/Bindings/Python/mlir/_dlloader.py deleted file mode 100644 index 454a7b7f1..000000000 --- a/mlir/lib/Bindings/Python/mlir/_dlloader.py +++ /dev/null @@ -1,59 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import os -import platform - -_is_windows = platform.system() == "Windows" -_this_directory = os.path.dirname(__file__) - -# The standard LLVM build/install tree for Windows is laid out as: -# bin/ -# MLIRPublicAPI.dll -# python/ -# _mlir.*.pyd (dll extension) -# mlir/ -# _dlloader.py (this file) -# First check the python/ directory level for DLLs co-located with the pyd -# file, and then fall back to searching the bin/ directory. -# TODO: This should be configurable at some point. -_dll_search_path = [ - os.path.join(_this_directory, ".."), - os.path.join(_this_directory, "..", "..", "bin"), -] - -# Stash loaded DLLs to keep them alive. -_loaded_dlls = [] - -def preload_dependency(public_name): - """Preloads a dylib by its soname or DLL name. - - On Windows and Linux, doing this prior to loading a dependency will populate - the library in the flat namespace so that a subsequent library that depend - on it will resolve to this preloaded version. - - On OSX, resolution is completely path based so this facility no-ops. On - Linux, as long as RPATHs are setup properly, resolution is path based but - this facility can still act as an escape hatch for relocatable distributions. - """ - if _is_windows: - _preload_dependency_windows(public_name) - - -def _preload_dependency_windows(public_name): - dll_basename = public_name + ".dll" - found_path = None - for search_dir in _dll_search_path: - candidate_path = os.path.join(search_dir, dll_basename) - if os.path.exists(candidate_path): - found_path = candidate_path - break - - if found_path is None: - raise RuntimeError( - f"Unable to find dependency DLL {dll_basename} in search " - f"path {_dll_search_path}") - - import ctypes - _loaded_dlls.append(ctypes.CDLL(found_path)) diff --git a/mlir/lib/Bindings/Python/mlir/conversions/__init__.py b/mlir/lib/Bindings/Python/mlir/conversions/__init__.py deleted file mode 100644 index 0989449a4..000000000 --- a/mlir/lib/Bindings/Python/mlir/conversions/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._cext_loader import _load_extension -_cextConversions = _load_extension("_mlirConversions") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py deleted file mode 100644 index b07892991..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py +++ /dev/null @@ -1,95 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ..ir import * - - -class ModuleOp: - """Specialization for the module op class.""" - - def __init__(self, *, loc=None, ip=None): - super().__init__(self.build_generic(results=[], operands=[], loc=loc, - ip=ip)) - body = self.regions[0].blocks.append() - with InsertionPoint(body): - Operation.create("module_terminator") - - @property - def body(self): - return self.regions[0].blocks[0] - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self): - return self.attributes["sym_name"] - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py deleted file mode 100644 index 74390d487..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ /dev/null @@ -1,27 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - - -class StructuredOpMixin: - """All structured ops use the same mixin class.""" - - def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - if outputs and results: - raise ValueError( - "Structured ops must have outputs or results, but not both.") - super().__init__( - self.build_generic(results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip)) - - -def select_opview_mixin(parent_opview_cls): - # TODO: This shouldn't be a heuristic: we should have a way to annotate - # the OpView to note that it is a structured op. - if ("__init__" not in parent_opview_cls.__dict__ and - hasattr(parent_opview_cls, "inputs") and - hasattr(parent_opview_cls, "outputs") and - hasattr(parent_opview_cls, "result_tensors")): - return StructuredOpMixin diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py deleted file mode 100644 index 6d37700ec..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py +++ /dev/null @@ -1,116 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Re-export the parent _cext so that every level of the API can get it locally. -from .._cext_loader import _cext - -__all__ = [ - "equally_sized_accessor", - "extend_opview_class", - "get_default_loc_context", - "segmented_accessor", -] - - -def extend_opview_class(ext_module): - """Decorator to extend an OpView class from an extension module. - - Extension modules can expose various entry-points: - def select_opview_mixin(parent_opview_cls): - If defined, allows an appropriate mixin class to be selected dynamically - based on the parent OpView class. Should return NotImplemented if a - decision is not made. - - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). - - Args: - ext_module: A module from which to locate extensions. Can be None if not - available. - - Returns: - A decorator that takes an OpView subclass and further extends it as - needed. - """ - - def class_decorator(parent_opview_cls: type): - if ext_module is None: - return parent_opview_cls - mixin_cls = NotImplemented - try: - select_mixin = getattr(ext_module, "select_opview_mixin") - except AttributeError: - # Try to default resolve it. - try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) - except AttributeError: - pass - else: - mixin_cls = select_mixin(parent_opview_cls) - if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls - - # Have a mixin_cls. Create an appropriate subclass. - try: - - class LocalOpView(mixin_cls, parent_opview_cls): - pass - except TypeError as e: - raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e - LocalOpView.__name__ = parent_opview_cls.__name__ - LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView - - return class_decorator - - -def segmented_accessor(elements, raw_segments, idx): - """ - Returns a slice of elements corresponding to the idx-th segment. - - elements: a sliceable container (operands or results). - raw_segments: an mlir.ir.Attribute, of DenseIntElements subclass containing - sizes of the segments. - idx: index of the segment. - """ - segments = _cext.ir.DenseIntElementsAttr(raw_segments) - start = sum(segments[i] for i in range(idx)) - end = start + segments[idx] - return elements[start:end] - - -def equally_sized_accessor(elements, n_variadic, n_preceding_simple, - n_preceding_variadic): - """ - Returns a starting position and a number of elements per variadic group - assuming equally-sized groups and the given numbers of preceding groups. - - elements: a sequential container. - n_variadic: the number of variadic groups in the container. - n_preceding_simple: the number of non-variadic groups preceding the current - group. - n_preceding_variadic: the number of variadic groups preceding the current - group. - """ - - total_variadic_length = len(elements) - n_variadic + 1 - # This should be enforced by the C++-side trait verifier. - assert total_variadic_length % n_variadic == 0 - - elements_per_group = total_variadic_length // n_variadic - start = n_preceding_simple + n_preceding_variadic * elements_per_group - return start, elements_per_group - - -def get_default_loc_context(location=None): - """ - Returns a context in which the defaulted location is created. If the location - is None, takes the current location from the stack, raises ValueError if there - is no location on the stack. - """ - if location is None: - # Location.current raises ValueError if there is no current location. - return _cext.ir.Location.current.context - return location.context diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py deleted file mode 100644 index 98bf2e247..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ /dev/null @@ -1,91 +0,0 @@ -#!/usr/bin/which python -# Command line tool to load an oplib module and dump all of the operations -# it contains in some format. -"""Loads one or more modules containing op definitions and dumps them. - -The dump format can be: - -* `--dump_format=yaml` (default) -* `--dump_format=repr` - -Positional arguments are interpreted as module names (optionally, relative to -this module). Loose module files can be specified via `--file `. - -Sample usage: - # Dump the YAML op definitions for the core named ops (as in the dialect - # source tree). - python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops - -Note: YAML output is emitted in "document list" format with each operation -as its own "document". Practically, this means that each operation (or group -of composite ops) is emitted with a "---" preceding it, which can be useful -for testing. -""" - -import argparse -import importlib - -from .lang import * -from .lang.config import * -from .lang.yaml_helper import * - - -def create_arg_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description="Dump an oplib in various formats") - p.add_argument("modules", - metavar="M", - type=str, - nargs="*", - help="Op module to dump") - p.add_argument("--file", - metavar="F", - type=str, - nargs="*", - help="Python op file to dump") - p.add_argument("--format", - type=str, - dest="format", - default="yaml", - choices=("yaml", "repr"), - help="Format in which to dump") - return p - - -def load_module_from_file(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - m = importlib.util.module_from_spec(spec) - spec.loader.exec_module(m) - return m - - -def main(args): - # Load all configs. - configs = [] - modules = [] - for module_name in args.modules: - modules.append( - importlib.import_module(module_name, - package="mlir.dialects.linalg.opdsl")) - for i, file_path in enumerate(args.file or []): - modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) - for m in modules: - for attr_name, value in m.__dict__.items(): - # TODO: This class layering is awkward. - if isinstance(value, DefinedOpCallable): - try: - linalg_config = LinalgOpConfig.from_linalg_op_def(value.model) - except Exception as e: - raise ValueError( - f"Could not create LinalgOpConfig from {value.model}") from e - configs.extend(linalg_config) - - # Print. - if args.format == "yaml": - print(yaml_dump_all(configs)) - elif args.format == "repr": - for config in configs: - print(repr(config)) - - -if __name__ == "__main__": - main(create_arg_parser().parse_args()) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py deleted file mode 100644 index 34a8d6d30..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py +++ /dev/null @@ -1,312 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""DSL for constructing affine expressions and maps. - -These python wrappers allow construction of affine expressions in a more -pythonic fashion that is later instantiated as an IR AffineExpr. Separating the -AST from construction of the map allows for manipulations of symbols and dims -beyond the scope of one expression. - -Affine expression construction: - >>> with _ir.Context(): - ... s = AffineBuildState() - ... (S.K + S.M).build(s) - ... (S.K * S.M).build(s) - ... (S.K // S.M).build(s) - ... (S.K / S.M).build(s) - ... (S.K % 4).build(s) - ... (D.i + D.j * 4).build(s) - ... s - AffineExpr(s0 + s1) - AffineExpr(s0 * s1) - AffineExpr(s0 floordiv s1) - AffineExpr(s0 ceildiv s1) - AffineExpr(s0 mod 4) - AffineExpr(d0 + d1 * 4) - AffineBuildState< - symbols={'K': 0, 'M': 1} - dims={'i': 0, 'j': 1}> - -In the DSL, dimensions and symbols are name-uniqued instances of DimDef and -SymbolDef. There are shortcut "expando" instances that will create a -corresponding DimDef/SymbolDef upon accessing an attribute: - -Referencing a named dimension: - - >>> D.i - Dim(i) - >>> D.a is D.b - False - >>> D.a is D.a - True - -Referencing a named symbol: - - >>> S.foobar - Symbol(foobar) - >>> S.a is S.b - False - >>> S.a is S.a - True -""" - -from typing import Callable, Dict, Optional, Tuple, Union - -from mlir import ir as _ir - -__all__ = [ - "AffineBuildState", - "AffineExprDef", - "D", - "DimDef", - "S", - "SymbolDef", -] - -# Type aliases. -SymbolPosMap = Dict[str, int] - - -class AffineBuildState: - """Internal state for the AffineExprDef._create impls. - - Note that a "local" AffineBuildState can be created relative to a "global" - AffineBuildState. In that case, any affine expressions built will inherit - symbol and dim bindings from the global state and will update both as new - ones are discovered. This allows for building expressions across contexts - which share a common symbol and dim space. - """ - - def __init__(self, - *, - global_state: "AffineBuildState" = None, - allow_new_symbols: bool = True, - allow_new_dims: bool = True): - if not global_state: - self.all_symbols = dict() # type: Dict[str, int] - self.all_dims = dict() # type: Dict[str, int] - else: - # Alias the global dict. - self.all_symbols = global_state.all_symbols - self.all_dims = global_state.all_dims - - # Map of symbols and dims in the current build. - self.local_symbols = dict() # type: Dict[str, int] - self.local_dims = dict() # type: Dict[str, int] - self.allow_new_symbols = allow_new_symbols - self.allow_new_dims = allow_new_dims - - def get_dim(self, dimname: str) -> int: - """Gets the dim position given a name.""" - pos = self.all_dims.get(dimname) - if pos is None: - if not self.allow_new_dims: - raise ValueError( - f"New dimensions not allowed in the current affine expression: " - f"Requested '{dimname}', Availble: {self.all_dims}") - pos = len(self.all_dims) - self.all_dims[dimname] = pos - self.local_dims[dimname] = pos - return pos - - def get_symbol(self, symname: str) -> int: - """Geta a symbol position given a name.""" - pos = self.all_symbols.get(symname) - if pos is None: - if not self.allow_new_symbols: - raise ValueError( - f"New symbols not allowed in the current affine expression: " - f"Requested '{symname}', Availble: {self.all_symbols}") - pos = len(self.all_symbols) - self.all_symbols[symname] = pos - self.local_symbols[symname] = pos - return pos - - @property - def local_dim_count(self) -> int: - return len(self.local_dims) - - @property - def local_symbol_count(self) -> int: - return len(self.local_symbols) - - @property - def dim_count(self) -> int: - return len(self.all_dims) - - @property - def symbol_count(self) -> int: - return len(self.all_symbols) - - def __repr__(self): - lines = [f"AffineBuildState<"] - lines.append(f" symbols={self.local_symbols}") - lines.append(f" dims={self.local_dims}>") - return "\n".join(lines) - - -class AffineExprDef: - """Base class for an affine expression being defined.""" - - def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: - """Builds the corresponding _ir.AffineExpr from the definitions. - """ - state = AffineBuildState() if state is None else state - expr = self._create(state) - return expr - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - raise NotImplementedError() - - @staticmethod - def coerce_from(py_value): - if isinstance(py_value, int): - return AffineConstantExpr(py_value) - assert isinstance(py_value, AffineExprDef) - return py_value - - def visit_affine_exprs(self, callback): - """Visits all AffineExprDefs including self.""" - callback(self) - - def __add__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) - - def __mul__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) - - def __mod__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) - - def __floordiv__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) - - def __truediv__(lhs, rhs): - # TODO: Not really a ceil div - taking liberties for the DSL. - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) - - -class AffineConstantExpr(AffineExprDef): - """An affine constant being defined.""" - - def __init__(self, value: int): - assert isinstance(value, int) - self.value = value - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - return _ir.AffineConstantExpr.get(self.value) - - def __repr__(self): - return f"Const({self.value})" - - -class AffineBinaryExprDef(AffineExprDef): - """An affine binary expression being defined.""" - - def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): - self.ir_ctor = ir_ctor - self.lhs = lhs - self.rhs = rhs - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) - - def visit_affine_exprs(self, callback): - """Visits all AffineExprDefs including self.""" - super().visit_affine_exprs(callback) - self.lhs.visit_affine_exprs(callback) - self.rhs.visit_affine_exprs(callback) - - def __repr__(self): - return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" - - -class DimDef(AffineExprDef): - """Represents a named dimension. - - """ - ALL_DIMS = dict() # type: Dict[str, "DimDef"] - dimname: str - - def __new__(cls, dimname: str): - existing = cls.ALL_DIMS.get(dimname) - if existing is not None: - return existing - new = super().__new__(cls) - new.dimname = dimname - cls.ALL_DIMS[dimname] = new - return new - - def __repr__(self): - return f"Dim({self.dimname})" - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - pos = state.get_dim(self.dimname) - return _ir.AffineDimExpr.get(position=pos) - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access. - """ - - class ExpandoDims: - - def __getattr__(self, n): - return cls(n) - - return ExpandoDims() - - -class SymbolDef(AffineExprDef): - """Represents a named symbol. - - >>> s1 = SymbolDef("s1") - >>> s1 - Symbol(s1) - >>> s2 = SymbolDef("s2") - >>> s1 is s2 - False - >>> s1 is SymbolDef("s1") - True - """ - ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] - symname: str - - def __new__(cls, symname: str): - existing = cls.ALL_SYMBOLS.get(symname) - if existing is not None: - return existing - new = super().__new__(cls) - new.symname = symname - cls.ALL_SYMBOLS[symname] = new - return new - - def __repr__(self): - return f"Symbol({self.symname})" - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - pos = state.get_symbol(self.symname) - return _ir.AffineSymbolExpr.get(position=pos) - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access. - """ - - class ExpandoSymbols: - - def __getattr__(self, n): - return cls(n) - - return ExpandoSymbols() - - -# Global accessor for on-demand dims and symbols. -D = DimDef.create_expando() -S = SymbolDef.create_expando() diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py deleted file mode 100644 index 6bc6ff979..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ /dev/null @@ -1,425 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Model classes representing a tensor comprehension. - -These classes model the language more at an AST level as evaluated. Reasoning -about it typically involves processing this form into config objects that -represent actual op definitions (i.e. YAML). -""" - -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union - -from mlir import ir as _ir - -from .affine import * -from .scalar_expr import * -from .types import * -from .yaml_helper import * - -# Type aliases. -AffineDimList = Dict[str, _ir.AffineExpr] - - -class TensorExpression: - """An expression that can appear on the RHS of a comprehension.""" - - def to_scalar_expression(self) -> ScalarExpression: - raise NotImplementedError() - - def visit_affine_exprs(self, callback): - """Visits all affine expressions reachable by the expression.""" - pass - - def _get_all_dim_defs(self) -> Set[DimDef]: - """Recursively gets all DimDef affine expressions that are referenced.""" - results = set() - - def visitor(affine_expr): - if isinstance(affine_expr, DimDef): - results.add(affine_expr) - - self.visit_affine_exprs(visitor) - return results - - def collect_uses(self, uses: Set["TensorUse"]): - """Collects all TensorUses reachable through this expression.""" - pass - - def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return PrimFn.add(self, rhs) - - def __mul__(self, rhs) -> "TensorExpression": - return PrimFn.mul(self, rhs) - - def __sub__(self, rhs) -> "TensorExpression": - return PrimFn.sub(self, rhs) - - def __hash__(self): - return hash(id(self)) - - -class TensorUse(TensorExpression): - """A used tensor represented by its (tensor_name, indices). - - Note that forming a comprehension via direct assignment is performed through - __setitem__ on the TensorDef level. However, performing a reduction with - compound ops (+=, *=, etc) is done by doing a: - TensorDef.__getitem__ - TensorUse.__iadd__ - TensorDef.__setitem__ - """ - - def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]): - self.tensor_def = tensor_def - self.indices = tuple(indices) - - def to_scalar_expression(self) -> ScalarExpression: - assert self.tensor_def.tensor_name is not None - return ScalarArg(self.tensor_def.tensor_name).expr() - - @property - def tensor_name(self) -> str: - n = self.tensor_def.tensor_name - assert n is not None, "TensorDef not attached" - return n - - def visit_affine_exprs(self, callback): - for ind in self.indices: - ind.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - uses.add(self) - - def __iadd__(self, rhs: TensorExpression) -> TensorExpression: - return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) - - def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: - """For implicit reductions, computes default reduction dims. - - Assumes that the rhs is the expression being reduced and self is being - reduced into. Any indices referenced on the rhs and not in self are - considered reduction dims and will be ordered as encountered on the rhs. - """ - rhs_dims = rhs._get_all_dim_defs() - lhs_dims = self._get_all_dim_defs() - return rhs_dims - lhs_dims - - def __repr__(self): - return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" - - -class TensorDef: - """Bookkeeping of a single registered tensor, held in dict by name.""" - - def __init__(self, - type_var: TypeVar, - *shape: AffineExprDef, - indexing_map: Optional[_ir.AffineMap] = None, - output: bool = False): - if not isinstance(type_var, TypeVar): - raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}") - self.owner = None # type: Optional["LinalgOpDef"] - self.type_var = type_var - self.shape = shape - self.indexing_map = indexing_map - self.output = output - self.tensor_name = None # type: Optional[str] - self.registered_index = -1 # type: int - - @property - def rank(self) -> int: - """The rank of the tensor.""" - return len(self.shape) - - def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"): - if self.owner: - raise ValueError(f"TensorDef already registered with op: {self}") - self.registered_index = index - self.tensor_name = tensor_name - self.owner = owner - - def __getitem__(self, dims) -> TensorUse: - assert self.owner, "TensorDef is not attached to an op" - state = AffineBuildState(global_state=self.owner._affine_state, - allow_new_symbols=False) - if not isinstance(dims, tuple): - dims = (dims,) # Handle single subscript case. - # Special case: (None) is a 0d-scalar use. - if dims == (None,): - dims = () - - exprs = [] - for expr_def in dims: - if not isinstance(expr_def, AffineExprDef): - raise KeyError( - "A TensorDef can only be subscripted by a tuple of affine dims") - exprs.append(expr_def) - return TensorUse(self, exprs) - - def __setitem__(self, dims, value): - """Creates a new 1:1 comprehension by binding this tensor to an expression. - - Note that due to the way assignment works in Python, we have to capture - direct assignment as a setitem on the TensorDef. - """ - if not isinstance(value, TensorExpression): - raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " - f"Got: {repr(value)}") - use = self[dims] - comp = Comprehension((use, value)) - self.owner.comprehensions.append(comp) - - def __hash__(self): - return hash(id(self)) - - def __repr__(self): - output = "OUTPUT " if self.output else "" - return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " - f"shape={self.shape})") - - -class Comprehension: - """Represents a single comprehension.""" - - def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): - self.definitions = list() # List[TensorUse] - self.values = list() # List[TensorExpression] - - # Find the lhs to reduction rhs. - for assign, value in bindings: - if isinstance(value, ReduceApply): - if value.lhs: - raise ValueError(f"Reduction expression already assigns: {value}") - value.lhs = assign - self.definitions.append(assign) - self.values.append(value) - - @property - def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: - """Gets the reduction dims for the comprehension or None.""" - result = set() - for use in self.values: - if isinstance(use, ReduceApply): - result.add(use.reduce.reduce_dims) - else: - result.add(tuple()) - return result - - def __repr__(self): - if len(self.definitions) > 1: - defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" - values_repr = f"({', '.join(repr(v) for v in self.values)})" - else: - defs_repr = f"{repr(self.definitions[0])}" - values_repr = f"{repr(self.values[0])}" - - return f"{defs_repr} = {values_repr}" - - -class PrimFnType: - """Primitive operations.""" - - def __init__(self, prim_name: str): - self.prim_name = prim_name - - def __call__(self, *args): - return PrimApply(self, args) - - def reduce(self, *reduce_dims: DimDef): - """Shortcut to create a Reduce operation from this primitive.""" - return ReduceFnType(self, *reduce_dims) - - def __repr__(self): - return f"{self.prim_name}" - - -class PrimFn: - add = PrimFnType("add") - exp = PrimFnType("exp") - log = PrimFnType("log") - mul = PrimFnType("mul") - max = PrimFnType("max") - sub = PrimFnType("sub") - - -class ReduceFnType: - """A reduction operator that reduces into its LHS from its RHS.""" - - def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): - """Initializes the ReduceFn with a primitive function and dims.""" - if not isinstance(operator, PrimFnType): - raise ValueError(f"Reduce expected a Prim operator. Got: {operator}") - self.operator = operator - self.reduce_dims = tuple(reduce_dims) - - def __call__(self, *args: TensorExpression): - return ReduceApply(self, args) - - def __repr__(self): - return (f"reduce_{self.operator.prim_name}" - f"({', '.join(repr(d) for d in self.reduce_dims)})") - - -class ReduceFn: - add = PrimFn.add.reduce - mul = PrimFn.mul.reduce - max = PrimFn.max.reduce - - -class PrimApply(TensorExpression): - """Application of a primitive.""" - - def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]): - self.prim = prim - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarApplyFn(self.prim.prim_name, - *[arg.to_scalar_expression() for arg in self.args - ]).expr() - - def visit_affine_exprs(self, callback): - for arg in self.args: - arg.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - for arg in self.args: - arg.collect_uses(uses) - - def __repr__(self): - return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" - - -class cast(TensorExpression): - """Casts the element type to a type (typically symbolic TypeVar).""" - - def __init__(self, to_type: TypeVar, operand: TensorExpression): - self.to_type = to_type - self.operand = operand - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, - self.operand.to_scalar_expression()).expr() - - def visit_affine_exprs(self, callback): - self.operand.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - self.operand.collect_uses(uses) - - def __repr__(self): - return f"cast({self.to_type}, {repr(self.operand)})" - - -class ReduceApply(TensorExpression): - """Application of a reduction. - - This captures the lhs separately (initial value) separately from the rhs. - """ - - def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]): - self.reduce = reduce - self.lhs = None # type: Optional[TensorUse] - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - if self.lhs is None: - raise ValueError(f"Cannot scalarize a ReduceApply that has not been " - f"bound to its lhs: {self}") - full_args = [self.lhs.to_scalar_expression() - ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() - - def visit_affine_exprs(self, callback): - for ind in self.reduce.reduce_dims: - ind.visit_affine_exprs(callback) - for arg in self.args: - arg.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - for arg in self.args: - arg.collect_uses(uses) - - def __repr__(self): - return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" - - -class OpInterfaceDef: - """An interface that an op implements.""" - - def __init__(self, cpp_name: str): - self.cpp_name = cpp_name - - -ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") - - -class OpMetadataDef(YAMLObject): - """Metadata about the op (generally not behavior impacting).""" - yaml_tag = "!LinalgOpMetadata" - - def __init__(self, name: str, cpp_op_name: Optional[str], doc: Optional[str]): - self.name = name - self.cpp_op_name = cpp_op_name if cpp_op_name is not None else name - self.doc = doc - self.implements = [] # type: List[OpInterfaceDef] - - def to_yaml_custom_dict(self): - d = dict( - name=self.name, - cpp_op_name=self.cpp_op_name, - doc=self.doc, - ) - if self.implements: - d["implements"] = [intr.cpp_name for intr in self.implements] - return d - - -class LinalgOpDef: - """Definition of a linalg op.""" - - def __init__(self, - name: str, - cpp_op_name: Optional[str] = None, - doc: Optional[str] = None): - self.metadata = OpMetadataDef(name=name, cpp_op_name=cpp_op_name, doc=doc) - self.registered_tensors = dict() # type: Dict[str, TensorDef] - self.comprehensions = list() # type: List[Comprehension] - self._affine_state = AffineBuildState() - - @property - def inputs(self) -> Sequence[TensorDef]: - return [t for t in self.registered_tensors.values() if not t.output] - - @property - def outputs(self) -> Sequence[TensorDef]: - return [t for t in self.registered_tensors.values() if t.output] - - def add_tensor(self, tensor_name: str, tensor: TensorDef): - """Registers a tensor.""" - if tensor_name in self.registered_tensors: - raise ValueError(f"Tensor {tensor_name} is already registered " - f"to {self.registered_tensors['tensor_name']}") - tensor.attach(len(self.registered_tensors), tensor_name, self) - self.registered_tensors[tensor_name] = tensor - - def tensor(self, name): - """Gets a registered tensor by name.""" - try: - return self.registered_tensors[name] - except KeyError: - raise KeyError(f"Tensor {name} is not registered") - - def __repr__(self): - lines = [ - f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_op_name}," - ] - for name, tensor in self.registered_tensors.items(): - lines.append(f" {tensor}") - if self.comprehensions: - lines[-1] += " {" - for comprehension in self.comprehensions: - lines.append(f" {comprehension}") - lines.append("}") - return "\n".join(lines) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py deleted file mode 100644 index 115ea4061..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py +++ /dev/null @@ -1,321 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Represents configured ops as emitted for code generation. - -Classes in this module generally are directly serializable to YAML for use -by the code generator. - -TODO: These should just be dumb containers or serialization code but they -currently encode too many details of how the language is interpreted. Move this -to helpers on the comprehension objects themselves. -""" - -from typing import Any, Dict, Optional - -from mlir import ir as _ir - -from .comprehension import * -from .yaml_helper import * - -__all__ = [ - "LinalgStructuredOpConfig", - "LinalgOpConfig", -] - - -def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: - with affine_map.context: - # Affine map printing/parsing is via an AffineMap attr. - attr = _ir.AffineMapAttr.get(affine_map) - return str(attr) - - -class TensorUseConfig: - """Wrapper around a TensorUse with additional context-bound state.""" - - def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): - self.tensor_use = tensor_use - self.indexing_map = indexing_map - - def __repr__(self): - return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" - - -class TensorDefConfig(YAMLObject): - """Wrapper around a TensorDef with additional context-bound state.""" - yaml_tag = "LinalgTensorDef" - - def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap): - self.tensor_def = tensor_def - self.shape_map = shape_map - self.indexing_map = None # type: Optional[_ir.AffineMap] - - def to_yaml_custom_dict(self): - - def get_usage(): - if self.tensor_def.output: - return "output" - else: - return "input" - - return dict( - name=self.tensor_def.tensor_name, - usage=get_usage(), - shape=_serialize_affine_map(self.shape_map), - element_type_var=self.tensor_def.type_var.name, - ) - - def __repr__(self): - return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" - - -class LinalgIndexingMapsConfig(YAMLObject): - """Abstracts the style of indexing maps that the op exports. - - Presently only static (tied to the op name) indexing maps are supported. In - the future, it is expected that we will have additional variants: - - Dynamic based on attributes - - Dynamic based on operands - Each is expected to require a different variant of specification. - """ - yaml_tag = "!LinalgIndexingMapsConfig" - - def __init__(self, - static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): - self.static_indexing_maps = static_indexing_maps - - def to_yaml_custom_dict(self): - if self.static_indexing_maps is not None: - return dict(static_indexing_maps=[ - _serialize_affine_map(m) for m in self.static_indexing_maps - ]) - raise ValueError( - f"LinalgIndexingMapsConfig must have one type of indexing map" - f"(got none)") - - -class LinalgStructuredOpConfig(YAMLObject): - """Configuration for metadata sufficient to construct a linalg single - contraction named op.""" - - yaml_tag = "!LinalgStructuredOpConfig" - - def __init__(self, - comprehension: Comprehension, - context: Optional[_ir.Context] = None): - self.context = context if context is not None else _ir.Context() - self.affine_state = AffineBuildState() - self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] - self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig] - self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - - # Compute the ordered set of writes. - collected_uses = set() - for write_use, read_use in zip(comprehension.definitions, - comprehension.values): - self.writes.append((write_use, read_use)) - - for write_use, read_use in self.writes: - collected_uses.add(write_use) - read_use.collect_uses(collected_uses) - - # Need to add all definitions before uses, so process twice. - for use in collected_uses: - self.add_tensor_arg(use.tensor_def) - for use in collected_uses: - self.add_use(use) - - # Now normalize all defs and uses indexing maps now that full count of - # dims and symbols are known. - for cuse in self.uses.values(): - cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for cdef in self.tensor_args.values(): - cdef.shape_map = self._normalize_affine_map(cdef.shape_map, - with_dims=False) - - # Now for each write use, propagate the indexing maps from the use to the - # tensor, ensuring that there are not conflicts. - for write_use, _ in self.writes: - write_tensor_def = self.tensor_args[write_use.tensor_def] - if write_tensor_def.indexing_map: - raise ValueError( - f"Unexpected multi-write to a single tensor: {write_tensor_def}") - write_tensor_def.indexing_map = self.uses[write_use].indexing_map - - # For each read use, propagate the indexing maps from the use to the - # tensor, ensuring that there are not conflicts. - for _, read_expr in self.writes: - read_uses = set() # type: Set[TensorUse] - read_expr.collect_uses(read_uses) - for read_use in read_uses: - read_tensor_def = self.tensor_args[read_use.tensor_def] - if (read_tensor_def.indexing_map and - read_tensor_def.indexing_map != self.uses[read_use].indexing_map): - raise ValueError( - f"Unexpected multi-read of a tensor with different accesses:" - f"{read_tensor_def} vs {read_use}") - read_tensor_def.indexing_map = self.uses[read_use].indexing_map - - # Sanity check that all defs have an indexing map. - assert all(d.indexing_map for d in self.tensor_args.values()), ( - f"Missing indexing map on TensorDef: {self.tensor_args}") - - # Collect reduction dims and ensure all the same. - all_reduction_dims = set(comprehension.all_reduction_dims) - if len(all_reduction_dims) != 1: - raise ValueError( - f"All writes within a generic must have the same reduction " - f"dims. Got: {all_reduction_dims}") - self.reduction_dims = next(iter(all_reduction_dims)) - - # Generate the scalar assignments (used to build a body). - self.assignments = [ - ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) - for write_use, read_expr in self.writes - ] - - @property - def ordered_tensor_args(self) -> Sequence[TensorDefConfig]: - return sorted(self.tensor_args.values(), - key=lambda tdc: tdc.tensor_def.registered_index) - - @property - def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]: - return sorted(self.uses.values(), - key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) - - @property - def ordered_dims(self) -> Sequence[Tuple[str, int]]: - """Gets the ordered list of dim bindings (symbolic name, position). - - TODO: The original parser relies on parse ordering to arrive at the - iterator types, but that ordering is not defined on the Python side, so - this may be ambiguous. - """ - return list(self.affine_state.all_dims.items()) - - @property - def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [use.indexing_map for use in self.ordered_tensor_uses] - - @property - def iterator_types(self) -> Sequence[str]: - - def get_type(symbolic_name, position): - for reduction_dim_expr in self.reduction_dims: - if reduction_dim_expr.dimname == symbolic_name: - return "reduction" - return "parallel" - - return [get_type(*dim) for dim in self.ordered_dims] - - def add_tensor_arg(self, tensor_def: TensorDef): - if tensor_def in self.tensor_args: - return - with self.context: - local_state = AffineBuildState(global_state=self.affine_state, - allow_new_dims=False) - exprs = [] - for expr in tensor_def.shape: - exprs.append(expr.build(state=local_state)) - assert local_state.local_dim_count == 0 - indexing_map = _ir.AffineMap.get(dim_count=0, - symbol_count=local_state.symbol_count, - exprs=exprs) - - def_config = TensorDefConfig(tensor_def, indexing_map) - self.tensor_args[tensor_def] = def_config - - def add_use(self, tensor_use: TensorUse): - if tensor_use in self.uses: - return - with self.context: - local_state = AffineBuildState(global_state=self.affine_state, - allow_new_symbols=False) - exprs = [] - for expr in tensor_use.indices: - exprs.append(expr.build(state=local_state)) - assert local_state.local_symbol_count == 0 - indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) - - use_config = TensorUseConfig(tensor_use, indexing_map) - self.uses[tensor_use] = use_config - - def _normalize_affine_map(self, - affine_map: _ir.AffineMap, - with_dims: bool = True) -> _ir.AffineMap: - """Normalizes an indexing map to have the max known symbols and dims.""" - with self.context: - return _ir.AffineMap.get( - dim_count=self.affine_state.dim_count if with_dims else 0, - symbol_count=self.affine_state.symbol_count, - exprs=list(affine_map.results)) - - def to_yaml_custom_dict(self): - self_dict = dict( - args=self.ordered_tensor_args, - # TODO: Refactor the hierarchy internally when supporting more - # than static (preserving this serialized form). - indexing_maps=LinalgIndexingMapsConfig( - static_indexing_maps=self.indexing_maps), - iterator_types=self.iterator_types, - assignments=self.assignments, - ) - return self_dict - - def __repr__(self): - lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] - lines.append("tensor_args=[") - for def_config in self.ordered_tensor_args: - lines.append(f" {repr(def_config)}") - lines.append("], indexing_maps=[") - for m in self.indexing_maps: - lines.append(f" {repr(m)}") - lines.append(f"], iterator_types=[") - for t in self.iterator_types: - lines.append(f" {t}") - lines.append("])") - return "\n".join(lines) - - -class LinalgOpConfig(YAMLObject): - """Container for any supported linalg op type. - - This includes the concrete type by name for ease of parsing by systems - that ignore tags. - """ - yaml_tag = "!LinalgOpConfig" - - def __init__(self, - metadata: OpMetadataDef, - *, - structured_op: Optional[LinalgStructuredOpConfig] = None): - self.metadata = metadata - self.structured_op = structured_op - - def to_yaml_custom_dict(self): - self_dict = dict(metadata=self.metadata,) - if self.structured_op: - self_dict["structured_op"] = self.structured_op - return self_dict - - @staticmethod - def from_linalg_op_def( - tc_op_def: LinalgOpDef, - context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: - """Expands a LinalgOpDef into corresponding Linalg configured ops.""" - # TODO: Many LinalgOpDef patterns need to expand to multiple generics. - assert len( - tc_op_def.comprehensions) == 1, "Only one comprehension supported" - return [ - LinalgOpConfig(tc_op_def.metadata, - structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], context)), - ] - - def __repr__(self): - return (f"LinalgOpConfig(metadata={self.metadata},\n" - f"structured_op={self.structured_op})") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py deleted file mode 100644 index d367c5bdd..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ /dev/null @@ -1,91 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from typing import Dict, List - -from contextlib import contextmanager -import functools -import inspect -import threading - -from mlir import ir -from .comprehension import * - -_CONTEXT = threading.local() - - -@contextmanager -def bind_op_def(model: LinalgOpDef): - if hasattr(_CONTEXT, "current_op_def"): - raise ValueError("Cannot recursively define an operation") - _CONTEXT.current_op_def = model - try: - yield model - finally: - del _CONTEXT.current_op_def - - -def current_op_def() -> LinalgOpDef: - try: - return _CONTEXT.current_op_def - except AttributeError: - raise ValueError( - "Attempt to access the current op definition being defined " - "but none is set. Did you mean to call this in an op definition?") - - -class DefinedOpCallable: - """Callable that wraps any defined op function.""" - - def __init__(self, op_name: str, model: LinalgOpDef): - self.op_name = op_name - self.model = model - - def __call__(self, *args, **kwargs): - # TODO: Upstream the emitter and invoke here - raise NotImplementedError("Linalg generic emission not yet implemented") - - -def linalg_structured_op(dsl_func=None, - *, - op_name=None, - op_class_name=None) -> DefinedOpCallable: - if dsl_func is None: - # Curry the keyword args in for delayed application. - return functools.partial(tc_def_op, - op_name=op_name, - op_class_name=op_class_name) - # Determine default names by introspecting the function. - if op_name is None: - op_name = dsl_func.__name__ - if op_class_name is None: - # Camel case it. - op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - - tc_model = LinalgOpDef(name=op_name, - cpp_op_name=op_class_name, - doc=inspect.getdoc(dsl_func)) - - # Extract arguments and TensorDefs from the signature. - dsl_func_args = list() - sig = inspect.signature(dsl_func) - for param_name, param in sig.parameters.items(): - param_default = param.default - if not isinstance(param_default, TensorDef): - raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...): Found {param_name}: {param_default}") - dsl_func_args.append(param_default) - tc_model.add_tensor(param_name, param_default) - - # Invoke the DSL func to finish populating the model. - with bind_op_def(tc_model): - dsl_func(*dsl_func_args) - - # TODO: The returned callable should be an IR emitter but that is not - # upstreamed yet. - return DefinedOpCallable(op_name, tc_model) - - -def implements(*interfaces: OpInterfaceDef): - current_op_def().metadata.implements.extend(interfaces) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py deleted file mode 100644 index 9ebf7a9a0..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ /dev/null @@ -1,124 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Models DAGs of scalar math expressions. - -Used for generating region bodies at the "math" level where they are still type -polymorphic. This is modeled to be polymorphic by attribute name for interop -with serialization schemes that are just plain-old-dicts. - -These classes are typically not user accessed and are created as a by-product -of interpreting a comprehension DSL and model the operations to perform in the -op body. The class hierarchy is laid out to map well to a form of YAML that -can be easily consumed from the C++ side, not necessarily for ergonomics. -""" - -from typing import Optional, Sequence - -from .yaml_helper import * -from .types import * - -__all__ = [ - "ScalarAssign", - "ScalarApplyFn", - "ScalarArg", - "ScalarExpression", - "ScalarSymbolicCast", -] - - -class ScalarApplyFn: - """A type of ScalarExpression that applies a named function to operands.""" - - def __init__(self, fn_name: str, *operands: "ScalarExpression"): - self.fn_name = fn_name - self.operands = operands - - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_apply=self) - - def __repr__(self): - return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})" - - -class ScalarArg: - """A type of ScalarExpression that references a named argument.""" - - def __init__(self, arg: str): - self.arg = arg - - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_arg=self) - - def __repr__(self): - return f"(ScalarArg({self.arg})" - - -class ScalarSymbolicCast: - """A type of ScalarExpression that symbolically casts an operand to a TypeVar. - """ - - def __init__(self, to_type: TypeVar, operand: "ScalarExpression"): - self.to_type = to_type - self.operand = operand - - def expr(self) -> "ScalarExpression": - return ScalarExpression(symbolic_cast=self) - - def __repr__(self): - return f"ScalarSymbolicCast({self.to_type}, {self.operand})" - - -class ScalarExpression(YAMLObject): - """An expression on scalar values. - - Can be one of: - - ScalarApplyFn - - ScalarArg - - ScalarSymbolicCast - """ - yaml_tag = "!ScalarExpression" - - def __init__(self, - scalar_apply: Optional[ScalarApplyFn] = None, - scalar_arg: Optional[ScalarArg] = None, - symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1: - raise ValueError( - "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be " - "specified") - self.scalar_apply = scalar_apply - self.scalar_arg = scalar_arg - self.symbolic_cast = symbolic_cast - - def to_yaml_custom_dict(self): - if self.scalar_apply: - return dict(scalar_apply=dict( - fn_name=self.scalar_apply.fn_name, - operands=list(self.scalar_apply.operands), - )) - elif self.scalar_arg: - return dict(scalar_arg=self.scalar_arg.arg) - elif self.symbolic_cast: - # Note that even though operands must be arity 1, we write it the - # same way as for apply because it allows handling code to be more - # generic vs having a special form. - return dict(symbolic_cast=dict(type_var=self.symbolic_cast.to_type.name, - operands=[self.symbolic_cast.operand])) - else: - raise ValueError(f"Unexpected ScalarExpression type: {self}") - - -class ScalarAssign(YAMLObject): - """An assignment to a named argument (LHS of a comprehension).""" - yaml_tag = "!ScalarAssign" - - def __init__(self, arg: str, value: ScalarExpression): - self.arg = arg - self.value = value - - def to_yaml_custom_dict(self): - return dict(arg=self.arg, value=self.value) - - def __repr__(self): - return f"ScalarAssign({self.arg}, {self.value})" diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py deleted file mode 100644 index 35bbfe712..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py +++ /dev/null @@ -1,69 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Facility for symbolically referencing type variables. - -Type variables are instances of the TypeVar class, which is uniqued by name. -An "expando" accessor `TV` is provided that generates a named TypeVar for -any attribute access: - - >>> TV.T - TypeVar(T) - >>> TV.T is TV.U - False - >>> TV.T is TV.T - True -""" - -from enum import Enum -from typing import Dict - -__all__ = [ - "TypeVar", - "TV", - - # TypeVar aliases. - "T", - "U", - "V", -] - - -class TypeVar: - """A replaceable type variable. - - Type variables are uniqued by name. - """ - ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] - - def __new__(cls, name: str): - existing = cls.ALL_TYPEVARS.get(name) - if existing is not None: - return existing - new = super().__new__(cls) - new.name = name - cls.ALL_TYPEVARS[name] = new - return new - - def __repr__(self): - return f"TypeVar({self.name})" - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique type vars on attr access.""" - - class ExpandoTypeVars: - - def __getattr__(self, n): - return cls(n) - - return ExpandoTypeVars() - - -# Expando access via TV.foo -TV = TypeVar.create_expando() - -# Some common type name aliases. -T = TV.T -U = TV.U -V = TV.V diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py deleted file mode 100644 index 1945eea53..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py +++ /dev/null @@ -1,54 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""YAML serialization is routed through here to centralize common logic.""" - -import sys - -try: - import yaml -except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"This tool requires PyYAML but it was not installed. " - f"Recommend: {sys.executable} -m pip install PyYAML") from e - -__all__ = [ - "yaml_dump", - "yaml_dump_all", - "YAMLObject", -] - - -class YAMLObject(yaml.YAMLObject): - - @classmethod - def to_yaml(cls, dumper, self): - """Default to a custom dictionary mapping.""" - return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) - - def to_yaml_custom_dict(self): - raise NotImplementedError() - - def as_linalg_yaml(self): - return yaml_dump(self) - - -def multiline_str_representer(dumper, data): - if len(data.splitlines()) > 1: - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - else: - return dumper.represent_scalar('tag:yaml.org,2002:str', data) - - -yaml.add_representer(str, multiline_str_representer) - - -def yaml_dump(data, sort_keys=False, **kwargs): - return yaml.dump(data, sort_keys=sort_keys, **kwargs) - - -def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): - return yaml.dump_all(data, - sort_keys=sort_keys, - explicit_start=explicit_start, - **kwargs) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py deleted file mode 100644 index 229458855..000000000 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ /dev/null @@ -1,70 +0,0 @@ -from ..lang import * - -T1 = TV.T1 -T2 = TV.T2 - -Batch = S.Batch - - -@linalg_structured_op -def matmul(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs a matrix multiplacation of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) - - -@linalg_structured_op -def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplacation of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) - - -@linalg_structured_op -def matvec(A=TensorDef(T1, S.M, S.N), - y=TensorDef(T2, S.N), - x=TensorDef(U, S.M, output=True)): - """Performs a matrix-vector multiplication. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) - - -@linalg_structured_op -def vecmat(y=TensorDef(T1, S.M), - A=TensorDef(T2, S.M, S.N), - x=TensorDef(U, S.N, output=True)): - """Performs a vector-matrix multiplacation. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) - - -@linalg_structured_op -def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, - output=True)): - """Performs a dot product of two vectors to a scalar result. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) diff --git a/mlir/lib/Bindings/Python/mlir/execution_engine.py b/mlir/lib/Bindings/Python/mlir/execution_engine.py deleted file mode 100644 index 89bd4aad5..000000000 --- a/mlir/lib/Bindings/Python/mlir/execution_engine.py +++ /dev/null @@ -1,31 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Simply a wrapper around the extension module of the same name. -from ._cext_loader import _cext -import ctypes - -class ExecutionEngine(_cext.execution_engine.ExecutionEngine): - - def lookup(self, name): - """Lookup a function emitted with the `llvm.emit_c_interface` - attribute and returns a ctype callable. - Raise a RuntimeError if the function isn't found. - """ - func = self.raw_lookup("_mlir_ciface_" + name) - if not func: - raise RuntimeError("Unknown function " + name) - prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - return prototype(func) - - def invoke(self, name, *ctypes_args): - """Invoke a function with the list of ctypes arguments. - All arguments must be pointers. - Raise a RuntimeError if the function isn't found. - """ - func = self.lookup(name) - packed_args = (ctypes.c_void_p * len(ctypes_args))() - for argNum in range(len(ctypes_args)): - packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) - func(packed_args) diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py deleted file mode 100644 index e5ba1bdb0..000000000 --- a/mlir/lib/Bindings/Python/mlir/ir.py +++ /dev/null @@ -1,8 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Simply a wrapper around the extension module of the same name. -from ._cext_loader import _reexport_cext -_reexport_cext("ir", __name__) -del _reexport_cext diff --git a/mlir/lib/Bindings/Python/mlir/passmanager.py b/mlir/lib/Bindings/Python/mlir/passmanager.py deleted file mode 100644 index 6b267b76e..000000000 --- a/mlir/lib/Bindings/Python/mlir/passmanager.py +++ /dev/null @@ -1,8 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Simply a wrapper around the extension module of the same name. -from ._cext_loader import _reexport_cext -_reexport_cext("passmanager", __name__) -del _reexport_cext diff --git a/mlir/lib/Bindings/Python/mlir/transforms/__init__.py b/mlir/lib/Bindings/Python/mlir/transforms/__init__.py deleted file mode 100644 index 2149933d0..000000000 --- a/mlir/lib/Bindings/Python/mlir/transforms/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._cext_loader import _load_extension -_cextTransforms = _load_extension("_mlirTransforms") diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index ba58d99a7..6c4385084 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,43 +1,39 @@ +# For upstream, we accumulate all libraries into the MLIR_CAPI_LIBS +# property via a custom wrapper function. This is then used to create an +# aggregate below. +set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBS) +function(add_mlir_upstream_c_api_library name) + add_mlir_public_c_api_library(${name} ${ARGN}) + set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBS ${name}) +endfunction() + +add_subdirectory(Debug) add_subdirectory(Dialect) add_subdirectory(Conversion) -add_subdirectory(ExecutionEngine) +add_subdirectory(Interfaces) add_subdirectory(IR) -add_subdirectory(Registration) +add_subdirectory(RegisterEverything) add_subdirectory(Transforms) +add_subdirectory(Target) +if(MLIR_ENABLE_EXECUTION_ENGINE) + add_subdirectory(ExecutionEngine) +endif() -################################################################################ -# libMLIRPublicAPI shared library/DLL. -################################################################################ - -get_property(public_api_libs GLOBAL PROPERTY MLIR_PUBLIC_C_API_LIBS) - -foreach(lib ${public_api_libs}) - if(XCODE) - # Xcode doesn't support object libraries, so we have to trick it into - # linking the static libraries instead. - list(APPEND _DEPS "-force_load" ${lib}) +# Build the optional CAPI dylib. +if(MLIR_BUILD_MLIR_C_DYLIB) + message(STATUS "Building MLIR-C dylib") + get_property(_capi_libraries GLOBAL PROPERTY MLIR_CAPI_LIBS) + add_mlir_aggregate(MLIR-C + SHARED + EMBED_LIBS + ${_capi_libraries} + ) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_options(MLIR-C PRIVATE "-Wl,-exclude-libs,ALL") else() - list(APPEND _OBJECTS $) + if(NOT CMAKE_C_VISIBILITY_PRESET STREQUAL "hidden" OR NOT CMAKE_CXX_VISIBILITY_PRESET STREQUAL "hidden") + message(STATUS "MLIR-C on this platform exports all symbols. Recommend building with CMAKE_(C|CXX)_VISIBILITY_PRESET=hidden or implement filtering support.") + endif() endif() - # Accumulate transitive deps of each exported lib into _DEPS. - list(APPEND _DEPS $) -endforeach() - -add_mlir_library(MLIRPublicAPI - SHARED - ${_OBJECTS} - EXCLUDE_FROM_LIBMLIR - LINK_LIBS - # Dependency on the implementation shared library. - $<$:MLIR> - ${_DEPS} -) - -target_link_options( - MLIRPublicAPI - PRIVATE - # On Linux, disable re-export of any static linked libraries that - # came through. - $<$:LINKER:--exclude-libs,ALL> -) +endif() diff --git a/mlir/lib/CAPI/Conversion/CMakeLists.txt b/mlir/lib/CAPI/Conversion/CMakeLists.txt index 83435cd19..8cafc09d3 100644 --- a/mlir/lib/CAPI/Conversion/CMakeLists.txt +++ b/mlir/lib/CAPI/Conversion/CMakeLists.txt @@ -1,7 +1,10 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -add_mlir_public_c_api_library(MLIRCAPIConversion +add_mlir_upstream_c_api_library(MLIRCAPIConversion Passes.cpp + DEPENDS + MLIRConversionPassIncGen + LINK_LIBS PUBLIC ${conversion_libs} ) diff --git a/mlir/lib/CAPI/Debug/CMakeLists.txt b/mlir/lib/CAPI/Debug/CMakeLists.txt new file mode 100644 index 000000000..7b32f3ae0 --- /dev/null +++ b/mlir/lib/CAPI/Debug/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_upstream_c_api_library(MLIRCAPIDebug + Debug.cpp + + LINK_LIBS PUBLIC + MLIRSupport +) diff --git a/mlir/lib/CAPI/Debug/Debug.cpp b/mlir/lib/CAPI/Debug/Debug.cpp new file mode 100644 index 000000000..320ece499 --- /dev/null +++ b/mlir/lib/CAPI/Debug/Debug.cpp @@ -0,0 +1,36 @@ +//===- Debug.cpp - C Interface for MLIR/LLVM Debugging Functions ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Debug.h" +#include "mlir-c/Support.h" + +#include "mlir/CAPI/Support.h" + +#include "llvm/Support/Debug.h" + +void mlirEnableGlobalDebug(bool enable) { llvm::DebugFlag = enable; } + +bool mlirIsGlobalDebugEnabled() { return llvm::DebugFlag; } + +void mlirSetGlobalDebugType(const char *type) { + // Depending on the NDEBUG flag, this name can be either a function or a macro + // that expands to something that isn't a funciton call, so we cannot + // explicitly prefix it with `llvm::` or declare `using` it. + using namespace llvm; + setCurrentDebugType(type); +} + +void mlirSetGlobalDebugTypes(const char **types, intptr_t n) { + using namespace llvm; + setCurrentDebugTypes(types, n); +} + +bool mlirIsCurrentDebugType(const char *type) { + using namespace llvm; + return isCurrentDebugType(type); +} diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp new file mode 100644 index 000000000..d877ca2df --- /dev/null +++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp @@ -0,0 +1,14 @@ +//===- AMDGPU.cpp - C Interface for AMDGPU dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/AMDGPU.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu, + mlir::amdgpu::AMDGPUDialect) diff --git a/mlir/lib/CAPI/Dialect/Standard.cpp b/mlir/lib/CAPI/Dialect/Arith.cpp similarity index 60% rename from mlir/lib/CAPI/Dialect/Standard.cpp rename to mlir/lib/CAPI/Dialect/Arith.cpp index 57083a8a2..993f77e55 100644 --- a/mlir/lib/CAPI/Dialect/Standard.cpp +++ b/mlir/lib/CAPI/Dialect/Arith.cpp @@ -1,4 +1,4 @@ -//===- Standard.cpp - C Interface for Standard dialect --------------------===// +//===- Arith.cpp - C Interface for Arith dialect --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Dialect/Standard.h" +#include "mlir-c/Dialect/Arith.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Arith/IR/Arith.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Standard, std, mlir::StandardOpsDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Arith, arith, mlir::arith::ArithDialect) diff --git a/mlir/lib/CAPI/Dialect/Async.cpp b/mlir/lib/CAPI/Dialect/Async.cpp new file mode 100644 index 000000000..182cbf9df --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Async.cpp @@ -0,0 +1,13 @@ +//===- Async.cpp - C Interface for Async dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir-c/Dialect/Async.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Async, async, mlir::async::AsyncDialect) diff --git a/mlir/lib/CAPI/Dialect/AsyncPasses.cpp b/mlir/lib/CAPI/Dialect/AsyncPasses.cpp new file mode 100644 index 000000000..aa2074dcd --- /dev/null +++ b/mlir/lib/CAPI/Dialect/AsyncPasses.cpp @@ -0,0 +1,26 @@ +//===- AsyncPasses.cpp - C API for Async Dialect Passes -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/Async/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/Async/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index d256309bf..bb1fdf8be 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,48 +1,280 @@ -# TODO: Make the check source feature optional as an argument on *_add_library. -set(LLVM_OPTIONAL_SOURCES - Linalg.cpp - SCF.cpp - Shape.cpp - Standard.cpp - Tensor.cpp +add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU + AMDGPU.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRAMDGPUDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIArith + Arith.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRArithDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIAsync + Async.cpp + AsyncPasses.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + MLIRAsyncPassIncGen + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRAsyncDialect + MLIRAsyncTransforms + MLIRPass +) + +add_mlir_upstream_c_api_library(MLIRCAPIControlFlow + ControlFlow.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRControlFlowDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIEmitC + EmitC.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIREmitCDialect ) -add_mlir_public_c_api_library(MLIRCAPILinalg +add_mlir_upstream_c_api_library(MLIRCAPIMath + Math.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRMathDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIMemRef + MemRef.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRMemRefDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIGPU + GPU.cpp + GPUPasses.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + MLIRGPUPassIncGen + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRGPUTransforms + MLIRPass +) + +add_mlir_upstream_c_api_library(MLIRCAPIIndex + Index.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRIndexDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIIRDL + IRDL.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRIRDL +) + +add_mlir_upstream_c_api_library(MLIRCAPILLVM + LLVM.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRLLVMDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPILinalg Linalg.cpp + LinalgPasses.cpp + + PARTIAL_SOURCES_INTENDED + DEPENDS + MLIRLinalgPassIncGen + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRLinalgDialect + MLIRPass + MLIRLinalgTransforms +) + +add_mlir_upstream_c_api_library(MLIRCAPIMLProgram + MLProgram.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRMLProgramDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPINVGPU + NVGPU.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRNVGPUDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPINVVM + NVVM.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRNVVMDialect +) +add_mlir_upstream_c_api_library(MLIRCAPIROCDL + ROCDL.cpp + + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRLinalg + MLIRROCDLDialect ) -add_mlir_public_c_api_library(MLIRCAPISCF + +add_mlir_upstream_c_api_library(MLIRCAPISCF SCF.cpp + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRSCF + MLIRSCFDialect ) -add_mlir_public_c_api_library(MLIRCAPIShape +add_mlir_upstream_c_api_library(MLIRCAPIShape Shape.cpp + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRShapeDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPISparseTensor + SparseTensor.cpp + SparseTensorPasses.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRSparseTensorDialect + MLIRSparseTensorTransforms +) + +add_mlir_upstream_c_api_library(MLIRCAPIFunc + Func.cpp + + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRShape + MLIRFuncDialect ) -add_mlir_public_c_api_library(MLIRCAPIStandard - Standard.cpp +add_mlir_upstream_c_api_library(MLIRCAPISPIRV + SPIRV.cpp + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRStandard + MLIRSPIRVDialect ) -add_mlir_public_c_api_library(MLIRCAPITensor +add_mlir_upstream_c_api_library(MLIRCAPITensor Tensor.cpp + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRTensorDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPITransformDialect + Transform.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRTransformDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPITransformDialectTransforms + TransformInterpreter.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRTransformDialectTransforms +) + +add_mlir_upstream_c_api_library(MLIRCAPIQuant + Quant.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRQuantDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIOpenMP + OpenMP.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIROpenMPDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIPDL + PDL.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRPDLDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIVector + Vector.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRVectorDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPISMT + SMT.cpp + + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRTensor + MLIRSMT ) diff --git a/mlir/lib/CAPI/Dialect/ControlFlow.cpp b/mlir/lib/CAPI/Dialect/ControlFlow.cpp new file mode 100644 index 000000000..1e5b2de1c --- /dev/null +++ b/mlir/lib/CAPI/Dialect/ControlFlow.cpp @@ -0,0 +1,14 @@ +//===- ControlFlow.cpp - C Interface for ControlFlow dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/ControlFlow.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(ControlFlow, cf, + mlir::cf::ControlFlowDialect) diff --git a/mlir/lib/CAPI/Dialect/EmitC.cpp b/mlir/lib/CAPI/Dialect/EmitC.cpp new file mode 100644 index 000000000..b6d197366 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/EmitC.cpp @@ -0,0 +1,189 @@ +//===- EmitC.cpp - C Interface for EmitC dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/EmitC.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(EmitC, emitc, mlir::emitc::EmitCDialect) + +// Ensure the C-API enums are uint64_t-castable to C++ equivalents. +static_assert(static_cast(MLIR_EMITC_CMP_PREDICATE_EQ) == + static_cast(emitc::CmpPredicate::eq) && + static_cast(MLIR_EMITC_CMP_PREDICATE_NE) == + static_cast(emitc::CmpPredicate::ne) && + static_cast(MLIR_EMITC_CMP_PREDICATE_LT) == + static_cast(emitc::CmpPredicate::lt) && + static_cast(MLIR_EMITC_CMP_PREDICATE_LE) == + static_cast(emitc::CmpPredicate::le) && + static_cast(MLIR_EMITC_CMP_PREDICATE_GT) == + static_cast(emitc::CmpPredicate::gt) && + static_cast(MLIR_EMITC_CMP_PREDICATE_GE) == + static_cast(emitc::CmpPredicate::ge) && + static_cast(MLIR_EMITC_CMP_PREDICATE_THREE_WAY) == + static_cast(emitc::CmpPredicate::three_way), + "MlirEmitCCmpPredicate (C-API) and CmpPredicate (C++) mismatch"); + +//===---------------------------------------------------------------------===// +// ArrayType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCArrayType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCArrayTypeGetTypeID(void) { + return wrap(emitc::ArrayType::getTypeID()); +} + +MlirType mlirEmitCArrayTypeGet(intptr_t nDims, int64_t *shape, + MlirType elementType) { + return wrap( + emitc::ArrayType::get(llvm::ArrayRef(shape, nDims), unwrap(elementType))); +} + +//===---------------------------------------------------------------------===// +// LValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCLValueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCLValueTypeGetTypeID(void) { + return wrap(emitc::LValueType::getTypeID()); +} + +MlirType mlirEmitCLValueTypeGet(MlirType valueType) { + return wrap(emitc::LValueType::get(unwrap(valueType))); +} + +//===---------------------------------------------------------------------===// +// OpaqueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCOpaqueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCOpaqueTypeGetTypeID(void) { + return wrap(emitc::OpaqueType::getTypeID()); +} + +MlirType mlirEmitCOpaqueTypeGet(MlirContext ctx, MlirStringRef value) { + return wrap(emitc::OpaqueType::get(unwrap(ctx), unwrap(value))); +} + +//===---------------------------------------------------------------------===// +// PointerType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCPointerType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCPointerTypeGetTypeID(void) { + return wrap(emitc::PointerType::getTypeID()); +} + +MlirType mlirEmitCPointerTypeGet(MlirType pointee) { + return wrap(emitc::PointerType::get(unwrap(pointee))); +} + +//===---------------------------------------------------------------------===// +// PtrDiffTType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCPtrDiffTType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCPtrDiffTTypeGetTypeID(void) { + return wrap(emitc::PtrDiffTType::getTypeID()); +} + +MlirType mlirEmitCPtrDiffTTypeGet(MlirContext ctx) { + return wrap(emitc::PtrDiffTType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// SignedSizeTType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCSignedSizeTType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCSignedSizeTTypeGetTypeID(void) { + return wrap(emitc::SignedSizeTType::getTypeID()); +} + +MlirType mlirEmitCSignedSizeTTypeGet(MlirContext ctx) { + return wrap(emitc::SignedSizeTType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// SizeTType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCSizeTType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCSizeTTypeGetTypeID(void) { + return wrap(emitc::SizeTType::getTypeID()); +} + +MlirType mlirEmitCSizeTTypeGet(MlirContext ctx) { + return wrap(emitc::SizeTType::get(unwrap(ctx))); +} + +//===----------------------------------------------------------------------===// +// CmpPredicate attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAEmitCCmpPredicate(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirEmitCCmpPredicateAttrGet(MlirContext ctx, + MlirEmitCCmpPredicate val) { + return wrap((Attribute)emitc::CmpPredicateAttr::get( + unwrap(ctx), static_cast(val))); +} + +MlirEmitCCmpPredicate mlirEmitCCmpPredicateAttrGetValue(MlirAttribute attr) { + return static_cast( + llvm::cast(unwrap(attr)).getValue()); +} + +MlirTypeID mlirEmitCCmpPredicateAttrGetTypeID(void) { + return wrap(emitc::CmpPredicateAttr::getTypeID()); +} + +//===----------------------------------------------------------------------===// +// Opaque attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAEmitCOpaque(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirEmitCOpaqueAttrGet(MlirContext ctx, MlirStringRef value) { + return wrap((Attribute)emitc::OpaqueAttr::get(unwrap(ctx), unwrap(value))); +} + +MlirStringRef mlirEmitCOpaqueAttrGetValue(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getValue()); +} + +MlirTypeID mlirEmitCOpaqueAttrGetTypeID(void) { + return wrap(emitc::OpaqueAttr::getTypeID()); +} diff --git a/mlir/lib/CAPI/Dialect/Func.cpp b/mlir/lib/CAPI/Dialect/Func.cpp new file mode 100644 index 000000000..8265b61b9 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Func.cpp @@ -0,0 +1,27 @@ +//===- Func.cpp - C Interface for Func dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Func.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Func, func, mlir::func::FuncDialect) + +void mlirFuncSetArgAttr(MlirOperation op, intptr_t pos, MlirStringRef name, + MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setArgAttr(pos, unwrap(name), unwrap(attr)); +} + +void mlirFuncSetResultAttr(MlirOperation op, intptr_t pos, MlirStringRef name, + MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setResultAttr(pos, unwrap(name), unwrap(attr)); +} diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp new file mode 100644 index 000000000..e4796ed14 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/GPU.cpp @@ -0,0 +1,111 @@ +//===- GPU.cpp - C Interface for GPU dialect ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/GPU.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "llvm/Support/Casting.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect) + +//===-------------------------------------------------------------------===// +// AsyncTokenType +//===-------------------------------------------------------------------===// + +bool mlirTypeIsAGPUAsyncTokenType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) { + return wrap(gpu::AsyncTokenType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// ObjectAttr +//===---------------------------------------------------------------------===// + +bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, + uint32_t format, MlirStringRef objectStrRef, + MlirAttribute mlirObjectProps) { + MLIRContext *ctx = unwrap(mlirCtx); + llvm::StringRef object = unwrap(objectStrRef); + DictionaryAttr objectProps; + if (mlirObjectProps.ptr != nullptr) + objectProps = llvm::cast(unwrap(mlirObjectProps)); + return wrap(gpu::ObjectAttr::get( + ctx, unwrap(target), static_cast(format), + StringAttr::get(ctx, object), objectProps, nullptr)); +} + +MlirAttribute mlirGPUObjectAttrGetWithKernels(MlirContext mlirCtx, + MlirAttribute target, + uint32_t format, + MlirStringRef objectStrRef, + MlirAttribute mlirObjectProps, + MlirAttribute mlirKernelsAttr) { + MLIRContext *ctx = unwrap(mlirCtx); + llvm::StringRef object = unwrap(objectStrRef); + DictionaryAttr objectProps; + if (mlirObjectProps.ptr != nullptr) + objectProps = llvm::cast(unwrap(mlirObjectProps)); + gpu::KernelTableAttr kernels; + if (mlirKernelsAttr.ptr != nullptr) + kernels = llvm::cast(unwrap(mlirKernelsAttr)); + return wrap(gpu::ObjectAttr::get( + ctx, unwrap(target), static_cast(format), + StringAttr::get(ctx, object), objectProps, kernels)); +} + +MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return wrap(objectAttr.getTarget()); +} + +uint32_t mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return static_cast(objectAttr.getFormat()); +} + +MlirStringRef mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + llvm::StringRef object = objectAttr.getObject(); + return mlirStringRefCreate(object.data(), object.size()); +} + +bool mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return objectAttr.getProperties() != nullptr; +} + +MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return wrap(objectAttr.getProperties()); +} + +bool mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return objectAttr.getKernels() != nullptr; +} + +MlirAttribute mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return wrap(objectAttr.getKernels()); +} diff --git a/mlir/lib/CAPI/Dialect/GPUPasses.cpp b/mlir/lib/CAPI/Dialect/GPUPasses.cpp new file mode 100644 index 000000000..5128c63ec --- /dev/null +++ b/mlir/lib/CAPI/Dialect/GPUPasses.cpp @@ -0,0 +1,26 @@ +//===- GPUPasses.cpp - C API for GPU Dialect Passes ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/GPU/Transforms/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/GPU/Transforms/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/lib/CAPI/Dialect/IRDL.cpp b/mlir/lib/CAPI/Dialect/IRDL.cpp new file mode 100644 index 000000000..cb9dc8ceb --- /dev/null +++ b/mlir/lib/CAPI/Dialect/IRDL.cpp @@ -0,0 +1,18 @@ +//===- IRDL.cpp - C Interface for IRDL dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/IRDL.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLLoading.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IRDL, irdl, mlir::irdl::IRDLDialect) + +MlirLogicalResult mlirLoadIRDLDialects(MlirModule module) { + return wrap(mlir::irdl::loadDialects(unwrap(module))); +} diff --git a/mlir/lib/CAPI/Dialect/Index.cpp b/mlir/lib/CAPI/Dialect/Index.cpp new file mode 100644 index 000000000..845791436 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Index.cpp @@ -0,0 +1,13 @@ +//===- Index.cpp - C Interface for Index dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Index.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Index, index, mlir::index::IndexDialect) diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp new file mode 100644 index 000000000..69c804b76 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -0,0 +1,409 @@ +//===- LLVM.cpp - C Interface for LLVM dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "llvm-c/Core.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" + +using namespace mlir; +using namespace mlir::LLVM; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(LLVM, llvm, LLVMDialect) + +MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { + return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); +} + +bool mlirTypeIsALLVMPointerType(MlirType type) { + return isa(unwrap(type)); +} + +unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) { + return cast(unwrap(pointerType)).getAddressSpace(); +} + +MlirType mlirLLVMVoidTypeGet(MlirContext ctx) { + return wrap(LLVMVoidType::get(unwrap(ctx))); +} + +MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements) { + return wrap(LLVMArrayType::get(unwrap(elementType), numElements)); +} + +MlirType mlirLLVMArrayTypeGetElementType(MlirType type) { + return wrap(cast(unwrap(type)).getElementType()); +} + +MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, + MlirType const *argumentTypes, bool isVarArg) { + SmallVector argumentStorage; + return wrap(LLVMFunctionType::get( + unwrap(resultType), + unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); +} + +intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) { + return llvm::cast(unwrap(type)).getNumParams(); +} + +MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { + assert(pos >= 0 && "pos in array must be positive"); + return wrap(llvm::cast(unwrap(type)) + .getParamType(static_cast(pos))); +} + +MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getReturnType()); +} + +bool mlirTypeIsALLVMStructType(MlirType type) { + return isa(unwrap(type)); +} + +bool mlirLLVMStructTypeIsLiteral(MlirType type) { + return !cast(unwrap(type)).isIdentified(); +} + +intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) { + return cast(unwrap(type)).getBody().size(); +} + +MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) { + return wrap(cast(unwrap(type)).getBody()[position]); +} + +bool mlirLLVMStructTypeIsPacked(MlirType type) { + return cast(unwrap(type)).isPacked(); +} + +MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) { + return wrap(cast(unwrap(type)).getName()); +} + +bool mlirLLVMStructTypeIsOpaque(MlirType type) { + return cast(unwrap(type)).isOpaque(); +} + +MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fieldStorage; + return wrap(LLVMStructType::getLiteral( + unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage), + isPacked)); +} + +MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, + intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fieldStorage; + return wrap(LLVMStructType::getLiteralChecked( + [loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(), + unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked)); +} + +MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) { + return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx))); +} + +MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) { + return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name))); +} + +MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name, + intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fields; + return wrap(LLVMStructType::getNewIdentified( + unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields), + isPacked)); +} + +MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType, + intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fields; + return wrap( + cast(unwrap(structType)) + .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked)); +} + +MlirAttribute mlirLLVMDIExpressionElemAttrGet(MlirContext ctx, + unsigned int opcode, + intptr_t nArguments, + uint64_t const *arguments) { + auto list = ArrayRef(arguments, nArguments); + return wrap(DIExpressionElemAttr::get(unwrap(ctx), opcode, list)); +} + +MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations, + MlirAttribute const *operations) { + SmallVector attrStorage; + attrStorage.reserve(nOperations); + + return wrap(DIExpressionAttr::get( + unwrap(ctx), + llvm::map_to_vector( + unwrapList(nOperations, operations, attrStorage), + [](Attribute a) { return cast(a); }))); +} + +MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) { + return wrap(DINullTypeAttr::get(unwrap(ctx))); +} + +MlirAttribute mlirLLVMDIBasicTypeAttrGet(MlirContext ctx, unsigned int tag, + MlirAttribute name, + uint64_t sizeInBits, + MlirLLVMTypeEncoding encoding) { + + return wrap(DIBasicTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, encoding)); +} + +MlirAttribute mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId) { + return wrap( + DICompositeTypeAttr::getRecSelf(cast(unwrap(recId)))); +} + +MlirAttribute mlirLLVMDICompositeTypeAttrGet( + MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, + MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, + uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, + MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, + MlirAttribute associated) { + SmallVector elementsStorage; + elementsStorage.reserve(nElements); + + return wrap(DICompositeTypeAttr::get( + unwrap(ctx), cast(unwrap(recId)), isRecSelf, tag, + cast(unwrap(name)), cast(unwrap(file)), line, + cast(unwrap(scope)), cast(unwrap(baseType)), + DIFlags(flags), sizeInBits, alignInBits, + llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), + [](Attribute a) { return cast(a); }), + cast(unwrap(dataLocation)), + cast(unwrap(rank)), + cast(unwrap(allocated)), + cast(unwrap(associated)))); +} + +MlirAttribute mlirLLVMDIDerivedTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, + MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, + uint64_t offsetInBits, int64_t dwarfAddressSpace, MlirAttribute extraData) { + std::optional addressSpace = std::nullopt; + if (dwarfAddressSpace >= 0) + addressSpace = (unsigned)dwarfAddressSpace; + return wrap(DIDerivedTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), + cast(unwrap(baseType)), sizeInBits, alignInBits, offsetInBits, + addressSpace, cast(unwrap(extraData)))); +} + +MlirAttribute mlirLLVMDIStringTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, + uint32_t alignInBits, MlirAttribute stringLength, + MlirAttribute stringLengthExp, MlirAttribute stringLocationExp, + MlirLLVMTypeEncoding encoding) { + return wrap(DIStringTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, alignInBits, + cast(unwrap(stringLength)), + cast(unwrap(stringLengthExp)), + cast(unwrap(stringLocationExp)), encoding)); +} + +MlirAttribute +mlirLLVMDIDerivedTypeAttrGetBaseType(MlirAttribute diDerivedType) { + return wrap(cast(unwrap(diDerivedType)).getBaseType()); +} + +MlirAttribute mlirLLVMCConvAttrGet(MlirContext ctx, MlirLLVMCConv cconv) { + return wrap(CConvAttr::get(unwrap(ctx), CConv(cconv))); +} + +MlirAttribute mlirLLVMComdatAttrGet(MlirContext ctx, MlirLLVMComdat comdat) { + return wrap(ComdatAttr::get(unwrap(ctx), comdat::Comdat(comdat))); +} + +MlirAttribute mlirLLVMLinkageAttrGet(MlirContext ctx, MlirLLVMLinkage linkage) { + return wrap(LinkageAttr::get(unwrap(ctx), linkage::Linkage(linkage))); +} + +MlirAttribute mlirLLVMDIFileAttrGet(MlirContext ctx, MlirAttribute name, + MlirAttribute directory) { + return wrap(DIFileAttr::get(unwrap(ctx), cast(unwrap(name)), + cast(unwrap(directory)))); +} + +MlirAttribute +mlirLLVMDICompileUnitAttrGet(MlirContext ctx, MlirAttribute id, + unsigned int sourceLanguage, MlirAttribute file, + MlirAttribute producer, bool isOptimized, + MlirLLVMDIEmissionKind emissionKind, + MlirLLVMDINameTableKind nameTableKind) { + return wrap(DICompileUnitAttr::get( + unwrap(ctx), cast(unwrap(id)), sourceLanguage, + cast(unwrap(file)), cast(unwrap(producer)), + isOptimized, DIEmissionKind(emissionKind), + DINameTableKind(nameTableKind))); +} + +MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, uint64_t value) { + return wrap(DIFlagsAttr::get(unwrap(ctx), DIFlags(value))); +} + +MlirAttribute mlirLLVMDILexicalBlockAttrGet(MlirContext ctx, + MlirAttribute scope, + MlirAttribute file, + unsigned int line, + unsigned int column) { + return wrap( + DILexicalBlockAttr::get(unwrap(ctx), cast(unwrap(scope)), + cast(unwrap(file)), line, column)); +} + +MlirAttribute mlirLLVMDILexicalBlockFileAttrGet(MlirContext ctx, + MlirAttribute scope, + MlirAttribute file, + unsigned int discriminator) { + return wrap(DILexicalBlockFileAttr::get( + unwrap(ctx), cast(unwrap(scope)), + cast(unwrap(file)), discriminator)); +} + +MlirAttribute mlirLLVMDILocalVariableAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute name, + MlirAttribute diFile, unsigned int line, unsigned int arg, + unsigned int alignInBits, MlirAttribute diType, int64_t flags) { + return wrap(DILocalVariableAttr::get( + unwrap(ctx), cast(unwrap(scope)), + cast(unwrap(name)), cast(unwrap(diFile)), line, + arg, alignInBits, cast(unwrap(diType)), DIFlags(flags))); +} + +MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, + unsigned int callingConvention, + intptr_t nTypes, + MlirAttribute const *types) { + SmallVector attrStorage; + attrStorage.reserve(nTypes); + + return wrap(DISubroutineTypeAttr::get( + unwrap(ctx), callingConvention, + llvm::map_to_vector(unwrapList(nTypes, types, attrStorage), + [](Attribute a) { return cast(a); }))); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) { + return wrap(DISubprogramAttr::getRecSelf(cast(unwrap(recId)))); +} + +MlirAttribute mlirLLVMDISubprogramAttrGet( + MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, + MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, + MlirAttribute linkageName, MlirAttribute file, unsigned int line, + unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes, + intptr_t nAnnotations, MlirAttribute const *annotations) { + SmallVector nodesStorage; + nodesStorage.reserve(nRetainedNodes); + + SmallVector annotationsStorage; + annotationsStorage.reserve(nAnnotations); + + return wrap(DISubprogramAttr::get( + unwrap(ctx), cast(unwrap(recId)), isRecSelf, + cast(unwrap(id)), + cast(unwrap(compileUnit)), + cast(unwrap(scope)), cast(unwrap(name)), + cast(unwrap(linkageName)), cast(unwrap(file)), + line, scopeLine, DISubprogramFlags(subprogramFlags), + cast(unwrap(type)), + llvm::map_to_vector( + unwrapList(nRetainedNodes, retainedNodes, nodesStorage), + [](Attribute a) { return cast(a); }), + llvm::map_to_vector( + unwrapList(nAnnotations, annotations, annotationsStorage), + [](Attribute a) { return cast(a); }))); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getScope()); +} + +unsigned int mlirLLVMDISubprogramAttrGetLine(MlirAttribute diSubprogram) { + return cast(unwrap(diSubprogram)).getLine(); +} + +unsigned int mlirLLVMDISubprogramAttrGetScopeLine(MlirAttribute diSubprogram) { + return cast(unwrap(diSubprogram)).getScopeLine(); +} + +MlirAttribute +mlirLLVMDISubprogramAttrGetCompileUnit(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getCompileUnit()); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetFile(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getFile()); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetType(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getType()); +} + +MlirAttribute mlirLLVMDIModuleAttrGet(MlirContext ctx, MlirAttribute file, + MlirAttribute scope, MlirAttribute name, + MlirAttribute configMacros, + MlirAttribute includePath, + MlirAttribute apinotes, unsigned int line, + bool isDecl) { + return wrap(DIModuleAttr::get( + unwrap(ctx), cast(unwrap(file)), + cast(unwrap(scope)), cast(unwrap(name)), + cast(unwrap(configMacros)), + cast(unwrap(includePath)), cast(unwrap(apinotes)), + line, isDecl)); +} + +MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule) { + return wrap(cast(unwrap(diModule)).getScope()); +} + +MlirAttribute mlirLLVMDIImportedEntityAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute scope, + MlirAttribute entity, MlirAttribute file, unsigned int line, + MlirAttribute name, intptr_t nElements, MlirAttribute const *elements) { + SmallVector elementsStorage; + elementsStorage.reserve(nElements); + return wrap(DIImportedEntityAttr::get( + unwrap(ctx), tag, cast(unwrap(scope)), + cast(unwrap(entity)), cast(unwrap(file)), line, + cast(unwrap(name)), + llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), + [](Attribute a) { return cast(a); }))); +} + +MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name, + MlirAttribute value) { + return wrap(DIAnnotationAttr::get(unwrap(ctx), cast(unwrap(name)), + cast(unwrap(value)))); +} diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index da6fd4846..21db18dfd 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -8,7 +8,126 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, - mlir::linalg::LinalgDialect) +using namespace mlir; +using namespace mlir::linalg; + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { + Operation *op = unwrap(mlirOp); + auto linalgOp = cast(op); + auto *dialect = static_cast(linalgOp->getDialect()); + LinalgDialect::RegionBuilderFunType fun = + dialect->getRegionBuilder(op->getName().getStringRef()); + + assert(fun && "Expected a builtin named Linalg op."); + assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); + assert(op->getRegion(0).getBlocks().empty() && + "Expected Linalg op with 0 blocks"); + + SmallVector argTypes; + SmallVector argLocs; + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType())); + argLocs.push_back(opOperand.get().getLoc()); + } + + ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); + Region ®ion = op->getRegion(0); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); + b.setInsertionPointToStart(body); + fun(b, *body, op->getAttrs(), /*emitError=*/{}); +} + +MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + // isaContractionOpInterface handles null linalgOp internally. + return linalg::isaContractionOpInterface(linalgOp); +} + +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensions(MlirOperation op) { + MlirLinalgContractionDimensions result{}; + auto linalgOp = dyn_cast(unwrap(op)); + if (!linalgOp) + return result; + + FailureOr maybeDims = + linalg::inferContractionDims(linalgOp); + if (failed(maybeDims)) + return result; + + linalg::ContractionDimensions contractionDims = *maybeDims; + MLIRContext *ctx = linalgOp.getContext(); + + auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap( + DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + result.batch = toAttr(contractionDims.batch); + result.m = toAttr(contractionDims.m); + result.n = toAttr(contractionDims.n); + result.k = toAttr(contractionDims.k); + + return result; +} + +MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return false; + + return linalg::isaConvolutionOpInterface(linalgOp); +} + +MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions +mlirLinalgInferConvolutionDimensions(MlirOperation op) { + MlirLinalgConvolutionDimensions result{}; + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return result; + + FailureOr maybeDims = + linalg::inferConvolutionDims(linalgOp); + if (failed(maybeDims)) + return result; + + linalg::ConvolutionDimensions dims = *maybeDims; + MLIRContext *ctx = linalgOp.getContext(); + + auto toI32Attr = + [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + auto toI64Attr = + [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap(DenseI64ArrayAttr::get(ctx, vals)); + }; + + result.batch = toI32Attr(dims.batch); + result.outputImage = toI32Attr(dims.outputImage); + result.outputChannel = toI32Attr(dims.outputChannel); + result.filterLoop = toI32Attr(dims.filterLoop); + result.inputChannel = toI32Attr(dims.inputChannel); + result.depth = toI32Attr(dims.depth); + result.strides = toI64Attr(dims.strides); + result.dilations = toI64Attr(dims.dilations); + + return result; +} + +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return MlirAttribute{nullptr}; + + ArrayAttr attr = linalgOp.getIndexingMaps(); + return wrap(attr); +} + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/lib/CAPI/Dialect/LinalgPasses.cpp b/mlir/lib/CAPI/Dialect/LinalgPasses.cpp new file mode 100644 index 000000000..6677476d8 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/LinalgPasses.cpp @@ -0,0 +1,26 @@ +//===- LinalgPasses.cpp - C API for Linalg Dialect Passes -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/Linalg/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/Linalg/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/lib/CAPI/Dialect/MLProgram.cpp b/mlir/lib/CAPI/Dialect/MLProgram.cpp new file mode 100644 index 000000000..525b958d9 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/MLProgram.cpp @@ -0,0 +1,14 @@ +//===- MLProgram.cpp - C Interface for MLProgram dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir-c/Dialect/MLProgram.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MLProgram, ml_program, + mlir::ml_program::MLProgramDialect) diff --git a/mlir/lib/CAPI/Dialect/Math.cpp b/mlir/lib/CAPI/Dialect/Math.cpp new file mode 100644 index 000000000..483e549a3 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Math.cpp @@ -0,0 +1,13 @@ +//===- Math.cpp - C Interface for Math dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Math.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Math/IR/Math.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Math, math, mlir::math::MathDialect) diff --git a/mlir/lib/CAPI/Dialect/MemRef.cpp b/mlir/lib/CAPI/Dialect/MemRef.cpp new file mode 100644 index 000000000..cfcdea974 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/MemRef.cpp @@ -0,0 +1,14 @@ +//===- MemRef.cpp - C Interface for MemRef dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/MemRef.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MemRef, memref, + mlir::memref::MemRefDialect) diff --git a/mlir/lib/CAPI/Dialect/NVGPU.cpp b/mlir/lib/CAPI/Dialect/NVGPU.cpp new file mode 100644 index 000000000..e6da529e1 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/NVGPU.cpp @@ -0,0 +1,31 @@ +//===- NVGPU.cpp - C Interface for NVGPU dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; +using namespace mlir::nvgpu; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu, mlir::nvgpu::NVGPUDialect) + +bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirNVGPUTensorMapDescriptorTypeGet(MlirContext ctx, + MlirType tensorMemrefType, + int swizzle, int l2promo, + int oobFill, int interleave) { + return wrap(nvgpu::TensorMapDescriptorType::get( + unwrap(ctx), cast(unwrap(tensorMemrefType)), + TensorMapSwizzleKind(swizzle), TensorMapL2PromoKind(l2promo), + TensorMapOOBKind(oobFill), TensorMapInterleaveKind(interleave))); +} diff --git a/mlir/lib/CAPI/Dialect/NVVM.cpp b/mlir/lib/CAPI/Dialect/NVVM.cpp new file mode 100644 index 000000000..a87581664 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/NVVM.cpp @@ -0,0 +1,13 @@ +//===- NVVM.cpp - C Interface for NVVM dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/NVVM.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVVM, nvvm, mlir::NVVM::NVVMDialect) diff --git a/mlir/lib/CAPI/Dialect/OpenMP.cpp b/mlir/lib/CAPI/Dialect/OpenMP.cpp new file mode 100644 index 000000000..3ffa57ab5 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/OpenMP.cpp @@ -0,0 +1,16 @@ +//===- OPENMP.cpp - C Interface for OPENMP dialect +//------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/OpenMP.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(OpenMP, omp, omp::OpenMPDialect) diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp new file mode 100644 index 000000000..bd8b13c65 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -0,0 +1,89 @@ +//===- PDL.cpp - C Interface for PDL dialect ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/PDL.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PDL, pdl, pdl::PDLDialect) + +//===---------------------------------------------------------------------===// +// PDLType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLType(MlirType type) { + return isa(unwrap(type)); +} + +//===---------------------------------------------------------------------===// +// AttributeType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLAttributeType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { + return wrap(pdl::AttributeType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLOperationType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPDLOperationTypeGet(MlirContext ctx) { + return wrap(pdl::OperationType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// RangeType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLRangeType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPDLRangeTypeGet(MlirType elementType) { + return wrap(pdl::RangeType::get(unwrap(elementType))); +} + +MlirType mlirPDLRangeTypeGetElementType(MlirType type) { + return wrap(cast(unwrap(type)).getElementType()); +} + +//===---------------------------------------------------------------------===// +// TypeType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLTypeType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPDLTypeTypeGet(MlirContext ctx) { + return wrap(pdl::TypeType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// ValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLValueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirPDLValueTypeGet(MlirContext ctx) { + return wrap(pdl::ValueType::get(unwrap(ctx))); +} diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp new file mode 100644 index 000000000..01a6a948f --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -0,0 +1,273 @@ +//===- Quant.cpp - C Interface for Quant dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) + +//===---------------------------------------------------------------------===// +// QuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAQuantizedType(MlirType type) { + return isa(unwrap(type)); +} + +unsigned mlirQuantizedTypeGetSignedFlag() { + return quant::QuantizationFlags::Signed; +} + +int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, + unsigned integralWidth) { + return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, + integralWidth); +} + +int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, + unsigned integralWidth) { + return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, + integralWidth); +} + +MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { + return wrap(cast(unwrap(type)).getExpressedType()); +} + +unsigned mlirQuantizedTypeGetFlags(MlirType type) { + return cast(unwrap(type)).getFlags(); +} + +bool mlirQuantizedTypeIsSigned(MlirType type) { + return cast(unwrap(type)).isSigned(); +} + +MlirType mlirQuantizedTypeGetStorageType(MlirType type) { + return wrap(cast(unwrap(type)).getStorageType()); +} + +int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { + return cast(unwrap(type)).getStorageTypeMin(); +} + +int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { + return cast(unwrap(type)).getStorageTypeMax(); +} + +unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { + return cast(unwrap(type)).getStorageTypeIntegralWidth(); +} + +bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, + MlirType candidate) { + return cast(unwrap(type)) + .isCompatibleExpressedType(unwrap(candidate)); +} + +MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { + return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); +} + +MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, + MlirType candidate) { + return wrap(cast(unwrap(type)) + .castFromStorageType(unwrap(candidate))); +} + +MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { + return wrap(quant::QuantizedType::castToStorageType( + cast(unwrap(type)))); +} + +MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, + MlirType candidate) { + return wrap(cast(unwrap(type)) + .castFromExpressedType(unwrap(candidate))); +} + +MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { + return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); +} + +MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, + MlirType candidate) { + return wrap(cast(unwrap(type)) + .castExpressedToStorageType(unwrap(candidate))); +} + +//===---------------------------------------------------------------------===// +// AnyQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAAnyQuantizedType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, + MlirType expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), + unwrap(expressedType), + storageTypeMin, storageTypeMax)); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, + MlirType expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::UniformQuantizedType::get( + flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, + storageTypeMin, storageTypeMax)); +} + +double mlirUniformQuantizedTypeGetScale(MlirType type) { + return cast(unwrap(type)).getScale(); +} + +int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { + return cast(unwrap(type)).getZeroPoint(); +} + +bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { + return cast(unwrap(type)).isFixedPoint(); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirUniformQuantizedPerAxisTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + intptr_t nDims, double *scales, int64_t *zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::UniformQuantizedPerAxisType::get( + flags, unwrap(storageType), unwrap(expressedType), + llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims), + quantizedDimension, storageTypeMin, storageTypeMax)); +} + +intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { + return cast(unwrap(type)) + .getScales() + .size(); +} + +double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { + return cast(unwrap(type)) + .getScales()[pos]; +} + +int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getZeroPoints()[pos]; +} + +int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { + return cast(unwrap(type)) + .getQuantizedDimension(); +} + +bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { + return cast(unwrap(type)).isFixedPoint(); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirUniformQuantizedSubChannelTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims, + int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + auto scales = dyn_cast(unwrap(scalesAttr)); + auto zeroPoints = dyn_cast(unwrap(zeroPointsAttr)); + + if (!scales || !zeroPoints) { + return {}; + } + + return wrap(quant::UniformQuantizedSubChannelType::get( + flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints, + llvm::ArrayRef(quantizedDimensions, nDims), + llvm::ArrayRef(blockSizes, nDims), storageTypeMin, + storageTypeMax)); +} + +intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) { + return cast(unwrap(type)) + .getBlockSizes() + .size(); +} + +int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getQuantizedDimensions()[pos]; +} + +int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getBlockSizes()[pos]; +} + +MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) { + return wrap( + cast(unwrap(type)).getScales()); +} + +MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) { + return wrap(cast(unwrap(type)) + .getZeroPoints()); +} + +//===---------------------------------------------------------------------===// +// CalibratedQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsACalibratedQuantizedType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, + double max) { + return wrap( + quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); +} + +double mlirCalibratedQuantizedTypeGetMin(MlirType type) { + return cast(unwrap(type)).getMin(); +} + +double mlirCalibratedQuantizedTypeGetMax(MlirType type) { + return cast(unwrap(type)).getMax(); +} diff --git a/mlir/lib/CAPI/Dialect/ROCDL.cpp b/mlir/lib/CAPI/Dialect/ROCDL.cpp new file mode 100644 index 000000000..63e2fa881 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/ROCDL.cpp @@ -0,0 +1,13 @@ +//===- ROCDL.cpp - C Interface for ROCDL dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/ROCDL.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(ROCDL, rocdl, mlir::ROCDL::ROCDLDialect) diff --git a/mlir/lib/CAPI/Dialect/SCF.cpp b/mlir/lib/CAPI/Dialect/SCF.cpp index c1dca6d21..17751b1c9 100644 --- a/mlir/lib/CAPI/Dialect/SCF.cpp +++ b/mlir/lib/CAPI/Dialect/SCF.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir-c/Dialect/SCF.h" #include "mlir/CAPI/Registration.h" diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp new file mode 100644 index 000000000..7e96bbb07 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SMT.cpp @@ -0,0 +1,127 @@ +//===- SMT.cpp - C interface for the SMT dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SMT.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SMT/IR/SMTAttributes.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" + +using namespace mlir; +using namespace smt; + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect) + +//===----------------------------------------------------------------------===// +// Type API. +//===----------------------------------------------------------------------===// + +bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) { + return isAnyNonFuncSMTValueType(unwrap(type)); +} + +bool mlirSMTTypeIsAnySMTValueType(MlirType type) { + return isAnySMTValueType(unwrap(type)); +} + +bool mlirSMTTypeIsAArray(MlirType type) { return isa(unwrap(type)); } + +MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType, + MlirType rangeType) { + return wrap( + ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType))); +} + +bool mlirSMTTypeIsABitVector(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) { + return wrap(BitVectorType::get(unwrap(ctx), width)); +} + +bool mlirSMTTypeIsABool(MlirType type) { return isa(unwrap(type)); } + +MlirType mlirSMTTypeGetBool(MlirContext ctx) { + return wrap(BoolType::get(unwrap(ctx))); +} + +bool mlirSMTTypeIsAInt(MlirType type) { return isa(unwrap(type)); } + +MlirType mlirSMTTypeGetInt(MlirContext ctx) { + return wrap(IntType::get(unwrap(ctx))); +} + +bool mlirSMTTypeIsASMTFunc(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType) { + SmallVector domainTypesVec; + domainTypesVec.reserve(numberOfDomainTypes); + + for (size_t i = 0; i < numberOfDomainTypes; i++) + domainTypesVec.push_back(unwrap(domainTypes[i])); + + return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType))); +} + +bool mlirSMTTypeIsASort(MlirType type) { return isa(unwrap(type)); } + +MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams) { + SmallVector sortParamsVec; + sortParamsVec.reserve(numberOfSortParams); + + for (size_t i = 0; i < numberOfSortParams; i++) + sortParamsVec.push_back(unwrap(sortParams[i])); + + return wrap(SortType::get(unwrap(ctx), unwrap(identifier), sortParamsVec)); +} + +//===----------------------------------------------------------------------===// +// Attribute API. +//===----------------------------------------------------------------------===// + +bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { + return symbolizeBVCmpPredicate(unwrap(str)).has_value(); +} + +bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { + return symbolizeIntPredicate(unwrap(str)).has_value(); +} + +bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value, + unsigned width) { + return wrap(BitVectorAttr::get(unwrap(ctx), value, width)); +} + +MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { + auto predicate = symbolizeBVCmpPredicate(unwrap(str)); + assert(predicate.has_value() && "invalid predicate"); + + return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value())); +} + +MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { + auto predicate = symbolizeIntPredicate(unwrap(str)); + assert(predicate.has_value() && "invalid predicate"); + + return wrap(IntPredicateAttr::get(unwrap(ctx), predicate.value())); +} diff --git a/mlir/lib/CAPI/Dialect/SPIRV.cpp b/mlir/lib/CAPI/Dialect/SPIRV.cpp new file mode 100644 index 000000000..9bfe26b95 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SPIRV.cpp @@ -0,0 +1,13 @@ +//===- SPIRV.cpp - C Interface for SPIRV dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SPIRV.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SPIRV, spirv, mlir::spirv::SPIRVDialect) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp new file mode 100644 index 000000000..cf25b5263 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -0,0 +1,128 @@ +//===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SparseTensor.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/AffineMap.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Support/LLVM.h" + +using namespace llvm; +using namespace mlir::sparse_tensor; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, + mlir::sparse_tensor::SparseTensorDialect) + +// Ensure the C-API enums are int-castable to C++ equivalents. +static_assert( + static_cast(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == + static_cast(LevelFormat::Dense) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == + static_cast(LevelFormat::Compressed) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == + static_cast(LevelFormat::Singleton) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == + static_cast(LevelFormat::LooseCompressed) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == + static_cast(LevelFormat::NOutOfM), + "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); + +static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == + static_cast(LevelPropNonDefault::Nonordered) && + static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == + static_cast(LevelPropNonDefault::Nonunique) && + static_cast(MLIR_SPARSE_PROPERTY_SOA) == + static_cast(LevelPropNonDefault::SoA), + "MlirSparseTensorLevelProperty (C-API) and " + "LevelPropertyNondefault (C++) mismatch"); + +bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirAttribute mlirSparseTensorEncodingAttrGet( + MlirContext ctx, intptr_t lvlRank, + MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, + MlirAffineMap lvlToDim, int posWidth, int crdWidth, + MlirAttribute explicitVal, MlirAttribute implicitVal) { + SmallVector cppLvlTypes; + + cppLvlTypes.reserve(lvlRank); + for (intptr_t l = 0; l < lvlRank; ++l) + cppLvlTypes.push_back(static_cast(lvlTypes[l])); + + return wrap(SparseTensorEncodingAttr::get( + unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, + crdWidth, unwrap(explicitVal), unwrap(implicitVal))); +} + +MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getDimToLvl()); +} + +MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getLvlToDim()); +} + +intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { + return cast(unwrap(attr)).getLvlRank(); +} + +MlirSparseTensorLevelType +mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { + return static_cast( + cast(unwrap(attr)).getLvlType(lvl)); +} + +enum MlirSparseTensorLevelFormat +mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { + LevelType lt = + static_cast(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); + return static_cast(lt.getLvlFmt()); +} + +int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { + return cast(unwrap(attr)).getPosWidth(); +} + +int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { + return cast(unwrap(attr)).getCrdWidth(); +} + +MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getExplicitVal()); +} + +MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getImplicitVal()); +} + +MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( + enum MlirSparseTensorLevelFormat lvlFmt, + const enum MlirSparseTensorLevelPropertyNondefault *properties, + unsigned size, unsigned n, unsigned m) { + + std::vector props; + props.reserve(size); + for (unsigned i = 0; i < size; i++) + props.push_back(static_cast(properties[i])); + + return static_cast( + *buildLevelType(static_cast(lvlFmt), props, n, m)); +} + +unsigned +mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { + return getN(static_cast(lvlType)); +} + +unsigned +mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { + return getM(static_cast(lvlType)); +} diff --git a/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp b/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp new file mode 100644 index 000000000..5b2ba4ca7 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp @@ -0,0 +1,26 @@ +//===- SparseTensorPasses.cpp - C API for SparseTensor Dialect Passes -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp new file mode 100644 index 000000000..5fd773572 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -0,0 +1,108 @@ +//===- Transform.cpp - C Interface for Transform dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform, + transform::TransformDialect) + +//===---------------------------------------------------------------------===// +// AnyOpType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyOpType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) { + return wrap(transform::AnyOpType::getTypeID()); +} + +MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { + return wrap(transform::AnyOpType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// AnyParamType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyParamType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) { + return wrap(transform::AnyParamType::getTypeID()); +} + +MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) { + return wrap(transform::AnyParamType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// AnyValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyValueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) { + return wrap(transform::AnyValueType::getTypeID()); +} + +MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) { + return wrap(transform::AnyValueType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformOperationType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirTransformOperationTypeGetTypeID(void) { + return wrap(transform::OperationType::getTypeID()); +} + +MlirType mlirTransformOperationTypeGet(MlirContext ctx, + MlirStringRef operationName) { + return wrap( + transform::OperationType::get(unwrap(ctx), unwrap(operationName))); +} + +MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { + return wrap(cast(unwrap(type)).getOperationName()); +} + +//===---------------------------------------------------------------------===// +// ParamType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformParamType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirTransformParamTypeGetTypeID(void) { + return wrap(transform::ParamType::getTypeID()); +} + +MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) { + return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type))); +} + +MlirType mlirTransformParamTypeGetType(MlirType type) { + return wrap(cast(unwrap(type)).getType()); +} diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp new file mode 100644 index 000000000..145455e1c --- /dev/null +++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp @@ -0,0 +1,83 @@ +//===- TransformTransforms.cpp - C Interface for Transform dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// C interface to transforms for the transform dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform/Interpreter.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Dialect/Transform/IR/Utils.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" + +using namespace mlir; + +DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions) + +extern "C" { + +MlirTransformOptions mlirTransformOptionsCreate() { + return wrap(new transform::TransformOptions); +} + +void mlirTransformOptionsEnableExpensiveChecks( + MlirTransformOptions transformOptions, bool enable) { + unwrap(transformOptions)->enableExpensiveChecks(enable); +} + +bool mlirTransformOptionsGetExpensiveChecksEnabled( + MlirTransformOptions transformOptions) { + return unwrap(transformOptions)->getExpensiveChecksEnabled(); +} + +void mlirTransformOptionsEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions, bool enable) { + unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable); +} + +bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions) { + return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp(); +} + +void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) { + delete unwrap(transformOptions); +} + +MlirLogicalResult mlirTransformApplyNamedSequence( + MlirOperation payload, MlirOperation transformRoot, + MlirOperation transformModule, MlirTransformOptions transformOptions) { + Operation *transformRootOp = unwrap(transformRoot); + Operation *transformModuleOp = unwrap(transformModule); + if (!isa(transformRootOp)) { + transformRootOp->emitError() + << "must implement TransformOpInterface to be used as transform root"; + return mlirLogicalResultFailure(); + } + if (!isa(transformModuleOp)) { + transformModuleOp->emitError() + << "must be a " << ModuleOp::getOperationName(); + return mlirLogicalResultFailure(); + } + return wrap(transform::applyTransformNamedSequence( + unwrap(payload), unwrap(transformRoot), + cast(unwrap(transformModule)), *unwrap(transformOptions))); +} + +MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target, + MlirOperation other) { + OwningOpRef otherOwning(unwrap(other)->clone()); + LogicalResult result = transform::detail::mergeSymbolsInto( + unwrap(target), std::move(otherOwning)); + return wrap(result); +} +} diff --git a/mlir/lib/CAPI/Dialect/Vector.cpp b/mlir/lib/CAPI/Dialect/Vector.cpp new file mode 100644 index 000000000..c744b83b6 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Vector.cpp @@ -0,0 +1,14 @@ +//===- Vector.cpp - C Interface for Vector dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Vector.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Vector, vector, + mlir::vector::VectorDialect) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 09dcb6143..bf7dff897 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,8 +1,16 @@ +set(LLVM_LINK_COMPONENTS + nativecodegen + native + orcjit + support +) + # Main API shared library. -add_mlir_public_c_api_library(MLIRCEXECUTIONENGINE +add_mlir_upstream_c_api_library(MLIRCAPIExecutionEngine ExecutionEngine.cpp LINK_LIBS PUBLIC + MLIRBuiltinToLLVMIRTranslation MLIRExecutionEngine MLIRLLVMToLLVMIRTranslation ) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 68137c067..306cebd23 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -10,21 +10,57 @@ #include "mlir/CAPI/ExecutionEngine.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; -extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op) { - static bool init_once = [] { +extern "C" MlirExecutionEngine +mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths, + bool enableObjectDump) { + static bool initOnce = [] { llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm llvm::InitializeNativeTargetAsmPrinter(); return true; }(); - (void)init_once; + (void)initOnce; - mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext()); - auto jitOrError = ExecutionEngine::create(unwrap(op)); + auto &ctx = *unwrap(op)->getContext(); + mlir::registerBuiltinDialectTranslation(ctx); + mlir::registerLLVMDialectTranslation(ctx); + mlir::registerOpenMPDialectTranslation(ctx); + + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!tmBuilderOrError) { + llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; + return MlirExecutionEngine{nullptr}; + } + auto tmOrError = tmBuilderOrError->createTargetMachine(); + if (!tmOrError) { + llvm::errs() << "Failed to create a TargetMachine for the host\n"; + return MlirExecutionEngine{nullptr}; + } + + SmallVector libPaths; + for (unsigned i = 0; i < static_cast(numPaths); ++i) + libPaths.push_back(sharedLibPaths[i].data); + + // Create a transformer to run all LLVM optimization passes at the + // specified optimization level. + auto transformer = mlir::makeOptimizingTransformer( + optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); + ExecutionEngineOptions jitOptions; + jitOptions.transformer = transformer; + jitOptions.jitCodeGenOptLevel = static_cast(optLevel); + jitOptions.sharedLibPaths = libPaths; + jitOptions.enableObjectDump = enableObjectDump; + auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions); if (!jitOrError) { consumeError(jitOrError.takeError()); return MlirExecutionEngine{nullptr}; @@ -47,10 +83,37 @@ mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name, return wrap(success()); } +extern "C" void *mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, + MlirStringRef name) { + auto optionalFPtr = + llvm::expectedToOptional(unwrap(jit)->lookupPacked(unwrap(name))); + if (!optionalFPtr) + return nullptr; + return reinterpret_cast(*optionalFPtr); +} + extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name) { - auto expectedFPtr = unwrap(jit)->lookup(unwrap(name)); - if (!expectedFPtr) + auto optionalFPtr = + llvm::expectedToOptional(unwrap(jit)->lookup(unwrap(name))); + if (!optionalFPtr) return nullptr; - return reinterpret_cast(*expectedFPtr); + return *optionalFPtr; +} + +extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, + MlirStringRef name, + void *sym) { + unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner(unwrap(name))] = + { llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported }; + return symbolMap; + }); +} + +extern "C" void mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit, + MlirStringRef name) { + unwrap(jit)->dumpToObjectFile(unwrap(name)); } diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index 2d8bc3ce5..5a0a03b11 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -56,12 +56,34 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, return unwrap(affineExpr).isFunctionOfDim(position); } +MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, + MlirAffineMap affineMap) { + return wrap(unwrap(affineExpr).compose(unwrap(affineMap))); +} + +MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr, + uint32_t numDims, uint32_t shift, + uint32_t offset) { + return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset)); +} + +MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, + uint32_t numSymbols, uint32_t shift, + uint32_t offset) { + return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset)); +} + +MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr, uint32_t numDims, + uint32_t numSymbols) { + return wrap(simplifyAffineExpr(unwrap(expr), numDims, numSymbols)); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) { @@ -69,7 +91,7 @@ MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) { } intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).cast().getPosition(); + return cast(unwrap(affineExpr)).getPosition(); } //===----------------------------------------------------------------------===// @@ -77,7 +99,7 @@ intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) { //===----------------------------------------------------------------------===// bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) { @@ -85,7 +107,7 @@ MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) { } intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).cast().getPosition(); + return cast(unwrap(affineExpr)).getPosition(); } //===----------------------------------------------------------------------===// @@ -93,7 +115,7 @@ intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) { //===----------------------------------------------------------------------===// bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) { @@ -101,7 +123,7 @@ MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) { } int64_t mlirAffineConstantExprGetValue(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).cast().getValue(); + return cast(unwrap(affineExpr)).getValue(); } //===----------------------------------------------------------------------===// @@ -176,13 +198,13 @@ MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs, //===----------------------------------------------------------------------===// bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) { - return wrap(unwrap(affineExpr).cast().getLHS()); + return wrap(cast(unwrap(affineExpr)).getLHS()); } MlirAffineExpr mlirAffineBinaryOpExprGetRHS(MlirAffineExpr affineExpr) { - return wrap(unwrap(affineExpr).cast().getRHS()); + return wrap(cast(unwrap(affineExpr)).getRHS()); } diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp index f532d5dae..1889765ef 100644 --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -68,7 +68,7 @@ MlirAffineMap mlirAffineMapMinorIdentityGet(MlirContext ctx, intptr_t dims, MlirAffineMap mlirAffineMapPermutationGet(MlirContext ctx, intptr_t size, unsigned *permutation) { return wrap(AffineMap::getPermutationMap( - llvm::makeArrayRef(permutation, static_cast(size)), unwrap(ctx))); + llvm::ArrayRef(permutation, static_cast(size)), unwrap(ctx))); } bool mlirAffineMapIsIdentity(MlirAffineMap affineMap) { @@ -137,3 +137,23 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults) { return wrap(unwrap(affineMap).getMinorSubMap(numResults)); } + +MlirAffineMap mlirAffineMapReplace(MlirAffineMap affineMap, + MlirAffineExpr expression, + MlirAffineExpr replacement, + intptr_t numResultDims, + intptr_t numResultSyms) { + return wrap(unwrap(affineMap).replace(unwrap(expression), unwrap(replacement), + numResultDims, numResultSyms)); +} + +void mlirAffineMapCompressUnusedSymbols( + MlirAffineMap *affineMaps, intptr_t size, void *result, + void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) { + SmallVector maps; + for (intptr_t idx = 0; idx < size; ++idx) + maps.push_back(unwrap(affineMaps[idx])); + intptr_t idx = 0; + for (auto m : mlir::compressUnusedSymbols(maps)) + populateResult(result, idx++, wrap(m)); +} diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index a54006db2..8d57ab6b5 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -7,20 +7,34 @@ //===----------------------------------------------------------------------===// #include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/IntegerSet.h" #include "mlir/CAPI/Support.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" using namespace mlir; +MlirAttribute mlirAttributeGetNull() { return {nullptr}; } + +//===----------------------------------------------------------------------===// +// Location attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsALocation(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAAffineMap(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { @@ -28,7 +42,11 @@ MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); +} + +MlirTypeID mlirAffineMapAttrGetTypeID(void) { + return wrap(AffineMapAttr::getTypeID()); } //===----------------------------------------------------------------------===// @@ -36,7 +54,7 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, @@ -48,19 +66,21 @@ MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getValue()[pos]); + return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } +MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsADictionary(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, @@ -68,26 +88,29 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, SmallVector attributes; attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) - attributes.emplace_back( - Identifier::get(unwrap(elements[i].name), unwrap(ctx)), - unwrap(elements[i].attribute)); + attributes.emplace_back(unwrap(elements[i].name), + unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = - unwrap(attr).cast().getValue()[pos]; - return {wrap(attribute.first), wrap(attribute.second)}; + llvm::cast(unwrap(attr)).getValue()[pos]; + return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name) { - return wrap(unwrap(attr).cast().get(unwrap(name))); + return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); +} + +MlirTypeID mlirDictionaryAttrGetTypeID(void) { + return wrap(DictionaryAttr::getTypeID()); } //===----------------------------------------------------------------------===// @@ -95,7 +118,7 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, //===----------------------------------------------------------------------===// bool mlirAttributeIsAFloat(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, @@ -109,15 +132,17 @@ MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { - return unwrap(attr).cast().getValueAsDouble(); + return llvm::cast(unwrap(attr)).getValueAsDouble(); } +MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAInteger(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { @@ -125,7 +150,19 @@ MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { - return unwrap(attr).cast().getInt(); + return llvm::cast(unwrap(attr)).getInt(); +} + +int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getSInt(); +} + +uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getUInt(); +} + +MlirTypeID mlirIntegerAttrGetTypeID(void) { + return wrap(IntegerAttr::getTypeID()); } //===----------------------------------------------------------------------===// @@ -133,7 +170,7 @@ int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsABool(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { @@ -141,7 +178,7 @@ MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { } bool mlirBoolAttrGetValue(MlirAttribute attr) { - return unwrap(attr).cast().getValue(); + return llvm::cast(unwrap(attr)).getValue(); } //===----------------------------------------------------------------------===// @@ -149,7 +186,19 @@ bool mlirBoolAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); +} + +MlirTypeID mlirIntegerSetAttrGetTypeID(void) { + return wrap(IntegerSetAttr::getTypeID()); +} + +MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) { + return wrap(IntegerSetAttr::get(unwrap(set))); +} + +MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -157,23 +206,28 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaque(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { return wrap( - OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), StringRef(data, dataLength), unwrap(type))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(attr)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getAttrData()); + return wrap(llvm::cast(unwrap(attr)).getAttrData()); +} + +MlirTypeID mlirOpaqueAttrGetTypeID(void) { + return wrap(OpaqueAttr::getTypeID()); } //===----------------------------------------------------------------------===// @@ -181,19 +235,23 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAString(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); + return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(str), unwrap(type))); + return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); +} + +MlirTypeID mlirStringAttrGetTypeID(void) { + return wrap(StringAttr::getTypeID()); } //===----------------------------------------------------------------------===// @@ -201,7 +259,7 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsASymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, @@ -210,26 +268,38 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) - refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); + refs.push_back(llvm::cast(unwrap(references[i]))); + auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); + return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getRootReference()); + return wrap( + llvm::cast(unwrap(attr)).getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getLeafReference()); + return wrap( + llvm::cast(unwrap(attr)).getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getNestedReferences().size()); + llvm::cast(unwrap(attr)).getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getNestedReferences()[pos]); + return wrap( + llvm::cast(unwrap(attr)).getNestedReferences()[pos]); +} + +MlirTypeID mlirSymbolRefAttrGetTypeID(void) { + return wrap(SymbolRefAttr::getTypeID()); +} + +MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) { + return wrap(mlir::DistinctAttr::create(unwrap(referencedAttr))); } //===----------------------------------------------------------------------===// @@ -237,7 +307,7 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, //===----------------------------------------------------------------------===// bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { @@ -245,7 +315,7 @@ MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -253,7 +323,7 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAType(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirTypeAttrGet(MlirType type) { @@ -261,43 +331,155 @@ MlirAttribute mlirTypeAttrGet(MlirType type) { } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// bool mlirAttributeIsAUnit(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { return wrap(UnitAttr::get(unwrap(ctx))); } +MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// bool mlirAttributeIsAElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr).cast().getValue( - llvm::makeArrayRef(idxs, rank))); + return wrap(llvm::cast(unwrap(attr)) + .getValues()[llvm::ArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return unwrap(attr).cast().isValidIndex( - llvm::makeArrayRef(idxs, rank)); + return llvm::cast(unwrap(attr)) + .isValidIndex(llvm::ArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().getNumElements(); + return llvm::cast(unwrap(attr)).getNumElements(); +} + +//===----------------------------------------------------------------------===// +// Dense array attribute. +//===----------------------------------------------------------------------===// + +MlirTypeID mlirDenseArrayAttrGetTypeID() { + return wrap(DenseArrayAttr::getTypeID()); +} + +//===----------------------------------------------------------------------===// +// IsA support. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} +bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} +bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} +bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} +bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} +bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} +bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +//===----------------------------------------------------------------------===// +// Constructors. +//===----------------------------------------------------------------------===// + +MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, + int const *values) { + SmallVector elements(values, values + size); + return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); +} +MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, + int8_t const *values) { + return wrap( + DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, + int16_t const *values) { + return wrap( + DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, + int32_t const *values) { + return wrap( + DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, + int64_t const *values) { + return wrap( + DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, + float const *values) { + return wrap( + DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, + double const *values) { + return wrap( + DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} + +//===----------------------------------------------------------------------===// +// Accessors. +//===----------------------------------------------------------------------===// + +intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).size(); +} + +//===----------------------------------------------------------------------===// +// Indexed accessors. +//===----------------------------------------------------------------------===// + +bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; +} +int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; +} +int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; +} +int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; +} +int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; +} +float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; +} +double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr))[pos]; } //===----------------------------------------------------------------------===// @@ -306,76 +488,107 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); +} + +MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { + return wrap(DenseIntOrFPElementsAttr::getTypeID()); } //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, MlirAttribute const *elements) { SmallVector attributes; return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), + DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrapList(numElements, elements, attributes))); } +MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, + size_t rawBufferSize, + const void *rawBuffer) { + auto shapedTypeCpp = llvm::cast(unwrap(shapedType)); + ArrayRef rawBufferCpp(static_cast(rawBuffer), + rawBufferSize); + bool isSplat = false; + if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, + isSplat)) + return mlirAttributeGetNull(); + return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); +} + MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); +} +MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, + uint8_t element) { + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); +} +MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, + int8_t element) { + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, const int *elements) { SmallVector values(elements, elements + numElements); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } /// Creates a dense attribute with elements of the type deduced by templates. @@ -383,11 +596,30 @@ template static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), - llvm::makeArrayRef(elements, numElements))); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + llvm::ArrayRef(elements, numElements))); } +MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, + intptr_t numElements, + const uint8_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, + intptr_t numElements, + const int8_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, + intptr_t numElements, + const int16_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { @@ -418,6 +650,20 @@ MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, const double *elements) { return getDenseAttribute(shapedType, numElements, elements); } +MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + size_t bufferSize = numElements * 2; + const void *buffer = static_cast(elements); + return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); +} +MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + size_t bufferSize = numElements * 2; + const void *buffer = static_cast(elements); + return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); +} MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, @@ -427,106 +673,268 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { - return wrap(unwrap(attr).cast().reshape( - unwrap(shapedType).cast())); + return wrap(llvm::cast(unwrap(attr)) + .reshape(llvm::cast(unwrap(shapedType)))); } //===----------------------------------------------------------------------===// // Splat accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { - return unwrap(attr).cast().isSplat(); + return llvm::cast(unwrap(attr)).isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getSplatValue()); + return wrap( + llvm::cast(unwrap(attr)).getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); +} +int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getSplatValue(); +} +uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; +} +int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; +} +uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; +} +int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; +} +uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return *( - unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return *( - unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; +} +uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return llvm::cast(unwrap(attr)).getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - *(unwrap(attr).cast().getValues().begin() + - pos)); + llvm::cast(unwrap(attr)).getValues()[pos]); } //===----------------------------------------------------------------------===// // Raw data accessors. +//===----------------------------------------------------------------------===// const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getRawData().data()); + llvm::cast(unwrap(attr)).getRawData().data()); } //===----------------------------------------------------------------------===// -// Opaque elements attribute. +// Resource blob attributes. //===----------------------------------------------------------------------===// -bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) { - return unwrap(attr).isa(); +bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, + size_t dataAlignment, bool dataIsMutable, + void (*deleter)(void *userData, const void *data, size_t size, + size_t align), + void *userData) { + AsmResourceBlob::DeleterFn cppDeleter = {}; + if (deleter) { + cppDeleter = [deleter, userData](void *data, size_t size, size_t align) { + deleter(userData, data, size, align); + }; + } + AsmResourceBlob blob( + llvm::ArrayRef(static_cast(data), dataLength), + dataAlignment, std::move(cppDeleter), dataIsMutable); + return wrap( + DenseResourceElementsAttr::get(llvm::cast(unwrap(shapedType)), + unwrap(name), std::move(blob))); +} + +template +static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, + intptr_t numElements, const T *elements) { + return wrap(U::get(llvm::cast(unwrap(shapedType)), unwrap(name), + UnmanagedAsmResourceBlob::allocateInferAlign( + llvm::ArrayRef(elements, numElements)))); +} + +MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint8_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint16_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint32_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint64_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int8_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int16_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int32_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int64_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const float *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const double *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +template +static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { + return (*llvm::cast(unwrap(attr)).tryGetAsArrayRef())[pos]; +} + +bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, + pos); +} +uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, + pos); +} +uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, + pos); +} +int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { + return getDenseResourceVal(attr, pos); } //===----------------------------------------------------------------------===// @@ -534,22 +942,58 @@ bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsASparseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { - return wrap( - SparseElementsAttr::get(unwrap(shapedType).cast(), - unwrap(denseIndices).cast(), - unwrap(denseValues).cast())); + return wrap(SparseElementsAttr::get( + llvm::cast(unwrap(shapedType)), + llvm::cast(unwrap(denseIndices)), + llvm::cast(unwrap(denseValues)))); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getIndices()); + return wrap(llvm::cast(unwrap(attr)).getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValues()); + return wrap(llvm::cast(unwrap(attr)).getValues()); +} + +MlirTypeID mlirSparseElementsAttrGetTypeID(void) { + return wrap(SparseElementsAttr::getTypeID()); +} + +//===----------------------------------------------------------------------===// +// Strided layout attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, + intptr_t numStrides, + const int64_t *strides) { + return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, + ArrayRef(strides, numStrides))); +} + +int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { + return llvm::cast(unwrap(attr)).getOffset(); +} + +intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { + return static_cast( + llvm::cast(unwrap(attr)).getStrides().size()); +} + +int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getStrides()[pos]; +} + +MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { + return wrap(StridedLayoutAttr::getTypeID()); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index e4442ac4c..9d8554aab 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -9,20 +9,26 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include + using namespace mlir; //===----------------------------------------------------------------------===// // Integer types. //===----------------------------------------------------------------------===// +MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } + bool mlirTypeIsAInteger(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { @@ -38,26 +44,30 @@ MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { } unsigned mlirIntegerTypeGetWidth(MlirType type) { - return unwrap(type).cast().getWidth(); + return llvm::cast(unwrap(type)).getWidth(); } bool mlirIntegerTypeIsSignless(MlirType type) { - return unwrap(type).cast().isSignless(); + return llvm::cast(unwrap(type)).isSignless(); } bool mlirIntegerTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return llvm::cast(unwrap(type)).isSigned(); } bool mlirIntegerTypeIsUnsigned(MlirType type) { - return unwrap(type).cast().isUnsigned(); + return llvm::cast(unwrap(type)).isUnsigned(); } //===----------------------------------------------------------------------===// // Index type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa(); } +MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } + +bool mlirTypeIsAIndex(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirIndexTypeGet(MlirContext ctx) { return wrap(IndexType::get(unwrap(ctx))); @@ -67,35 +77,209 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// -bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } +bool mlirTypeIsAFloat(MlirType type) { + return llvm::isa(unwrap(type)); +} + +unsigned mlirFloatTypeGetWidth(MlirType type) { + return llvm::cast(unwrap(type)).getWidth(); +} + +MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() { + return wrap(Float4E2M1FNType::getTypeID()); +} + +bool mlirTypeIsAFloat4E2M1FN(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) { + return wrap(Float4E2M1FNType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { + return wrap(Float6E2M3FNType::getTypeID()); +} + +bool mlirTypeIsAFloat6E2M3FN(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) { + return wrap(Float6E2M3FNType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { + return wrap(Float6E3M2FNType::getTypeID()); +} + +bool mlirTypeIsAFloat6E3M2FN(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) { + return wrap(Float6E3M2FNType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E5M2TypeGetTypeID() { + return wrap(Float8E5M2Type::getTypeID()); +} + +bool mlirTypeIsAFloat8E5M2(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { + return wrap(Float8E5M2Type::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E4M3TypeGetTypeID() { + return wrap(Float8E4M3Type::getTypeID()); +} + +bool mlirTypeIsAFloat8E4M3(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) { + return wrap(Float8E4M3Type::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { + return wrap(Float8E4M3FNType::getTypeID()); +} + +bool mlirTypeIsAFloat8E4M3FN(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { + return wrap(Float8E4M3FNType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { + return wrap(Float8E5M2FNUZType::getTypeID()); +} + +bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { + return wrap(Float8E5M2FNUZType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { + return wrap(Float8E4M3FNUZType::getTypeID()); +} + +bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { + return wrap(Float8E4M3FNUZType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { + return wrap(Float8E4M3B11FNUZType::getTypeID()); +} + +bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { + return wrap(Float8E4M3B11FNUZType::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E3M4TypeGetTypeID() { + return wrap(Float8E3M4Type::getTypeID()); +} + +bool mlirTypeIsAFloat8E3M4(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { + return wrap(Float8E3M4Type::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() { + return wrap(Float8E8M0FNUType::getTypeID()); +} + +bool mlirTypeIsAFloat8E8M0FNU(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) { + return wrap(Float8E8M0FNUType::get(unwrap(ctx))); +} + +MlirTypeID mlirBFloat16TypeGetTypeID() { + return wrap(BFloat16Type::getTypeID()); +} + +bool mlirTypeIsABF16(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirBF16TypeGet(MlirContext ctx) { - return wrap(FloatType::getBF16(unwrap(ctx))); + return wrap(BFloat16Type::get(unwrap(ctx))); } -bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } +MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } + +bool mlirTypeIsAF16(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirF16TypeGet(MlirContext ctx) { - return wrap(FloatType::getF16(unwrap(ctx))); + return wrap(Float16Type::get(unwrap(ctx))); +} + +MlirTypeID mlirFloatTF32TypeGetTypeID() { + return wrap(FloatTF32Type::getTypeID()); +} + +bool mlirTypeIsATF32(MlirType type) { + return llvm::isa(unwrap(type)); } -bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } +MlirType mlirTF32TypeGet(MlirContext ctx) { + return wrap(FloatTF32Type::get(unwrap(ctx))); +} + +MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } + +bool mlirTypeIsAF32(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirF32TypeGet(MlirContext ctx) { - return wrap(FloatType::getF32(unwrap(ctx))); + return wrap(Float32Type::get(unwrap(ctx))); } -bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } +MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } + +bool mlirTypeIsAF64(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirF64TypeGet(MlirContext ctx) { - return wrap(FloatType::getF64(unwrap(ctx))); + return wrap(Float64Type::get(unwrap(ctx))); } //===----------------------------------------------------------------------===// // None type. //===----------------------------------------------------------------------===// -bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa(); } +MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } + +bool mlirTypeIsANone(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirNoneTypeGet(MlirContext ctx) { return wrap(NoneType::get(unwrap(ctx))); @@ -105,8 +289,10 @@ MlirType mlirNoneTypeGet(MlirContext ctx) { // Complex type. //===----------------------------------------------------------------------===// +MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } + bool mlirTypeIsAComplex(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirComplexTypeGet(MlirType elementType) { @@ -114,95 +300,160 @@ MlirType mlirComplexTypeGet(MlirType elementType) { } MlirType mlirComplexTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } //===----------------------------------------------------------------------===// // Shaped type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAShaped(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirShapedTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } bool mlirShapedTypeHasRank(MlirType type) { - return unwrap(type).cast().hasRank(); + return llvm::cast(unwrap(type)).hasRank(); } int64_t mlirShapedTypeGetRank(MlirType type) { - return unwrap(type).cast().getRank(); + return llvm::cast(unwrap(type)).getRank(); } bool mlirShapedTypeHasStaticShape(MlirType type) { - return unwrap(type).cast().hasStaticShape(); + return llvm::cast(unwrap(type)).hasStaticShape(); } bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { - return unwrap(type).cast().isDynamicDim( - static_cast(dim)); + return llvm::cast(unwrap(type)) + .isDynamicDim(static_cast(dim)); +} + +bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) { + return llvm::cast(unwrap(type)) + .isStaticDim(static_cast(dim)); } int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { - return unwrap(type).cast().getDimSize(static_cast(dim)); + return llvm::cast(unwrap(type)) + .getDimSize(static_cast(dim)); } +int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } + bool mlirShapedTypeIsDynamicSize(int64_t size) { return ShapedType::isDynamic(size); } +bool mlirShapedTypeIsStaticSize(int64_t size) { + return ShapedType::isStatic(size); +} + bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { - return ShapedType::isDynamicStrideOrOffset(val); + return ShapedType::isDynamic(val); +} + +bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) { + return ShapedType::isStatic(val); +} + +int64_t mlirShapedTypeGetDynamicStrideOrOffset() { + return ShapedType::kDynamic; } //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa(); } +MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } + +bool mlirTypeIsAVector(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType) { - return wrap( - VectorType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType))); } MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType) { return wrap(VectorType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType))); } +MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, + const bool *scalable, MlirType elementType) { + return wrap(VectorType::get( + llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), + llvm::ArrayRef(scalable, static_cast(rank)))); +} + +MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, + const bool *scalable, + MlirType elementType) { + return wrap(VectorType::getChecked( + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), + llvm::ArrayRef(scalable, static_cast(rank)))); +} + +bool mlirVectorTypeIsScalable(MlirType type) { + return cast(unwrap(type)).isScalable(); +} + +bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) { + return cast(unwrap(type)).getScalableDims()[dim]; +} + //===----------------------------------------------------------------------===// // Ranked / Unranked tensor type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATensor(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirTypeID mlirRankedTensorTypeGetTypeID() { + return wrap(RankedTensorType::getTypeID()); +} bool mlirTypeIsARankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); +} + +MlirTypeID mlirUnrankedTensorTypeGetTypeID() { + return wrap(UnrankedTensorType::getTypeID()); } bool mlirTypeIsAUnrankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType) { - return wrap(RankedTensorType::get( - llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + MlirType elementType, MlirAttribute encoding) { + return wrap( + RankedTensorType::get(llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), unwrap(encoding))); } MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, + MlirAttribute encoding) { return wrap(RankedTensorType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), unwrap(encoding))); +} + +MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getEncoding()); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { @@ -214,89 +465,119 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); } +MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getElementType()); +} + //===----------------------------------------------------------------------===// // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } +MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } + +bool mlirTypeIsAMemRef(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, - const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, - unsigned memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + const int64_t *shape, MlirAttribute layout, + MlirAttribute memorySpace) { + return wrap(MemRefType::get( + llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), + mlirAttributeIsNull(layout) + ? MemRefLayoutAttrInterface() + : llvm::cast(unwrap(layout)), + unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, - MlirAffineMap const *affineMaps, - unsigned memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); + MlirAttribute layout, + MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), + mlirAttributeIsNull(layout) + ? MemRefLayoutAttrInterface() + : llvm::cast(unwrap(layout)), + unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + MlirAttribute memorySpace) { + return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), MemRefLayoutAttrInterface(), + unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace))); +} + +MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getLayout()); } -intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { - return static_cast( - unwrap(type).cast().getAffineMaps().size()); +MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getLayout().getAffineMap()); } -MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getAffineMaps()[pos]); +MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } -unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, + int64_t *strides, + int64_t *offset) { + MemRefType memrefType = llvm::cast(unwrap(type)); + SmallVector strides_; + if (failed(memrefType.getStridesAndOffset(strides_, *offset))) + return mlirLogicalResultFailure(); + + (void)std::copy(strides_.begin(), strides_.end(), strides); + return mlirLogicalResultSuccess(); +} + +MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { + return wrap(UnrankedMemRefType::getTypeID()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } -MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { - return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, + MlirAttribute memorySpace) { + return wrap( + UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); } MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), - memorySpace)); + unwrap(memorySpace))); } -unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } //===----------------------------------------------------------------------===// // Tuple type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa(); } +MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } + +bool mlirTypeIsATuple(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements) { @@ -306,19 +587,24 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, } intptr_t mlirTupleTypeGetNumTypes(MlirType type) { - return unwrap(type).cast().size(); + return llvm::cast(unwrap(type)).size(); } MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getType(static_cast(pos))); + return wrap( + llvm::cast(unwrap(type)).getType(static_cast(pos))); } //===----------------------------------------------------------------------===// // Function type. //===----------------------------------------------------------------------===// +MlirTypeID mlirFunctionTypeGetTypeID() { + return wrap(FunctionType::getTypeID()); +} + bool mlirTypeIsAFunction(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, @@ -332,21 +618,47 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, } intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { - return unwrap(type).cast().getNumInputs(); + return llvm::cast(unwrap(type)).getNumInputs(); } intptr_t mlirFunctionTypeGetNumResults(MlirType type) { - return unwrap(type).cast().getNumResults(); + return llvm::cast(unwrap(type)).getNumResults(); } MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getInput(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getInput(static_cast(pos))); } MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); + return wrap(llvm::cast(unwrap(type)) + .getResult(static_cast(pos))); +} + +//===----------------------------------------------------------------------===// +// Opaque type. +//===----------------------------------------------------------------------===// + +MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } + +bool mlirTypeIsAOpaque(MlirType type) { + return llvm::isa(unwrap(type)); +} + +MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, + MlirStringRef typeData) { return wrap( - unwrap(type).cast().getResult(static_cast(pos))); + OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), + unwrap(typeData))); +} + +MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { + return wrap( + llvm::cast(unwrap(type)).getDialectNamespace().strref()); +} + +MlirStringRef mlirOpaqueTypeGetData(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getTypeData()); } diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt index 486ba6e0f..36f28520d 100644 --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -1,5 +1,5 @@ # Main API shared library. -add_mlir_public_c_api_library(MLIRCAPIIR +add_mlir_upstream_c_api_library(MLIRCAPIIR AffineExpr.cpp AffineMap.cpp BuiltinAttributes.cpp @@ -12,6 +12,7 @@ add_mlir_public_c_api_library(MLIRCAPIIR Support.cpp LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRIR MLIRParser MLIRSupport diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp index 2ed05a5a0..4a13ae576 100644 --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -57,13 +57,14 @@ MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( MlirContext context, MlirDiagnosticHandler handler, void *userData, void (*deleteUserData)(void *)) { assert(handler && "unexpected null diagnostic handler"); - if (deleteUserData == NULL) + if (deleteUserData == nullptr) deleteUserData = deleteUserDataNoop; - std::shared_ptr sharedUserData(userData, deleteUserData); DiagnosticEngine::HandlerID id = unwrap(context)->getDiagEngine().registerHandler( - [handler, sharedUserData](Diagnostic &diagnostic) { - return unwrap(handler(wrap(diagnostic), sharedUserData.get())); + [handler, + ownedUserData = std::unique_ptr( + userData, deleteUserData)](Diagnostic &diagnostic) { + return unwrap(handler(wrap(diagnostic), ownedUserData.get())); }); return static_cast(id); } diff --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp index fb972316e..19f64d948 100644 --- a/mlir/lib/CAPI/IR/DialectHandle.cpp +++ b/mlir/lib/CAPI/IR/DialectHandle.cpp @@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) { return unwrap(handle)->getNamespaceHook(); } +void mlirDialectHandleInsertDialect(MlirDialectHandle handle, + MlirDialectRegistry registry) { + unwrap(handle)->insertHook(registry); +} + void mlirDialectHandleRegisterDialect(MlirDialectHandle handle, MlirContext ctx) { - unwrap(handle)->registerHook(ctx); + mlir::DialectRegistry registry; + mlirDialectHandleInsertDialect(handle, wrap(®istry)); + unwrap(ctx)->appendDialectRegistry(registry); } MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 67032a4b5..8491553da 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -9,17 +9,32 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Parser.h" +#include "mlir/Parser/Parser.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Support/ThreadPool.h" + +#include +#include +#include using namespace mlir; @@ -32,6 +47,23 @@ MlirContext mlirContextCreate() { return wrap(context); } +static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { + return threadingEnabled ? MLIRContext::Threading::ENABLED + : MLIRContext::Threading::DISABLED; +} + +MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { + auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); + return wrap(context); +} + +MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, + bool threadingEnabled) { + auto *context = + new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled)); + return wrap(context); +} + bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { return unwrap(ctx1) == unwrap(ctx2); } @@ -49,6 +81,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { return static_cast(unwrap(context)->getAvailableDialects().size()); } +void mlirContextAppendDialectRegistry(MlirContext ctx, + MlirDialectRegistry registry) { + unwrap(ctx)->appendDialectRegistry(*unwrap(registry)); +} + // TODO: expose a cheaper way than constructing + sorting a vector only to take // its size. intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { @@ -60,6 +97,31 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context, return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); } +bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { + return unwrap(context)->isOperationRegistered(unwrap(name)); +} + +void mlirContextEnableMultithreading(MlirContext context, bool enable) { + return unwrap(context)->enableMultithreading(enable); +} + +void mlirContextLoadAllAvailableDialects(MlirContext context) { + unwrap(context)->loadAllAvailableDialects(); +} + +void mlirContextSetThreadPool(MlirContext context, + MlirLlvmThreadPool threadPool) { + unwrap(context)->setThreadPool(*unwrap(threadPool)); +} + +unsigned mlirContextGetNumThreads(MlirContext context) { + return unwrap(context)->getNumThreads(); +} + +MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) { + return wrap(&unwrap(context)->getThreadPool()); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// @@ -76,6 +138,63 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { return wrap(unwrap(dialect)->getNamespace()); } +//===----------------------------------------------------------------------===// +// DialectRegistry API. +//===----------------------------------------------------------------------===// + +MlirDialectRegistry mlirDialectRegistryCreate() { + return wrap(new DialectRegistry()); +} + +void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { + delete unwrap(registry); +} + +//===----------------------------------------------------------------------===// +// AsmState API. +//===----------------------------------------------------------------------===// + +MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, + MlirOpPrintingFlags flags) { + return wrap(new AsmState(unwrap(op), *unwrap(flags))); +} + +static Operation *findParent(Operation *op, bool shouldUseLocalScope) { + do { + // If we are printing local scope, stop at the first operation that is + // isolated from above. + if (shouldUseLocalScope && op->hasTrait()) + break; + + // Otherwise, traverse up to the next parent. + Operation *parentOp = op->getParentOp(); + if (!parentOp) + break; + op = parentOp; + } while (true); + return op; +} + +MlirAsmState mlirAsmStateCreateForValue(MlirValue value, + MlirOpPrintingFlags flags) { + Operation *op; + mlir::Value val = unwrap(value); + if (auto result = llvm::dyn_cast(val)) { + op = result.getOwner(); + } else { + op = llvm::cast(val).getOwner()->getParentOp(); + if (!op) { + emitError(val.getLoc()) << "<>"; + return {nullptr}; + } + } + op = findParent(op, unwrap(flags)->shouldUseLocalScope()); + return wrap(new AsmState(op, *unwrap(flags))); +} + +/// Destroys printing flags created with mlirAsmStateCreate. +void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); } + //===----------------------------------------------------------------------===// // Printing flags API. //===----------------------------------------------------------------------===// @@ -93,23 +212,64 @@ void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); } -void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, +void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, + intptr_t largeResourceLimit) { + unwrap(flags)->elideLargeResourceString(largeResourceLimit); +} + +void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm) { - unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); + unwrap(flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); } void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { unwrap(flags)->printGenericOpForm(); } +void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) { + unwrap(flags)->printNameLocAsPrefix(); +} + void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { unwrap(flags)->useLocalScope(); } +void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { + unwrap(flags)->assumeVerified(); +} + +void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) { + unwrap(flags)->skipRegions(); +} +//===----------------------------------------------------------------------===// +// Bytecode printing flags API. +//===----------------------------------------------------------------------===// + +MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { + return wrap(new BytecodeWriterConfig()); +} + +void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { + delete unwrap(config); +} + +void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, + int64_t version) { + unwrap(flags)->setDesiredBytecodeVersion(version); +} + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// +MlirAttribute mlirLocationGetAttribute(MlirLocation location) { + return wrap(LocationAttr(unwrap(location))); +} + +MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { + return wrap(Location(llvm::dyn_cast(unwrap(attribute)))); +} + MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col) { @@ -117,10 +277,129 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context, FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); } +MlirLocation +mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, + unsigned startLine, unsigned startCol, + unsigned endLine, unsigned endCol) { + return wrap( + Location(FileLineColRange::get(unwrap(context), unwrap(filename), + startLine, startCol, endLine, endCol))); +} + +MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) { + return wrap(llvm::dyn_cast(unwrap(location)).getFilename()); +} + +int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getStartLine(); + return -1; +} + +int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getStartColumn(); + return -1; +} + +int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getEndLine(); + return -1; +} + +int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getEndColumn(); + return -1; +} + +MlirTypeID mlirLocationFileLineColRangeGetTypeID() { + return wrap(FileLineColRange::getTypeID()); +} + +bool mlirLocationIsAFileLineColRange(MlirLocation location) { + return isa(unwrap(location)); +} + MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); } +MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) { + return wrap( + Location(llvm::dyn_cast(unwrap(location)).getCallee())); +} + +MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) { + return wrap( + Location(llvm::dyn_cast(unwrap(location)).getCaller())); +} + +MlirTypeID mlirLocationCallSiteGetTypeID() { + return wrap(CallSiteLoc::getTypeID()); +} + +bool mlirLocationIsACallSite(MlirLocation location) { + return isa(unwrap(location)); +} + +MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, + MlirLocation const *locations, + MlirAttribute metadata) { + SmallVector locs; + ArrayRef unwrappedLocs = unwrapList(nLocations, locations, locs); + return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); +} + +unsigned mlirLocationFusedGetNumLocations(MlirLocation location) { + if (auto locationsArrRef = llvm::dyn_cast(unwrap(location))) + return locationsArrRef.getLocations().size(); + return 0; +} + +void mlirLocationFusedGetLocations(MlirLocation location, + MlirLocation *locationsCPtr) { + if (auto locationsArrRef = llvm::dyn_cast(unwrap(location))) { + for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations())) + locationsCPtr[i] = wrap(location); + } +} + +MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) { + return wrap(llvm::dyn_cast(unwrap(location)).getMetadata()); +} + +MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); } + +bool mlirLocationIsAFused(MlirLocation location) { + return isa(unwrap(location)); +} + +MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, + MlirLocation childLoc) { + if (mlirLocationIsNull(childLoc)) + return wrap( + Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); + return wrap(Location(NameLoc::get( + StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); +} + +MlirIdentifier mlirLocationNameGetName(MlirLocation location) { + return wrap((llvm::dyn_cast(unwrap(location)).getName())); +} + +MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) { + return wrap( + Location(llvm::dyn_cast(unwrap(location)).getChildLoc())); +} + +MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); } + +bool mlirLocationIsAName(MlirLocation location) { + return isa(unwrap(location)); +} + MlirLocation mlirLocationUnknownGet(MlirContext context) { return wrap(Location(UnknownLoc::get(unwrap(context)))); } @@ -148,7 +427,17 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location) { } MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { - OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context)); + OwningOpRef owning = + parseSourceString(unwrap(module), unwrap(context)); + if (!owning) + return MlirModule{nullptr}; + return MlirModule{owning.release().getOperation()}; +} + +MlirModule mlirModuleCreateParseFromFile(MlirContext context, + MlirStringRef fileName) { + OwningOpRef owning = + parseSourceFile(unwrap(fileName), unwrap(context)); if (!owning) return MlirModule{nullptr}; return MlirModule{owning.release().getOperation()}; @@ -163,14 +452,19 @@ MlirBlock mlirModuleGetBody(MlirModule module) { } void mlirModuleDestroy(MlirModule module) { - // Transfer ownership to an OwningModuleRef so that its destructor is called. - OwningModuleRef(unwrap(module)); + // Transfer ownership to an OwningOpRef so that its destructor is + // called. + OwningOpRef(unwrap(module)); } MlirOperation mlirModuleGetOperation(MlirModule module) { return wrap(unwrap(module).getOperation()); } +MlirModule mlirModuleFromOperation(MlirOperation op) { + return wrap(dyn_cast(unwrap(op))); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// @@ -231,30 +525,53 @@ void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { static LogicalResult inferOperationTypes(OperationState &state) { MLIRContext *context = state.getContext(); - const AbstractOperation *abstractOp = - AbstractOperation::lookup(state.name.getStringRef(), context); - if (!abstractOp) { + std::optional info = state.name.getRegisteredInfo(); + if (!info) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation was not registered. Ensure that the dialect " + << ", but the operation was not registered; ensure that the dialect " "containing the operation is linked into MLIR and registered with " "the context"; return failure(); } - // Fallback to inference via an op interface. - auto *inferInterface = abstractOp->getInterface(); + auto *inferInterface = info->getInterface(); if (!inferInterface) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation does not support type inference. Result " - "types must be specified explicitly."; + << ", but the operation does not support type inference; result " + "types must be specified explicitly"; + return failure(); + } + + DictionaryAttr attributes = state.attributes.getDictionary(context); + OpaqueProperties properties = state.getRawProperties(); + + if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { + auto prop = std::make_unique(info->getOpPropertyByteSize()); + properties = OpaqueProperties(prop.get()); + if (properties) { + auto emitError = [&]() { + return mlir::emitError(state.location) + << " failed properties conversion while building " + << state.name.getStringRef() << " with `" << attributes << "`: "; + }; + if (failed(info->setOpPropertiesFromAttribute(state.name, properties, + attributes, emitError))) + return failure(); + } + if (succeeded(inferInterface->inferReturnTypes( + context, state.location, state.operands, attributes, properties, + state.regions, state.types))) { + return success(); + } + // Diagnostic emitted by interface. return failure(); } if (succeeded(inferInterface->inferReturnTypes( - context, state.location, state.operands, - state.attributes.getDictionary(context), state.regions, state.types))) + context, state.location, state.operands, attributes, properties, + state.regions, state.types))) return success(); // Diagnostic emitted by interface. @@ -295,12 +612,26 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) { return {nullptr}; } - MlirOperation result = wrap(Operation::create(cppState)); - return result; + return wrap(Operation::create(cppState)); +} + +MlirOperation mlirOperationCreateParse(MlirContext context, + MlirStringRef sourceStr, + MlirStringRef sourceName) { + + return wrap( + parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName)) + .release()); +} + +MlirOperation mlirOperationClone(MlirOperation op) { + return wrap(unwrap(op)->clone()); } void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } +void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } + bool mlirOperationEqual(MlirOperation op, MlirOperation other) { return unwrap(op) == unwrap(other); } @@ -309,6 +640,16 @@ MlirContext mlirOperationGetContext(MlirOperation op) { return wrap(unwrap(op)->getContext()); } +MlirLocation mlirOperationGetLocation(MlirOperation op) { + return wrap(unwrap(op)->getLoc()); +} + +MlirTypeID mlirOperationGetTypeID(MlirOperation op) { + if (auto info = unwrap(op)->getRegisteredInfo()) + return wrap(info->getTypeID()); + return {nullptr}; +} + MlirIdentifier mlirOperationGetName(MlirOperation op) { return wrap(unwrap(op)->getName().getIdentifier()); } @@ -329,6 +670,22 @@ MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { return wrap(&unwrap(op)->getRegion(static_cast(pos))); } +MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { + Operation *cppOp = unwrap(op); + if (cppOp->getNumRegions() == 0) + return wrap(static_cast(nullptr)); + return wrap(&cppOp->getRegion(0)); +} + +MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { + Region *cppRegion = unwrap(region); + Operation *parent = cppRegion->getParentOp(); + intptr_t next = cppRegion->getRegionNumber() + 1; + if (parent->getNumRegions() > next) + return wrap(&parent->getRegion(next)); + return wrap(static_cast(nullptr)); +} + MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { return wrap(unwrap(op)->getNextNode()); } @@ -341,6 +698,17 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { return wrap(unwrap(op)->getOperand(static_cast(pos))); } +void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue) { + unwrap(op)->setOperand(static_cast(pos), unwrap(newValue)); +} + +void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands, + MlirValue const *operands) { + SmallVector ops; + unwrap(op)->setOperands(unwrapList(nOperands, operands, ops)); +} + intptr_t mlirOperationGetNumResults(MlirOperation op) { return static_cast(unwrap(op)->getNumResults()); } @@ -357,13 +725,67 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { return wrap(unwrap(op)->getSuccessor(static_cast(pos))); } +MLIR_CAPI_EXPORTED bool +mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { + std::optional attr = unwrap(op)->getInherentAttr(unwrap(name)); + return attr.has_value(); +} + +MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, + MlirStringRef name) { + std::optional attr = unwrap(op)->getInherentAttr(unwrap(name)); + if (attr.has_value()) + return wrap(*attr); + return {}; +} + +void mlirOperationSetInherentAttributeByName(MlirOperation op, + MlirStringRef name, + MlirAttribute attr) { + unwrap(op)->setInherentAttr( + StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); +} + +intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { + return static_cast( + llvm::range_size(unwrap(op)->getDiscardableAttrs())); +} + +MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, + intptr_t pos) { + NamedAttribute attr = + *std::next(unwrap(op)->getDiscardableAttrs().begin(), pos); + return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; +} + +MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, + MlirStringRef name) { + return wrap(unwrap(op)->getDiscardableAttr(unwrap(name))); +} + +void mlirOperationSetDiscardableAttributeByName(MlirOperation op, + MlirStringRef name, + MlirAttribute attr) { + unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr)); +} + +bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, + MlirStringRef name) { + return !!unwrap(op)->removeDiscardableAttr(unwrap(name)); +} + +void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, + MlirBlock block) { + unwrap(op)->setSuccessor(unwrap(block), static_cast(pos)); +} + intptr_t mlirOperationGetNumAttributes(MlirOperation op) { return static_cast(unwrap(op)->getAttrs().size()); } MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { NamedAttribute attr = unwrap(op)->getAttrs()[pos]; - return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)}; + return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; } MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, @@ -392,18 +814,88 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, unwrap(op)->print(stream, *unwrap(flags)); } +void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + if (state.ptr) + unwrap(op)->print(stream, *unwrap(state)); + unwrap(op)->print(stream); +} + +void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + // As no desired version is set, no failure can occur. + (void)writeBytecodeToFile(unwrap(op), stream); +} + +MlirLogicalResult mlirOperationWriteBytecodeWithConfig( + MlirOperation op, MlirBytecodeWriterConfig config, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + return wrap(writeBytecodeToFile(unwrap(op), stream, *unwrap(config))); +} + void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } bool mlirOperationVerify(MlirOperation op) { return succeeded(verify(unwrap(op))); } +void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { + return unwrap(op)->moveAfter(unwrap(other)); +} + +void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { + return unwrap(op)->moveBefore(unwrap(other)); +} + +bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other) { + return unwrap(op)->isBeforeInBlock(unwrap(other)); +} + +static mlir::WalkResult unwrap(MlirWalkResult result) { + switch (result) { + case MlirWalkResultAdvance: + return mlir::WalkResult::advance(); + + case MlirWalkResultInterrupt: + return mlir::WalkResult::interrupt(); + + case MlirWalkResultSkip: + return mlir::WalkResult::skip(); + } + llvm_unreachable("unknown result in WalkResult::unwrap"); +} + +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder) { + switch (walkOrder) { + + case MlirWalkPreOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { + return unwrap(callback(wrap(op), userData)); + }); + break; + case MlirWalkPostOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { + return unwrap(callback(wrap(op), userData)); + }); + } +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// MlirRegion mlirRegionCreate() { return wrap(new Region); } +bool mlirRegionEqual(MlirRegion region, MlirRegion other) { + return unwrap(region) == unwrap(other); +} + MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { Region *cppRegion = unwrap(region); if (cppRegion->empty()) @@ -450,14 +942,19 @@ void mlirRegionDestroy(MlirRegion region) { delete static_cast(region.ptr); } +void mlirRegionTakeBody(MlirRegion target, MlirRegion source) { + unwrap(target)->takeBody(*unwrap(source)); +} + //===----------------------------------------------------------------------===// // Block API. //===----------------------------------------------------------------------===// -MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) { +MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, + MlirLocation const *locs) { Block *b = new Block; for (intptr_t i = 0; i < nArgs; ++i) - b->addArgument(unwrap(args[i])); + b->addArgument(unwrap(args[i]), unwrap(locs[i])); return wrap(b); } @@ -469,6 +966,10 @@ MlirOperation mlirBlockGetParentOperation(MlirBlock block) { return wrap(unwrap(block)->getParentOp()); } +MlirRegion mlirBlockGetParentRegion(MlirBlock block) { + return wrap(unwrap(block)->getParent()); +} + MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { return wrap(unwrap(block)->getNextNode()); } @@ -529,12 +1030,27 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block, void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } +void mlirBlockDetach(MlirBlock block) { + Block *b = unwrap(block); + b->getParent()->getBlocks().remove(b); +} + intptr_t mlirBlockGetNumArguments(MlirBlock block) { return static_cast(unwrap(block)->getNumArguments()); } -MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type) { - return wrap(unwrap(block)->addArgument(unwrap(type))); +MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, + MlirLocation loc) { + return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); +} + +void mlirBlockEraseArgument(MlirBlock block, unsigned index) { + return unwrap(block)->eraseArgument(index); +} + +MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, + MlirLocation loc) { + return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc))); } MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { @@ -547,6 +1063,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, unwrap(block)->print(stream); } +intptr_t mlirBlockGetNumSuccessors(MlirBlock block) { + return static_cast(unwrap(block)->getNumSuccessors()); +} + +MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) { + return wrap(unwrap(block)->getSuccessor(static_cast(pos))); +} + +intptr_t mlirBlockGetNumPredecessors(MlirBlock block) { + Block *b = unwrap(block); + return static_cast(std::distance(b->pred_begin(), b->pred_end())); +} + +MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) { + Block *b = unwrap(block); + Block::pred_iterator it = b->pred_begin(); + std::advance(it, pos); + return wrap(*it); +} + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// @@ -556,39 +1092,44 @@ bool mlirValueEqual(MlirValue value1, MlirValue value2) { } bool mlirValueIsABlockArgument(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } bool mlirValueIsAOpResult(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::dyn_cast(unwrap(value)).getOwner()); } intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getArgNumber()); + llvm::dyn_cast(unwrap(value)).getArgNumber()); } void mlirBlockArgumentSetType(MlirValue value, MlirType type) { - unwrap(value).cast().setType(unwrap(type)); + if (auto blockArg = llvm::dyn_cast(unwrap(value))) + blockArg.setType(unwrap(type)); } MlirOperation mlirOpResultGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::dyn_cast(unwrap(value)).getOwner()); } intptr_t mlirOpResultGetResultNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getResultNumber()); + llvm::dyn_cast(unwrap(value)).getResultNumber()); } MlirType mlirValueGetType(MlirValue value) { return wrap(unwrap(value).getType()); } +void mlirValueSetType(MlirValue value, MlirType type) { + unwrap(value).setType(unwrap(type)); +} + void mlirValueDump(MlirValue value) { unwrap(value).dump(); } void mlirValuePrint(MlirValue value, MlirStringCallback callback, @@ -597,6 +1138,80 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback, unwrap(value).print(stream); } +void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + Value cppValue = unwrap(value); + cppValue.printAsOperand(stream, *unwrap(state)); +} + +MlirOpOperand mlirValueGetFirstUse(MlirValue value) { + Value cppValue = unwrap(value); + if (cppValue.use_empty()) + return {}; + + OpOperand *opOperand = cppValue.use_begin().getOperand(); + + return wrap(opOperand); +} + +void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { + unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); +} + +void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue, + intptr_t numExceptions, + MlirOperation *exceptions) { + Value oldValueCpp = unwrap(oldValue); + Value newValueCpp = unwrap(newValue); + + llvm::SmallPtrSet exceptionSet; + for (intptr_t i = 0; i < numExceptions; ++i) { + exceptionSet.insert(unwrap(exceptions[i])); + } + + oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet); +} + +MlirLocation mlirValueGetLocation(MlirValue v) { + return wrap(unwrap(v).getLoc()); +} + +MlirContext mlirValueGetContext(MlirValue v) { + return wrap(unwrap(v).getContext()); +} + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } + +MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { + return wrap(unwrap(opOperand)->getOwner()); +} + +MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) { + return wrap(unwrap(opOperand)->get()); +} + +unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { + return unwrap(opOperand)->getOperandNumber(); +} + +MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { + if (mlirOpOperandIsNull(opOperand)) + return {}; + + OpOperand *nextOpOperand = static_cast( + unwrap(opOperand)->getNextOperandUsingThisValue()); + + if (!nextOpOperand) + return {}; + + return wrap(nextOpOperand); +} + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// @@ -609,6 +1224,14 @@ MlirContext mlirTypeGetContext(MlirType type) { return wrap(unwrap(type).getContext()); } +MlirTypeID mlirTypeGetTypeID(MlirType type) { + return wrap(unwrap(type).getTypeID()); +} + +MlirDialect mlirTypeGetDialect(MlirType type) { + return wrap(&unwrap(type).getDialect()); +} + bool mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } @@ -633,7 +1256,18 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) { } MlirType mlirAttributeGetType(MlirAttribute attribute) { - return wrap(unwrap(attribute).getType()); + Attribute attr = unwrap(attribute); + if (auto typedAttr = llvm::dyn_cast(attr)) + return wrap(typedAttr.getType()); + return wrap(NoneType::get(attr.getContext())); +} + +MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { + return wrap(unwrap(attr).getTypeID()); +} + +MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { + return wrap(&unwrap(attr).getDialect()); } bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { @@ -658,7 +1292,7 @@ MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, //===----------------------------------------------------------------------===// MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { - return wrap(Identifier::get(unwrap(str), unwrap(context))); + return wrap(StringAttr::get(unwrap(context), unwrap(str))); } MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { @@ -672,3 +1306,62 @@ bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { return wrap(unwrap(ident).strref()); } + +//===----------------------------------------------------------------------===// +// Symbol and SymbolTable API. +//===----------------------------------------------------------------------===// + +MlirStringRef mlirSymbolTableGetSymbolAttributeName() { + return wrap(SymbolTable::getSymbolAttrName()); +} + +MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { + return wrap(SymbolTable::getVisibilityAttrName()); +} + +MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { + if (!unwrap(operation)->hasTrait()) + return wrap(static_cast(nullptr)); + return wrap(new SymbolTable(unwrap(operation))); +} + +void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { + delete unwrap(symbolTable); +} + +MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, + MlirStringRef name) { + return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); +} + +MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, + MlirOperation operation) { + return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); +} + +void mlirSymbolTableErase(MlirSymbolTable symbolTable, + MlirOperation operation) { + unwrap(symbolTable)->erase(unwrap(operation)); +} + +MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, + MlirStringRef newSymbol, + MlirOperation from) { + auto *cppFrom = unwrap(from); + auto *context = cppFrom->getContext(); + auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); + auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); + return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, + unwrap(from))); +} + +void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, + void (*callback)(MlirOperation, bool, + void *userData), + void *userData) { + SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, + [&](Operation *foundOpCpp, bool isVisible) { + callback(wrap(foundOpCpp), isVisible, + userData); + }); +} diff --git a/mlir/lib/CAPI/IR/IntegerSet.cpp b/mlir/lib/CAPI/IR/IntegerSet.cpp index 701d70353..43d48e415 100644 --- a/mlir/lib/CAPI/IR/IntegerSet.cpp +++ b/mlir/lib/CAPI/IR/IntegerSet.cpp @@ -49,7 +49,7 @@ MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims, return wrap(IntegerSet::get( static_cast(numDims), static_cast(numSymbols), mlirConstraints, - llvm::makeArrayRef(eqFlags, static_cast(numConstraints)))); + llvm::ArrayRef(eqFlags, static_cast(numConstraints)))); } MlirIntegerSet diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index b3685ddf4..3c499c3e4 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/Pass/PassManager.h" +#include using namespace mlir; @@ -24,6 +25,11 @@ MlirPassManager mlirPassManagerCreate(MlirContext ctx) { return wrap(new PassManager(unwrap(ctx))); } +MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, + MlirStringRef anchorOp) { + return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp))); +} + void mlirPassManagerDestroy(MlirPassManager passManager) { delete unwrap(passManager); } @@ -33,9 +39,44 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { return wrap(static_cast(unwrap(passManager))); } -MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, - MlirModule module) { - return wrap(unwrap(passManager)->run(unwrap(module))); +MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, + MlirOperation op) { + return wrap(unwrap(passManager)->run(unwrap(op))); +} + +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, + bool printBeforeAll, bool printAfterAll, + bool printModuleScope, + bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure, + MlirOpPrintingFlags flags, + MlirStringRef treePrintingPath) { + auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { + return printBeforeAll; + }; + auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { + return printAfterAll; + }; + if (unwrap(treePrintingPath).empty()) + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure, /*out=*/llvm::errs(), + *unwrap(flags)); + + unwrap(passManager) + ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure, + unwrap(treePrintingPath), *unwrap(flags)); +} + +void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { + unwrap(passManager)->enableVerifier(enable); +} + +void mlirPassManagerEnableTiming(MlirPassManager passManager) { + unwrap(passManager)->enableTiming(); } MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, @@ -57,6 +98,15 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); } +MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, + MlirStringRef pipelineElements, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager), + stream)); +} + void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); @@ -64,8 +114,104 @@ void mlirPrintPassPipeline(MlirOpPassManager passManager, } MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, - MlirStringRef pipeline) { - // TODO: errors are sent to std::errs() at the moment, we should pass in a - // stream and redirect to a diagnostic. - return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); + MlirStringRef pipeline, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + FailureOr pm = parsePassPipeline(unwrap(pipeline), stream); + if (succeeded(pm)) + *unwrap(passManager) = std::move(*pm); + return wrap(pm); +} + +//===----------------------------------------------------------------------===// +// External Pass API. +//===----------------------------------------------------------------------===// + +namespace mlir { +class ExternalPass; +} // namespace mlir +DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) + +namespace mlir { +/// This pass class wraps external passes defined in other languages using the +/// MLIR C-interface +class ExternalPass : public Pass { +public: + ExternalPass(TypeID passID, StringRef name, StringRef argument, + StringRef description, std::optional opName, + ArrayRef dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData) + : Pass(passID, opName), id(passID), name(name), argument(argument), + description(description), dependentDialects(dependentDialects), + callbacks(callbacks), userData(userData) { + callbacks.construct(userData); + } + + ~ExternalPass() override { callbacks.destruct(userData); } + + StringRef getName() const override { return name; } + StringRef getArgument() const override { return argument; } + StringRef getDescription() const override { return description; } + + void getDependentDialects(DialectRegistry ®istry) const override { + MlirDialectRegistry cRegistry = wrap(®istry); + for (MlirDialectHandle dialect : dependentDialects) + mlirDialectHandleInsertDialect(dialect, cRegistry); + } + + void signalPassFailure() { Pass::signalPassFailure(); } + +protected: + LogicalResult initialize(MLIRContext *ctx) override { + if (callbacks.initialize) + return unwrap(callbacks.initialize(wrap(ctx), userData)); + return success(); + } + + bool canScheduleOn(RegisteredOperationName opName) const override { + if (std::optional specifiedOpName = getOpName()) + return opName.getStringRef() == specifiedOpName; + return true; + } + + void runOnOperation() override { + callbacks.run(wrap(getOperation()), wrap(this), userData); + } + + std::unique_ptr clonePass() const override { + void *clonedUserData = callbacks.clone(userData); + return std::make_unique(id, name, argument, description, + getOpName(), dependentDialects, + callbacks, clonedUserData); + } + +private: + TypeID id; + std::string name; + std::string argument; + std::string description; + std::vector dependentDialects; + MlirExternalPassCallbacks callbacks; + void *userData; +}; +} // namespace mlir + +MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, + void *userData) { + return wrap(static_cast(new mlir::ExternalPass( + unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), + opName.length > 0 ? std::optional(unwrap(opName)) + : std::nullopt, + {dependentDialects, static_cast(nDependentDialects)}, callbacks, + userData))); +} + +void mlirExternalPassSignalFailure(MlirExternalPass pass) { + unwrap(pass)->signalPassFailure(); } diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp index e4b409906..3311131fc 100644 --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -6,10 +6,64 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Support.h" +#include "mlir/CAPI/Support.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ThreadPool.h" #include MlirStringRef mlirStringRefCreateFromCString(const char *str) { return mlirStringRefCreate(str, strlen(str)); } + +bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) { + return llvm::StringRef(string.data, string.length) == + llvm::StringRef(other.data, other.length); +} + +//===----------------------------------------------------------------------===// +// LLVM ThreadPool API. +//===----------------------------------------------------------------------===// +MlirLlvmThreadPool mlirLlvmThreadPoolCreate() { + return wrap(new llvm::DefaultThreadPool()); +} + +void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) { + delete unwrap(threadPool); +} + +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// +MlirTypeID mlirTypeIDCreate(const void *ptr) { + assert(reinterpret_cast(ptr) % 8 == 0 && + "ptr must be 8 byte aligned"); + // This is essentially a no-op that returns back `ptr`, but by going through + // the `TypeID` functions we can get compiler errors in case the `TypeID` + // api/representation changes + return wrap(mlir::TypeID::getFromOpaquePointer(ptr)); +} + +bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { + return unwrap(typeID1) == unwrap(typeID2); +} + +size_t mlirTypeIDHashValue(MlirTypeID typeID) { + return hash_value(unwrap(typeID)); +} + +//===----------------------------------------------------------------------===// +// TypeIDAllocator API. +//===----------------------------------------------------------------------===// + +MlirTypeIDAllocator mlirTypeIDAllocatorCreate() { + return wrap(new mlir::TypeIDAllocator()); +} + +void mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator) { + delete unwrap(allocator); +} + +MlirTypeID mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator) { + return wrap(unwrap(allocator)->allocate()); +} diff --git a/mlir/lib/CAPI/Interfaces/CMakeLists.txt b/mlir/lib/CAPI/Interfaces/CMakeLists.txt new file mode 100644 index 000000000..7aefb56d9 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_upstream_c_api_library(MLIRCAPIInterfaces + Interfaces.cpp + + LINK_LIBS PUBLIC + MLIRInferTypeOpInterface) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp new file mode 100644 index 000000000..d3fd6b4c0 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -0,0 +1,169 @@ + + +//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Interfaces.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Interfaces.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/ScopeExit.h" +#include + +using namespace mlir; + +namespace { + +std::optional +getRegisteredOperationName(MlirContext context, MlirStringRef opName) { + StringRef name(opName.data, opName.length); + std::optional info = + RegisteredOperationName::lookup(name, unwrap(context)); + return info; +} + +std::optional maybeGetLocation(MlirLocation location) { + std::optional maybeLocation; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + return maybeLocation; +} + +SmallVector unwrapOperands(intptr_t nOperands, MlirValue *operands) { + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + return unwrappedOperands; +} + +DictionaryAttr unwrapAttributes(MlirAttribute attributes) { + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = llvm::cast(unwrap(attributes)); + return attributeDict; +} + +SmallVector> unwrapRegions(intptr_t nRegions, + MlirRegion *regions) { + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + return unwrappedRegions; +} + +} // namespace + +bool mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID) { + std::optional info = + unwrap(operation)->getRegisteredInfo(); + return info && info->hasInterface(unwrap(interfaceTypeID)); +} + +bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID) { + std::optional info = RegisteredOperationName::lookup( + StringRef(operationName.data, operationName.length), unwrap(context)); + return info && info->hasInterface(unwrap(interfaceTypeID)); +} + +MlirTypeID mlirInferTypeOpInterfaceTypeID() { + return wrap(InferTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirTypesCallback callback, void *userData) { + StringRef name(opName.data, opName.length); + std::optional info = + getRegisteredOperationName(context, opName); + if (!info) + return mlirLogicalResultFailure(); + + std::optional maybeLocation = maybeGetLocation(location); + SmallVector unwrappedOperands = unwrapOperands(nOperands, operands); + DictionaryAttr attributeDict = unwrapAttributes(attributes); + SmallVector> unwrappedRegions = + unwrapRegions(nRegions, regions); + + SmallVector inferredTypes; + if (failed(info->getInterface()->inferReturnTypes( + unwrap(context), maybeLocation, unwrappedOperands, attributeDict, + properties, unwrappedRegions, inferredTypes))) + return mlirLogicalResultFailure(); + + SmallVector wrappedInferredTypes; + wrappedInferredTypes.reserve(inferredTypes.size()); + for (Type t : inferredTypes) + wrappedInferredTypes.push_back(wrap(t)); + callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); + return mlirLogicalResultSuccess(); +} + +MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() { + return wrap(InferShapedTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData) { + std::optional info = + getRegisteredOperationName(context, opName); + if (!info) + return mlirLogicalResultFailure(); + + std::optional maybeLocation = maybeGetLocation(location); + SmallVector unwrappedOperands = unwrapOperands(nOperands, operands); + DictionaryAttr attributeDict = unwrapAttributes(attributes); + SmallVector> unwrappedRegions = + unwrapRegions(nRegions, regions); + + SmallVector inferredTypeComponents; + if (failed(info->getInterface() + ->inferReturnTypeComponents( + unwrap(context), maybeLocation, + mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)), + attributeDict, properties, unwrappedRegions, + inferredTypeComponents))) + return mlirLogicalResultFailure(); + + bool hasRank; + intptr_t rank; + const int64_t *shapeData; + for (const ShapedTypeComponents &t : inferredTypeComponents) { + if (t.hasRank()) { + hasRank = true; + rank = t.getDims().size(); + shapeData = t.getDims().data(); + } else { + hasRank = false; + rank = 0; + shapeData = nullptr; + } + callback(hasRank, rank, shapeData, wrap(t.getElementType()), + wrap(t.getAttribute()), userData); + } + return mlirLogicalResultSuccess(); +} diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt new file mode 100644 index 000000000..ccda668ec --- /dev/null +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -0,0 +1,16 @@ +# Dialect registration. +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) +add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything + RegisterEverything.cpp + + LINK_LIBS PUBLIC + ${translation_libs} + + MLIRBuiltinToLLVMIRTranslation + MLIRCAPIIR + MLIRCAPITransforms + MLIRLLVMToLLVMIRTranslation + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses +) diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp new file mode 100644 index 000000000..c1c4a418b --- /dev/null +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -0,0 +1,32 @@ +//===- RegisterEverything.cpp - Register all MLIR entities ----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/RegisterEverything.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" +#include "mlir/InitAllPasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" + +void mlirRegisterAllDialects(MlirDialectRegistry registry) { + mlir::registerAllDialects(*unwrap(registry)); + mlir::registerAllExtensions(*unwrap(registry)); +} + +void mlirRegisterAllLLVMTranslations(MlirContext context) { + auto &ctx = *unwrap(context); + mlir::DialectRegistry registry; + mlir::registerAllToLLVMIRTranslations(registry); + ctx.appendDialectRegistry(registry); +} + +void mlirRegisterAllPasses() { mlir::registerAllPasses(); } diff --git a/mlir/lib/CAPI/Registration/CMakeLists.txt b/mlir/lib/CAPI/Registration/CMakeLists.txt deleted file mode 100644 index 417140ac0..000000000 --- a/mlir/lib/CAPI/Registration/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -# Dialect registration. -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) -add_mlir_public_c_api_library(MLIRCAPIRegistration - Registration.cpp - - LINK_LIBS PUBLIC - MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation - ${dialect_libs} -) diff --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/Registration/Registration.cpp deleted file mode 100644 index dea782453..000000000 --- a/mlir/lib/CAPI/Registration/Registration.cpp +++ /dev/null @@ -1,23 +0,0 @@ -//===- Registration.cpp - C Interface for MLIR Registration ---------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Registration.h" - -#include "mlir/CAPI/IR.h" -#include "mlir/InitAllDialects.h" -#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" - -void mlirRegisterAllDialects(MlirContext context) { - mlir::registerAllDialects(*unwrap(context)); - // TODO: we may not want to eagerly load here. - unwrap(context)->loadAllAvailableDialects(); -} - -void mlirRegisterAllLLVMTranslations(MlirContext context) { - mlir::registerLLVMDialectTranslation(*unwrap(context)); -} diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt new file mode 100644 index 000000000..8fbb7aa95 --- /dev/null +++ b/mlir/lib/CAPI/Target/CMakeLists.txt @@ -0,0 +1,25 @@ +add_mlir_upstream_c_api_library(MLIRCAPITarget + LLVMIR.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRToLLVMIRTranslationRegistration + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRLLVMIRToLLVMTranslation + MLIRSupport +) + +add_mlir_upstream_c_api_library(MLIRCAPIExportSMTLIB + ExportSMTLIB.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRExportSMTLIB +) diff --git a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp new file mode 100644 index 000000000..4326f9672 --- /dev/null +++ b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp @@ -0,0 +1,42 @@ +//===- ExportSMTLIB.cpp - C Interface to ExportSMTLIB ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements a C Interface for export SMTLIB. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Target/ExportSMTLIB.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Utils.h" +#include "mlir/Target/SMTLIB/ExportSMTLIB.h" + +using namespace mlir; + +MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module, + MlirStringCallback callback, + void *userData, + bool inlineSingleUseValues, + bool indentLetBody) { + mlir::detail::CallbackOstream stream(callback, userData); + smt::SMTEmissionOptions options; + options.inlineSingleUseValues = inlineSingleUseValues; + options.indentLetBody = indentLetBody; + return wrap(smt::exportSMTLIB(unwrap(module), stream)); +} + +MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module, + MlirStringCallback callback, + void *userData, + bool inlineSingleUseValues, + bool indentLetBody) { + return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module), + callback, userData, + inlineSingleUseValues, indentLetBody); +} diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp new file mode 100644 index 000000000..1c1912aec --- /dev/null +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -0,0 +1,79 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Target/LLVMIR.h" + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" + +using namespace mlir; + +LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, + LLVMContextRef context) { + Operation *moduleOp = unwrap(module); + + llvm::LLVMContext *ctx = llvm::unwrap(context); + + std::unique_ptr llvmModule = + mlir::translateModuleToLLVMIR(moduleOp, *ctx); + + LLVMModuleRef moduleRef = llvm::wrap(llvmModule.release()); + + return moduleRef; +} + +DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator, + mlir::LLVM::TypeFromLLVMIRTranslator) + +MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) { + MLIRContext *context = unwrap(ctx); + auto *translator = new LLVM::TypeFromLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeFromLLVMIRTranslatorDestroy( + MlirTypeFromLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType) { + LLVM::TypeFromLLVMIRTranslator *translator_ = unwrap(translator); + mlir::Type type = translator_->translateType(llvm::unwrap(llvmType)); + return wrap(type); +} + +DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator, + mlir::LLVM::TypeToLLVMIRTranslator) + +MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) { + llvm::LLVMContext *context = llvm::unwrap(ctx); + auto *translator = new LLVM::TypeToLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +LLVMTypeRef +mlirTypeToLLVMIRTranslatorTranslateType(MlirTypeToLLVMIRTranslator translator, + MlirType mlirType) { + LLVM::TypeToLLVMIRTranslator *translator_ = unwrap(translator); + llvm::Type *type = translator_->translateType(unwrap(mlirType)); + return llvm::wrap(type); +} diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt index e5e1677ec..6c67aa09f 100644 --- a/mlir/lib/CAPI/Transforms/CMakeLists.txt +++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt @@ -1,6 +1,9 @@ -add_mlir_public_c_api_library(MLIRCAPITransforms +add_mlir_upstream_c_api_library(MLIRCAPITransforms Passes.cpp + Rewrite.cpp LINK_LIBS PUBLIC + MLIRIR MLIRTransforms + MLIRTransformUtils ) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp new file mode 100644 index 000000000..a4df97f7b --- /dev/null +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -0,0 +1,327 @@ +//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Rewrite.h" + +#include "mlir-c/Transforms.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Rewrite.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +/// RewriterBase API inherited from OpBuilder +//===----------------------------------------------------------------------===// + +MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getContext()); +} + +//===----------------------------------------------------------------------===// +/// Insertion points methods +//===----------------------------------------------------------------------===// + +void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { + unwrap(rewriter)->clearInsertionPoint(); +} + +void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->setInsertionPoint(unwrap(op)); +} + +void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); +} + +void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, + MlirValue value) { + unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); +} + +void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, + MlirBlock block) { + unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); +} + +void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, + MlirBlock block) { + unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); +} + +MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getInsertionBlock()); +} + +MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getBlock()); +} + +//===----------------------------------------------------------------------===// +/// Block and operation creation/insertion/cloning +//===----------------------------------------------------------------------===// + +MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, + MlirBlock insertBefore, + intptr_t nArgTypes, + MlirType const *argTypes, + MlirLocation const *locations) { + SmallVector args; + ArrayRef unwrappedArgs = unwrapList(nArgTypes, argTypes, args); + SmallVector locs; + ArrayRef unwrappedLocs = unwrapList(nArgTypes, locations, locs); + return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, + unwrappedLocs)); +} + +MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->insert(unwrap(op))); +} + +// Other methods of OpBuilder + +MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->clone(*unwrap(op))); +} + +MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); +} + +void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, + MlirRegion region, MlirBlock before) { + + unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); +} + +//===----------------------------------------------------------------------===// +/// RewriterBase API +//===----------------------------------------------------------------------===// + +void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, + MlirRegion region, MlirBlock before) { + unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); +} + +void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, + MlirOperation op, intptr_t nValues, + MlirValue const *values) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nValues, values, vals); + unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); +} + +void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation newOp) { + unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); +} + +void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { + unwrap(rewriter)->eraseOp(unwrap(op)); +} + +void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { + unwrap(rewriter)->eraseBlock(unwrap(block)); +} + +void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, + MlirBlock source, MlirOperation op, + intptr_t nArgValues, + MlirValue const *argValues) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nArgValues, argValues, vals); + + unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), + unwrappedVals); +} + +void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, + MlirBlock dest, intptr_t nArgValues, + MlirValue const *argValues) { + SmallVector args; + ArrayRef unwrappedArgs = unwrapList(nArgValues, argValues, args); + unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); +} + +void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, + MlirOperation existingOp) { + unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); +} + +void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, + MlirOperation existingOp) { + unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); +} + +void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, + MlirBlock existingBlock) { + unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); +} + +void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->startOpModification(unwrap(op)); +} + +void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->finalizeOpModification(unwrap(op)); +} + +void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->cancelOpModification(unwrap(op)); +} + +void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, + MlirValue from, MlirValue to) { + unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); +} + +void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, + intptr_t nValues, + MlirValue const *from, + MlirValue const *to) { + SmallVector fromVals; + ArrayRef unwrappedFromVals = unwrapList(nValues, from, fromVals); + SmallVector toVals; + ArrayRef unwrappedToVals = unwrapList(nValues, to, toVals); + unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); +} + +void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, + MlirOperation from, + intptr_t nTo, + MlirValue const *to) { + SmallVector toVals; + ArrayRef unwrappedToVals = unwrapList(nTo, to, toVals); + unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); +} + +void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, + MlirOperation from, + MlirOperation to) { + unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); +} + +void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, + MlirOperation op, + intptr_t nNewValues, + MlirValue const *newValues, + MlirBlock block) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nNewValues, newValues, vals); + unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, + unwrap(block)); +} + +void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, + MlirValue from, MlirValue to, + MlirOperation exceptedUser) { + unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), + unwrap(exceptedUser)); +} + +//===----------------------------------------------------------------------===// +/// IRRewriter API +//===----------------------------------------------------------------------===// + +MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { + return wrap(new IRRewriter(unwrap(context))); +} + +MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { + return wrap(new IRRewriter(unwrap(op))); +} + +void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { + delete static_cast(unwrap(rewriter)); +} + +//===----------------------------------------------------------------------===// +/// RewritePatternSet and FrozenRewritePatternSet API +//===----------------------------------------------------------------------===// + +inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { + assert(module.ptr && "unexpected null module"); + return *(static_cast(module.ptr)); +} + +inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { + return {module}; +} + +inline mlir::FrozenRewritePatternSet * +unwrap(MlirFrozenRewritePatternSet module) { + assert(module.ptr && "unexpected null module"); + return static_cast(module.ptr); +} + +inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { + return {module}; +} + +MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); + op.ptr = nullptr; + return wrap(m); +} + +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { + delete unwrap(op); + op.ptr = nullptr; +} + +MlirLogicalResult +mlirApplyPatternsAndFoldGreedily(MlirModule op, + MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig) { + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); +} + +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { + assert(module.ptr && "unexpected null module"); + return static_cast(module.ptr); +} + +inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { + return {module}; +} + +MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { + return wrap(new mlir::PDLPatternModule( + mlir::OwningOpRef(unwrap(op)))); +} + +void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { + delete unwrap(op); + op.ptr = nullptr; +} + +MlirRewritePatternSet +mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { + auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); + op.ptr = nullptr; + return wrap(m); +} +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt new file mode 100644 index 000000000..7a0c95ebb --- /dev/null +++ b/mlir/python/CMakeLists.txt @@ -0,0 +1,850 @@ +include(AddMLIRPython) + +################################################################################ +# Structural groupings. +################################################################################ + +declare_mlir_python_sources(MLIRPythonSources) +declare_mlir_python_sources(MLIRPythonSources.Dialects + ADD_TO_PARENT MLIRPythonSources) +declare_mlir_python_sources(MLIRPythonSources.Core + ADD_TO_PARENT MLIRPythonSources) + +################################################################################ +# Pure python sources and generated code +################################################################################ + +declare_mlir_python_sources(MLIRPythonSources.Core.Python + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources.Core + SOURCES + _mlir_libs/__init__.py + ir.py + passmanager.py + rewrite.py + dialects/_ods_common.py + + # The main _mlir module has submodules: include stubs from each. + _mlir_libs/_mlir/__init__.pyi + _mlir_libs/_mlir/ir.pyi + _mlir_libs/_mlir/passmanager.pyi +) + +declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources.Core.Python + SOURCES + extras/types.py + extras/meta.py +) + +declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources + SOURCES + execution_engine.py + _mlir_libs/_mlirExecutionEngine.pyi + SOURCES_GLOB + runtime/*.py +) + +declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources + ROOT_DIR "${MLIR_SOURCE_DIR}/include" + SOURCES_GLOB "mlir-c/*.h" +) + +################################################################################ +# Dialect bindings +################################################################################ + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/AffineOps.td + SOURCES + dialects/affine.py + DIALECT_NAME affine + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/AMDGPUOps.td + SOURCES + dialects/amdgpu.py + DIALECT_NAME amdgpu + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/AsyncOps.td + SOURCES_GLOB dialects/async_dialect/*.py + DIALECT_NAME async) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BufferizationOps.td + SOURCES + dialects/bufferization.py + DIALECT_NAME bufferization + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td" +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BuiltinOps.td + SOURCES + dialects/builtin.py + DIALECT_NAME builtin) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ComplexOps.td + SOURCES + dialects/complex.py + DIALECT_NAME complex) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/IndexOps.td + SOURCES + dialects/index.py + DIALECT_NAME index + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ControlFlowOps.td + SOURCES + dialects/cf.py + DIALECT_NAME cf) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/FuncOps.td + SOURCES + dialects/func.py + DIALECT_NAME func) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/GPUOps.td + SOURCES_GLOB dialects/gpu/*.py + DIALECT_NAME gpu + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LinalgOps.td + SOURCES + SOURCES_GLOB + dialects/linalg/*.py + DIALECT_NAME linalg + DEPENDS LinalgOdsGen + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LLVMOps.td + SOURCES + dialects/llvm.py + DIALECT_NAME llvm + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformPDLExtensionOps.td + SOURCES + dialects/transform/pdl.py + DIALECT_NAME transform + EXTENSION_NAME transform_pdl_extension) + +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformDebugExtensionOps.td + SOURCES + dialects/transform/debug.py + DIALECT_NAME transform + EXTENSION_NAME transform_debug_extension) + +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformTuneExtensionOps.td + SOURCES + dialects/transform/tune.py + DIALECT_NAME transform + EXTENSION_NAME transform_tune_extension) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformOps.td + SOURCES + dialects/transform/__init__.py + _mlir_libs/_mlir/dialects/transform/__init__.pyi + DIALECT_NAME transform + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td" +) + +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.extras + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS + SOURCES + dialects/transform/extras/__init__.py) + +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.interpreter + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES + dialects/transform/interpreter/__init__.py) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BufferizationTransformOps.td + SOURCES + dialects/transform/bufferization.py + DIALECT_NAME transform + EXTENSION_NAME bufferization_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/GPUTransformOps.td + SOURCES + dialects/transform/gpu.py + DIALECT_NAME transform + EXTENSION_NAME gpu_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFLoopTransformOps.td + SOURCES + dialects/transform/loop.py + DIALECT_NAME transform + EXTENSION_NAME loop_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MemRefTransformOps.td + SOURCES + dialects/transform/memref.py + DIALECT_NAME transform + EXTENSION_NAME memref_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/NVGPUTransformOps.td + SOURCES + dialects/transform/nvgpu.py + DIALECT_NAME transform + EXTENSION_NAME nvgpu_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LinalgStructuredTransformOps.td + SOURCES + dialects/transform/structured.py + DIALECT_NAME transform + EXTENSION_NAME structured_transform + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" +) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SparseTensorTransformOps.td + SOURCES + dialects/transform/sparse_tensor.py + DIALECT_NAME transform + EXTENSION_NAME sparse_tensor_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TensorTransformOps.td + SOURCES + dialects/transform/tensor.py + DIALECT_NAME transform + EXTENSION_NAME tensor_transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/VectorTransformOps.td + SOURCES + dialects/transform/vector.py + DIALECT_NAME transform + EXTENSION_NAME vector_transform + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MathOps.td + SOURCES dialects/math.py + DIALECT_NAME math) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ArithOps.td + SOURCES + dialects/arith.py + DIALECT_NAME arith + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MemRefOps.td + SOURCES + dialects/memref.py + DIALECT_NAME memref) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MLProgramOps.td + SOURCES + dialects/ml_program.py + DIALECT_NAME ml_program) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/NVGPUOps.td + SOURCES + dialects/nvgpu.py + DIALECT_NAME nvgpu + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/NVVMOps.td + SOURCES + dialects/nvvm.py + DIALECT_NAME nvvm + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ROCDLOps.td + SOURCES + dialects/rocdl.py + DIALECT_NAME rocdl) + +declare_mlir_python_sources( + MLIRPythonSources.Dialects.quant + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS + SOURCES + dialects/quant.py + _mlir_libs/_mlir/dialects/quant.pyi) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/EmitC.td + SOURCES + dialects/emitc.py + DIALECT_NAME emitc) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/PDLOps.td + SOURCES + dialects/pdl.py + _mlir_libs/_mlir/dialects/pdl.pyi + DIALECT_NAME pdl) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/OpenMPOps.td + SOURCES + dialects/openmp.py + DIALECT_NAME omp + DEPENDS omp_common_td) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFOps.td + SOURCES + dialects/scf.py + DIALECT_NAME scf) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ShapeOps.td + SOURCES dialects/shape.py + DIALECT_NAME shape) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SparseTensorOps.td + SOURCES dialects/sparse_tensor.py + DIALECT_NAME sparse_tensor + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SMTOps.td + GEN_ENUM_BINDINGS + SOURCES + dialects/smt.py + DIALECT_NAME smt) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SPIRVOps.td + SOURCES dialects/spirv.py + DIALECT_NAME spirv) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TensorOps.td + SOURCES + dialects/tensor.py + DIALECT_NAME tensor) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TosaOps.td + SOURCES dialects/tosa.py + DIALECT_NAME tosa +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/VectorOps.td + SOURCES dialects/vector.py + DIALECT_NAME vector + GEN_ENUM_BINDINGS_TD_FILE + "dialects/VectorAttributes.td") + +################################################################################ +# Python extensions. +# The sources for these are all in lib/Bindings/Python, but since they have to +# be rebuilt for each package and integrate with the source setup here, we +# just reference them here instead of having ordered, cross package target +# dependencies. +################################################################################ + +set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") +declare_mlir_python_extension(MLIRPythonExtension.Core + MODULE_NAME _mlir + ADD_TO_PARENT MLIRPythonSources.Core + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + MainModule.cpp + IRAffine.cpp + IRAttributes.cpp + IRCore.cpp + IRInterfaces.cpp + IRModule.cpp + IRTypes.cpp + Pass.cpp + Rewrite.cpp + + # Headers must be included explicitly so they are installed. + Globals.h + IRModule.h + Pass.h + NanobindUtils.h + Rewrite.h + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIDebug + MLIRCAPIIR + MLIRCAPIInterfaces + + # Dialects + MLIRCAPIFunc +) + +# This extension exposes an API to register all dialects, extensions, and passes +# packaged in upstream MLIR and it is used for the upstream "mlir" Python +# package. Downstreams will likely want to provide their own and not depend +# on this one, since it links in the world. +# Note that this is not added to any top-level source target for transitive +# inclusion: It must be included explicitly by downstreams if desired. Note that +# this has a very large impact on what gets built/packaged. +declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything + MODULE_NAME _mlirRegisterEverything + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + RegisterEverything.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIConversion + MLIRCAPITransforms + MLIRCAPIRegisterEverything +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind + MODULE_NAME _mlirDialectsLinalg + ADD_TO_PARENT MLIRPythonSources.Dialects.linalg + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectLinalg.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPILinalg +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind + MODULE_NAME _mlirDialectsGPU + ADD_TO_PARENT MLIRPythonSources.Dialects.gpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectGPU.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIGPU +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind + MODULE_NAME _mlirDialectsLLVM + ADD_TO_PARENT MLIRPythonSources.Dialects.llvm + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectLLVM.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPILLVM +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind + MODULE_NAME _mlirDialectsQuant + ADD_TO_PARENT MLIRPythonSources.Dialects.quant + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectQuant.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIQuant +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind + MODULE_NAME _mlirDialectsNVGPU + ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectNVGPU.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPINVGPU +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind + MODULE_NAME _mlirDialectsPDL + ADD_TO_PARENT MLIRPythonSources.Dialects.pdl + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectPDL.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIPDL +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind + MODULE_NAME _mlirDialectsSparseTensor + ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectSparseTensor.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPISparseTensor +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind + MODULE_NAME _mlirDialectsTransform + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectTransform.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPITransformDialect +) + +declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses + MODULE_NAME _mlirAsyncPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.async + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + AsyncPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIAsync +) + +if(MLIR_ENABLE_EXECUTION_ENGINE) + declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine + MODULE_NAME _mlirExecutionEngine + ADD_TO_PARENT MLIRPythonSources.ExecutionEngine + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + ExecutionEngineModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIExecutionEngine + ) +endif() + +declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses + MODULE_NAME _mlirGPUPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.gpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + GPUPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIGPU +) + +declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses + MODULE_NAME _mlirLinalgPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.linalg + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + LinalgPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPILinalg +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind + MODULE_NAME _mlirDialectsSMT + ADD_TO_PARENT MLIRPythonSources.Dialects.smt + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectSMT.cpp + # Headers must be included explicitly so they are installed. + NanobindUtils.h + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPISMT + MLIRCAPIExportSMTLIB +) + +declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses + MODULE_NAME _mlirSparseTensorPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + SparseTensorPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPISparseTensor +) + +declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter + MODULE_NAME _mlirTransformInterpreter + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + TransformInterpreter.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPITransformDialectTransforms +) + +# TODO: Figure out how to put this in the test tree. +# This should not be included in the main Python extension. However, +# putting it into MLIRPythonTestSources along with the dialect declaration +# above confuses Python module loader when running under lit. +set(_ADDL_TEST_SOURCES) +if(MLIR_INCLUDE_TESTS) + set(_ADDL_TEST_SOURCES MLIRPythonTestSources) + declare_mlir_python_sources(MLIRPythonTestSources) + declare_mlir_python_sources(MLIRPythonTestSources.Dialects + ADD_TO_PARENT MLIRPythonTestSources) + + # TODO: this uses a tablegen file from the test directory and should be + # decoupled from here. + declare_mlir_python_sources( + MLIRPythonTestSources.Dialects.PythonTest + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonTestSources.Dialects + SOURCES + dialects/python_test.py + ) + set(LLVM_TARGET_DEFINITIONS + "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") + mlir_tablegen( + "dialects/_python_test_ops_gen.py" + -gen-python-op-bindings + -bind-dialect=python_test) + add_public_tablegen_target(PythonTestDialectPyIncGen) + declare_mlir_python_sources( + MLIRPythonTestSources.Dialects.PythonTest.ops_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest + SOURCES "dialects/_python_test_ops_gen.py") + + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11 + MODULE_NAME _mlirPythonTestPybind11 + ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY pybind11 + SOURCES + PythonTestModulePybind11.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect + ) + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind + MODULE_NAME _mlirPythonTestNanobind + ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + PythonTestModuleNanobind.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect + ) +endif() + +################################################################################ +# Common CAPI dependency DSO. +# All python extensions must link through one DSO which exports the CAPI, and +# this must have a globally unique name amongst all embeddors of the python +# library since it will effectively have global scope. +# +# The presence of this aggregate library is part of the long term plan, but its +# use needs to be made more flexible. +# +# TODO: Upgrade to the aggregate utility in https://reviews.llvm.org/D106419 +# once ready. +################################################################################ + +add_mlir_python_common_capi_library(MLIRPythonCAPI + INSTALL_COMPONENT MLIRPythonModules + INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" + OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs" + RELATIVE_INSTALL_ROOT "../../../.." + DECLARED_HEADERS + MLIRPythonCAPI.HeaderSources + DECLARED_SOURCES + MLIRPythonSources + MLIRPythonExtension.RegisterEverything + ${_ADDL_TEST_SOURCES} +) + +################################################################################ +# Custom targets. +################################################################################ + +_flatten_mlir_python_targets(mlir_python_sources_deps MLIRPythonSources) +add_custom_target("mlir-python-sources" DEPENDS ${mlir_python_sources_deps}) +if(NOT LLVM_ENABLE_IDE) + add_llvm_install_targets(install-mlir-python-sources + DEPENDS mlir-python-sources + COMPONENT mlir-python-sources + ) +endif() + +################################################################################ +# The fully assembled package of modules. +# This must come last. +################################################################################ + +add_mlir_python_modules(MLIRPythonModules + ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir" + INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}" + DECLARED_SOURCES + MLIRPythonSources + MLIRPythonExtension.RegisterEverything + ${_ADDL_TEST_SOURCES} + COMMON_CAPI_LINK_LIBS + MLIRPythonCAPI +) + diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py new file mode 100644 index 000000000..083a9075f --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -0,0 +1,235 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Any, Sequence + +import os + +_this_dir = os.path.dirname(__file__) + + +def get_lib_dirs() -> Sequence[str]: + """Gets the lib directory for linking to shared libraries. + + On some platforms, the package may need to be built specially to export + development libraries. + """ + return [_this_dir] + + +def get_include_dirs() -> Sequence[str]: + """Gets the include directory for compiling against exported C libraries. + + Depending on how the package was build, development C libraries may or may + not be present. + """ + return [os.path.join(_this_dir, "include")] + + +# Perform Python level site initialization. This involves: +# 1. Attempting to load initializer modules, specific to the distribution. +# 2. Defining the concrete mlir.ir.Context that does site specific +# initialization. +# +# Aside from just being far more convenient to do this at the Python level, +# it is actually quite hard/impossible to have such __init__ hooks, given +# the pybind memory model (i.e. there is not a Python reference to the object +# in the scope of the base class __init__). +# +# For #1, we: +# a. Probe for modules named '_mlirRegisterEverything' and +# '_site_initialize_{i}', where 'i' is a number starting at zero and +# proceeding so long as a module with the name is found. +# b. If the module has a 'register_dialects' attribute, it will be called +# immediately with a DialectRegistry to populate. +# c. If the module has a 'context_init_hook', it will be added to a list +# of callbacks that are invoked as the last step of Context +# initialization (and passed the Context under construction). +# d. If the module has a 'disable_multithreading' attribute, it will be +# taken as a boolean. If it is True for any initializer, then the +# default behavior of enabling multithreading on the context +# will be suppressed. This complies with the original behavior of all +# contexts being created with multithreading enabled while allowing +# this behavior to be changed if needed (i.e. if a context_init_hook +# explicitly sets up multithreading). +# +# This facility allows downstreams to customize Context creation to their +# needs. + +_dialect_registry = None +_load_on_create_dialects = None + + +def get_dialect_registry(): + global _dialect_registry + + if _dialect_registry is None: + from ._mlir import ir + + _dialect_registry = ir.DialectRegistry() + + return _dialect_registry + + +def append_load_on_create_dialect(dialect: str): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [dialect] + else: + _load_on_create_dialects.append(dialect) + + +def get_load_on_create_dialects(): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [] + return _load_on_create_dialects + + +def _site_initialize(): + import importlib + import itertools + import logging + from ._mlir import ir + + logger = logging.getLogger(__name__) + post_init_hooks = [] + disable_multithreading = False + # This flag disables eagerly loading all dialects. Eagerly loading is often + # not the desired behavior (see + # https://github.com/llvm/llvm-project/issues/56037), and the logic is that + # if any module has this attribute set, then we don't load all (e.g., it's + # being used in a solution where the loading is controlled). + disable_load_all_available_dialects = False + + def process_initializer_module(module_name): + nonlocal disable_multithreading + nonlocal disable_load_all_available_dialects + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + except ImportError: + message = ( + f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function." + ) + logger.warning(message, exc_info=True) + return False + + logger.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logger.debug("Registering dialects from initializer %r", m) + m.register_dialects(get_dialect_registry()) + if hasattr(m, "context_init_hook"): + logger.debug("Adding context init hook from %r", m) + post_init_hooks.append(m.context_init_hook) + if hasattr(m, "disable_multithreading"): + if bool(m.disable_multithreading): + logger.debug("Disabling multi-threading for context") + disable_multithreading = True + if hasattr(m, "disable_load_all_available_dialects"): + disable_load_all_available_dialects = True + return True + + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + init_module = None + if process_initializer_module("_mlirRegisterEverything"): + init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not process_initializer_module(module_name): + break + + class Context(ir._BaseContext): + def __init__( + self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs + ): + super().__init__(*args, **kwargs) + self.append_dialect_registry(get_dialect_registry()) + for hook in post_init_hooks: + hook(self) + if disable_multithreading and thread_pool is not None: + raise ValueError( + "Context constructor has given thread_pool argument, " + "but disable_multithreading flag is True. " + "Please, set thread_pool argument to None or " + "set disable_multithreading flag to False." + ) + if not disable_multithreading: + if thread_pool is None: + self.enable_multithreading(True) + else: + self.set_thread_pool(thread_pool) + if load_on_create_dialects is not None: + logger.debug( + "Loading all dialects from load_on_create_dialects arg %r", + load_on_create_dialects, + ) + for dialect in load_on_create_dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + if disable_load_all_available_dialects: + dialects = get_load_on_create_dialects() + if dialects: + logger.debug( + "Loading all dialects from global load_on_create_dialects %r", + dialects, + ) + for dialect in dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + logger.debug("Loading all available dialects") + self.load_all_available_dialects() + if init_module: + logger.debug( + "Registering translations from initializer %r", init_module + ) + init_module.register_llvm_translations(self) + + ir.Context = Context + + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ":" + for diag in self.error_diagnostics: + s += ( + "\nerror: " + + str(diag.location)[4:-1] + + ": " + + diag.message.replace("\n", "\n ") + ) + for note in diag.notes: + s += ( + "\n note: " + + str(note.location)[4:-1] + + ": " + + note.message.replace("\n", "\n ") + ) + return s + + ir.MLIRError = MLIRError + + +_site_initialize() diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi new file mode 100644 index 000000000..03449b70b --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -0,0 +1,12 @@ + +globals: "_Globals" + +class _Globals: + dialect_search_modules: list[str] + def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... + def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... + def append_dialect_search_prefix(self, module_name: str) -> None: ... + def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... + +def register_dialect(dialect_class: type) -> type: ... +def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi new file mode 100644 index 000000000..d12c6839d --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi @@ -0,0 +1,63 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from mlir.ir import Type, Context + +__all__ = [ + 'PDLType', + 'AttributeType', + 'OperationType', + 'RangeType', + 'TypeType', + 'ValueType', +] + + +class PDLType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + +class AttributeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> AttributeType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> OperationType: ... + + +class RangeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(element_type: Type) -> RangeType: ... + + @property + def element_type(self) -> Type: ... + + +class TypeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> TypeType: ... + + +class ValueType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> ValueType: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi new file mode 100644 index 000000000..3f5304584 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -0,0 +1,142 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from mlir.ir import DenseElementsAttr, Type + +__all__ = [ + "QuantizedType", + "AnyQuantizedType", + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "CalibratedQuantizedType", +] + +class QuantizedType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @staticmethod + def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @property + def expressed_type(self) -> Type: ... + + @property + def flags(self) -> int: ... + + @property + def is_signed(self) -> bool: ... + + @property + def storage_type(self) -> Type: ... + + @property + def storage_type_min(self) -> int: ... + + @property + def storage_type_max(self) -> int: ... + + @property + def storage_type_integral_width(self) -> int: ... + + def is_compatible_expressed_type(self, candidate: Type) -> bool: ... + + @property + def quantized_element_type(self) -> Type: ... + + def cast_from_storage_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_storage_type(type: Type) -> Type: ... + + def cast_from_expressed_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_expressed_type(type: Type) -> Type: ... + + def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ... + + +class AnyQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + storage_type_min: int, storage_type_max: int) -> Type: + ... + + +class UniformQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scale: float, zero_point: int, storage_type_min: int, + storage_type_max: int) -> Type: ... + + @property + def scale(self) -> float: ... + + @property + def zero_point(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + + +class UniformQuantizedPerAxisType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: list[float], zero_points: list[int], quantized_dimension: int, + storage_type_min: int, storage_type_max: int): + ... + + @property + def scales(self) -> list[float]: ... + + @property + def zero_points(self) -> list[int]: ... + + @property + def quantized_dimension(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + +class UniformQuantizedSubChannelType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: DenseElementsAttr, zero_points: DenseElementsAttr, + quantized_dimensions: list[int], block_sizes: list[int], + storage_type_min: int, storage_type_max: int): + ... + + @property + def quantized_dimensions(self) -> list[int]: ... + + @property + def block_sizes(self) -> list[int]: ... + + @property + def scales(self) -> DenseElementsAttr: ... + + @property + def zero_points(self) -> DenseElementsAttr: ... + +def CalibratedQuantizedType(QuantizedType): + + @classmethod + def get(cls, expressed_type: Type, min: float, max: float): ... + + @property + def min(self) -> float: ... + + @property + def max(self) -> float: ... \ No newline at end of file diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi new file mode 100644 index 000000000..a3f1b0910 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi @@ -0,0 +1,25 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from mlir.ir import Type, Context + + +class AnyOpType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Context | None = None) -> AnyOpType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(operation_name: str, context: Context | None = None) -> OperationType: ... + + @property + def operation_name(self) -> str: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi new file mode 100644 index 000000000..dcae3dd74 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -0,0 +1,2846 @@ +# Originally imported via: +# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir +# but with the following diff (in order to remove pipes from types, +# which we won't support until bumping minimum python to 3.10) +# +# --------------------- diff begins ------------------------------------ +# +# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py +# index 1f755aa..4924927 100644 +# --- a/pybind11_stubgen/printer.py +# +++ b/pybind11_stubgen/printer.py +# @@ -283,14 +283,6 @@ class Printer: +# return split[0] + "..." +# +# def print_type(self, type_: ResolvedType) -> str: +# - if ( +# - str(type_.name) == "typing.Optional" +# - and type_.parameters is not None +# - and len(type_.parameters) == 1 +# - ): +# - return f"{self.print_annotation(type_.parameters[0])} | None" +# - if str(type_.name) == "typing.Union" and type_.parameters is not None: +# - return " | ".join(self.print_annotation(p) for p in type_.parameters) +# if type_.parameters: +# param_str = ( +# "[" +# +# --------------------- diff ends ------------------------------------ +# +# Local modifications: +# * Rewrite references to 'mlir.ir' to local types. +# * Drop `typing.` everywhere (top-level import instead). +# * List -> List, dict -> Dict, Tuple -> Tuple. +# * copy-paste Buffer type from typing_extensions. +# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top. +# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`. +# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute. +# * Local edits to signatures and types that pybind11-stubgen did not auto detect (or detected incorrectly). +# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand. +# * Fill in `Any`s where possible. +# * black formatting. + +from __future__ import annotations + +import abc +import collections +from collections.abc import Callable, Sequence +from pathlib import Path +from typing import Any, BinaryIO, ClassVar, Literal, TypeVar, overload + +__all__ = [ + "AffineAddExpr", + "AffineBinaryExpr", + "AffineCeilDivExpr", + "AffineConstantExpr", + "AffineDimExpr", + "AffineExpr", + "AffineExprList", + "AffineFloorDivExpr", + "AffineMap", + "AffineMapAttr", + "AffineModExpr", + "AffineMulExpr", + "AffineSymbolExpr", + "ArrayAttr", + "ArrayAttributeIterator", + "AsmState", + "AttrBuilder", + "Attribute", + "BF16Type", + "Block", + "BlockArgument", + "BlockArgumentList", + "BlockIterator", + "BlockList", + "BoolAttr", + "ComplexType", + "Context", + "DenseBoolArrayAttr", + "DenseBoolArrayIterator", + "DenseElementsAttr", + "DenseF32ArrayAttr", + "DenseF32ArrayIterator", + "DenseF64ArrayAttr", + "DenseF64ArrayIterator", + "DenseFPElementsAttr", + "DenseI16ArrayAttr", + "DenseI16ArrayIterator", + "DenseI32ArrayAttr", + "DenseI32ArrayIterator", + "DenseI64ArrayAttr", + "DenseI64ArrayIterator", + "DenseI8ArrayAttr", + "DenseI8ArrayIterator", + "DenseIntElementsAttr", + "DenseResourceElementsAttr", + "Diagnostic", + "DiagnosticHandler", + "DiagnosticInfo", + "DiagnosticSeverity", + "Dialect", + "DialectDescriptor", + "DialectRegistry", + "Dialects", + "DictAttr", + "F16Type", + "F32Type", + "F64Type", + "FlatSymbolRefAttr", + "Float4E2M1FNType", + "Float6E2M3FNType", + "Float6E3M2FNType", + "Float8E3M4Type", + "Float8E4M3B11FNUZType", + "Float8E4M3FNType", + "Float8E4M3FNUZType", + "Float8E4M3Type", + "Float8E5M2FNUZType", + "Float8E5M2Type", + "Float8E8M0FNUType", + "FloatAttr", + "FloatTF32Type", + "FloatType", + "FunctionType", + "IndexType", + "InferShapedTypeOpInterface", + "InferTypeOpInterface", + "InsertionPoint", + "IntegerAttr", + "IntegerSet", + "IntegerSetAttr", + "IntegerSetConstraint", + "IntegerSetConstraintList", + "IntegerType", + "Location", + "MemRefType", + "Module", + "NamedAttribute", + "NoneType", + "OpAttributeMap", + "OpOperand", + "OpOperandIterator", + "OpOperandList", + "OpResult", + "OpResultList", + "OpSuccessors", + "OpView", + "OpaqueAttr", + "OpaqueType", + "Operation", + "OperationIterator", + "OperationList", + "RankedTensorType", + "Region", + "RegionIterator", + "RegionSequence", + "ShapedType", + "ShapedTypeComponents", + "StridedLayoutAttr", + "StringAttr", + "SymbolRefAttr", + "SymbolTable", + "TupleType", + "Type", + "TypeAttr", + "TypeID", + "UnitAttr", + "UnrankedMemRefType", + "UnrankedTensorType", + "Value", + "VectorType", + "_GlobalDebug", + "_OperationBase", +] + +if hasattr(collections.abc, "Buffer"): + Buffer = collections.abc.Buffer +else: + class Buffer(abc.ABC): + pass + +class _OperationBase: + @overload + def __eq__(self, arg0: _OperationBase) -> bool: ... + @overload + def __eq__(self, arg0: _OperationBase) -> bool: ... + def __hash__(self) -> int: ... + def __str__(self) -> str: + """ + Returns the assembly form of the operation. + """ + def clone(self, ip: InsertionPoint = None) -> OpView: ... + def detach_from_parent(self) -> OpView: + """ + Detaches the operation from its parent block. + """ + + @property + def attached(self) -> bool: + """ + Reports if the operation is attached to its parent block. + """ + + def erase(self) -> None: ... + + @overload + def get_asm( + binary: Literal[True], + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + skip_regions: bool = False, + ) -> bytes: ... + @overload + def get_asm( + self, + binary: bool = False, + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + skip_regions: bool = False, + ) -> str: + """ + Returns the assembly form of the operation. + + See the print() method for common keyword arguments for configuring + the output. + """ + + def move_after(self, other: _OperationBase) -> None: + """ + Puts self immediately after the other operation in its parent block. + """ + def move_before(self, other: _OperationBase) -> None: + """ + Puts self immediately before the other operation in its parent block. + """ + @overload + def print( + self, + state: AsmState, + file: Any | None = None, + binary: bool = False, + ) -> None: + """ + Prints the assembly form of the operation to a file like object. + + Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + state: AsmState capturing the operation numbering and flags. + """ + @overload + def print( + self, + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + file: Any | None = None, + binary: bool = False, + skip_regions: bool = False, + ) -> None: + """ + Prints the assembly form of the operation to a file like object. + + Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource strings above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_Scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. + skip_regions: Whether to skip printing regions. Defaults to False. + """ + def verify(self) -> bool: + """ + Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. + """ + def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None: + """ + Write the bytecode form of the operation to a file like object. + + Args: + file: The file like object or path to write to. + desired_version: The version of bytecode to emit. + Returns: + The bytecode writer status. + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def attributes(self) -> OpAttributeMap: ... + @property + def context(self) -> Context: + """ + Context that owns the Operation + """ + @property + def location(self) -> Location: + """ + Returns the source location the operation was defined or derived from. + """ + @property + def name(self) -> str: ... + @property + def operands(self) -> OpOperandList: ... + @property + def parent(self) -> _OperationBase | None: ... + @property + def regions(self) -> RegionSequence: ... + @property + def result(self) -> OpResult: + """ + Shortcut to get an op result if it has only one (throws an error otherwise). + """ + @property + def results(self) -> OpResultList: + """ + Returns the List of Operation results. + """ + +_TOperation = TypeVar("_TOperation", bound=_OperationBase) + +class AffineExpr: + @staticmethod + @overload + def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of two expressions. + """ + @staticmethod + @overload + def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of a constant and another expression. + """ + @staticmethod + @overload + def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of an expression and a constant. + """ + @staticmethod + @overload + def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: + """ + Gets an affine expression containing the rounded-up result of dividing one expression by another. + """ + @staticmethod + @overload + def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr: + """ + Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr: + """ + Gets an affine expression containing the rounded-up result of dividing an expression by a constant. + """ + @staticmethod + def get_constant( + value: int, context: Context | None = None + ) -> AffineConstantExpr: + """ + Gets a constant affine expression with the given value. + """ + @staticmethod + def get_dim(position: int, context: Context | None = None) -> AffineDimExpr: + """ + Gets an affine expression of a dimension at the given position. + """ + @staticmethod + @overload + def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: + """ + Gets an affine expression containing the rounded-down result of dividing one expression by another. + """ + @staticmethod + @overload + def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr: + """ + Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr: + """ + Gets an affine expression containing the rounded-down result of dividing an expression by a constant. + """ + @staticmethod + @overload + def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: + """ + Gets an affine expression containing the modulo of dividing one expression by another. + """ + @staticmethod + @overload + def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr: + """ + Gets a semi-affine expression containing the modulo of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr: + """ + Gets an affine expression containing the module of dividingan expression by a constant. + """ + @staticmethod + @overload + def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: + """ + Gets an affine expression containing a product of two expressions. + """ + @staticmethod + @overload + def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr: + """ + Gets an affine expression containing a product of a constant and another expression. + """ + @staticmethod + @overload + def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr: + """ + Gets an affine expression containing a product of an expression and a constant. + """ + @staticmethod + def get_symbol( + position: int, context: Context | None = None + ) -> AffineSymbolExpr: + """ + Gets an affine expression of a symbol at the given position. + """ + def _CAPICreate(self) -> AffineExpr: ... + @overload + def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ... + @overload + def __add__(self, arg0: int) -> AffineAddExpr: ... + @overload + def __eq__(self, arg0: AffineExpr) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + @overload + def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ... + @overload + def __mod__(self, arg0: int) -> AffineModExpr: ... + @overload + def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ... + @overload + def __mul__(self, arg0: int) -> AffineMulExpr: ... + def __radd__(self, arg0: int) -> AffineAddExpr: ... + def __rmod__(self, arg0: int) -> AffineModExpr: ... + def __rmul__(self, arg0: int) -> AffineMulExpr: ... + def __rsub__(self, arg0: int) -> AffineAddExpr: ... + @overload + def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ... + @overload + def __sub__(self, arg0: int) -> AffineAddExpr: ... + def compose(self, arg0: AffineMap) -> AffineExpr: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: ... + +class Attribute: + @staticmethod + def parse(asm: str | bytes, context: Context | None = None) -> Attribute: + """ + Parses an attribute from an assembly form. Raises an MLIRError on failure. + """ + def _CAPICreate(self) -> Attribute: ... + @overload + def __eq__(self, arg0: Attribute) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __init__(self, cast_from_type: Attribute) -> None: + """ + Casts the passed attribute to the generic Attribute + """ + def __str__(self) -> str: + """ + Returns the assembly form of the Attribute. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_named(self, arg0: str) -> NamedAttribute: + """ + Binds a name to the attribute + """ + def maybe_downcast(self) -> Any: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context that owns the Attribute + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class Type: + @staticmethod + def parse(asm: str | bytes, context: Context | None = None) -> Type: + """ + Parses the assembly form of a type. + + Returns a Type object or raises an MLIRError if the type cannot be parsed. + + See also: https://mlir.llvm.org/docs/LangRef/#type-system + """ + def _CAPICreate(self) -> Type: ... + @overload + def __eq__(self, arg0: Type) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __init__(self, cast_from_type: Type) -> None: + """ + Casts the passed type to the generic Type + """ + def __str__(self) -> str: + """ + Returns the assembly form of the type. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def maybe_downcast(self) -> Any: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context that owns the Type + """ + @property + def typeid(self) -> TypeID: ... + +class Value: + def _CAPICreate(self) -> Value: ... + @overload + def __eq__(self, arg0: Value) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __init__(self, value: Value) -> None: ... + def __str__(self) -> str: + """ + Returns the string form of the value. + + If the value is a block argument, this is the assembly form of its type and the + position in the argument List. If the value is an operation result, this is + equivalent to printing the operation that produced it. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @overload + def get_name(self, use_local_scope: bool = False, use_name_loc_as_prefix: bool = True) -> str: ... + @overload + def get_name(self, state: AsmState) -> str: + """ + Returns the string form of value as an operand (i.e., the ValueID). + """ + def maybe_downcast(self) -> Any: ... + def replace_all_uses_with(self, arg0: Value) -> None: + """ + Replace all uses of value with the new value, updating anything in + the IR that uses 'self' to use the other value instead. + """ + def set_type(self, type: Type) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context in which the value lives. + """ + @property + def owner(self) -> _OperationBase: ... + @property + def type(self) -> Type: ... + @property + def uses(self) -> OpOperandIterator: ... + +class AffineAddExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineBinaryExpr(AffineExpr): + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def lhs(self) -> AffineExpr: ... + @property + def rhs(self) -> AffineExpr: ... + +class AffineCeilDivExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineConstantExpr(AffineExpr): + @staticmethod + def get(value: int, context: Context | None = None) -> AffineConstantExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def value(self) -> int: ... + +class AffineDimExpr(AffineExpr): + @staticmethod + def get(position: int, context: Context | None = None) -> AffineDimExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def position(self) -> int: ... + +class AffineExprList: + def __add__(self, arg0: AffineExprList) -> list[AffineExpr]: ... + +class AffineFloorDivExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineMap: + @staticmethod + def compress_unused_symbols( + arg0: list, arg1: Context | None + ) -> list[AffineMap]: ... + @staticmethod + def get( + dim_count: int, + symbol_count: int, + exprs: list, + context: Context | None = None, + ) -> AffineMap: + """ + Gets a map with the given expressions as results. + """ + @staticmethod + def get_constant(value: int, context: Context | None = None) -> AffineMap: + """ + Gets an affine map with a single constant result + """ + @staticmethod + def get_empty(context: Context | None = None) -> AffineMap: + """ + Gets an empty affine map. + """ + @staticmethod + def get_identity(n_dims: int, context: Context | None = None) -> AffineMap: + """ + Gets an identity map with the given number of dimensions. + """ + @staticmethod + def get_minor_identity( + n_dims: int, n_results: int, context: Context | None = None + ) -> AffineMap: + """ + Gets a minor identity map with the given number of dimensions and results. + """ + @staticmethod + def get_permutation( + permutation: list[int], context: Context | None = None + ) -> AffineMap: + """ + Gets an affine map that permutes its inputs. + """ + def _CAPICreate(self) -> AffineMap: ... + @overload + def __eq__(self, arg0: AffineMap) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_major_submap(self, n_results: int) -> AffineMap: ... + def get_minor_submap(self, n_results: int) -> AffineMap: ... + def get_submap(self, result_positions: list[int]) -> AffineMap: ... + def replace( + self, + expr: AffineExpr, + replacement: AffineExpr, + n_result_dims: int, + n_result_syms: int, + ) -> AffineMap: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Context: + """ + Context that owns the Affine Map + """ + @property + def is_permutation(self) -> bool: ... + @property + def is_projected_permutation(self) -> bool: ... + @property + def n_dims(self) -> int: ... + @property + def n_inputs(self) -> int: ... + @property + def n_symbols(self) -> int: ... + @property + def results(self) -> AffineMapExprList: ... + +class AffineMapAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(affine_map: AffineMap) -> AffineMapAttr: + """ + Gets an attribute wrapping an AffineMap. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class AffineModExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineMulExpr(AffineBinaryExpr): + @staticmethod + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + +class AffineSymbolExpr(AffineExpr): + @staticmethod + def get(position: int, context: Context | None = None) -> AffineSymbolExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... + @property + def position(self) -> int: ... + +class ArrayAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(attributes: list, context: Context | None = None) -> ArrayAttr: + """ + Gets a uniqued Array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> ArrayAttr: ... + def __getitem__(self, arg0: int) -> Attribute: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> ArrayAttributeIterator: ... + def __len__(self) -> int: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class ArrayAttributeIterator: + def __iter__(self) -> ArrayAttributeIterator: ... + def __next__(self) -> Attribute: ... + +class AsmState: + @overload + def __init__(self, value: Value, use_local_scope: bool = False) -> None: ... + @overload + def __init__(self, op: _OperationBase, use_local_scope: bool = False) -> None: ... + +class AttrBuilder: + @staticmethod + def contains(arg0: str) -> bool: ... + @staticmethod + def get(arg0: str) -> Callable: ... + @staticmethod + def insert( + attribute_kind: str, attr_builder: Callable, replace: bool = False + ) -> None: + """ + Register an attribute builder for building MLIR attributes from python values. + """ + +class BF16Type(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> BF16Type: + """ + Create a bf16 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Block: + @staticmethod + def create_at_start( + parent: Region, + arg_types: list[Type], + arg_locs: Sequence | None = None, + ) -> Block: + """ + Creates and returns a new Block at the beginning of the given region (with given argument types and locations). + """ + @overload + def __eq__(self, arg0: Block) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + def __iter__(self) -> OperationIterator: + """ + Iterates over operations in the block. + """ + def __str__(self) -> str: + """ + Returns the assembly form of the block. + """ + def append(self, operation: _OperationBase) -> None: + """ + Appends an operation to this block. If the operation is currently in another block, it will be moved. + """ + def append_to(self, arg0: Region) -> None: + """ + Append this block to a region, transferring ownership if necessary + """ + def create_after(self, *args, arg_locs: Sequence | None = None) -> Block: + """ + Creates and returns a new Block after this block (with given argument types and locations). + """ + def create_before(self, *args, arg_locs: Sequence | None = None) -> Block: + """ + Creates and returns a new Block before this block (with given argument types and locations). + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def arguments(self) -> BlockArgumentList: + """ + Returns a List of block arguments. + """ + @property + def operations(self) -> OperationList: + """ + Returns a forward-optimized sequence of operations. + """ + @property + def owner(self) -> OpView: + """ + Returns the owning operation of this block. + """ + @property + def region(self) -> Region: + """ + Returns the owning region of this block. + """ + +class BlockArgument(Value): + @staticmethod + def isinstance(other_value: Value) -> bool: ... + def __init__(self, value: Value) -> None: ... + def maybe_downcast(self) -> Any: ... + def set_type(self, type: Type) -> None: ... + @property + def arg_number(self) -> int: ... + @property + def owner(self) -> Block: ... + +class BlockArgumentList: + @overload + def __getitem__(self, arg0: int) -> BlockArgument: ... + @overload + def __getitem__(self, arg0: slice) -> BlockArgumentList: ... + def __len__(self) -> int: ... + def __add__(self, arg0: BlockArgumentList) -> list[BlockArgument]: ... + @property + def types(self) -> list[Type]: ... + +class BlockIterator: + def __iter__(self) -> BlockIterator: ... + def __next__(self) -> Block: ... + +class BlockList: + def __getitem__(self, arg0: int) -> Block: ... + def __iter__(self) -> BlockIterator: ... + def __len__(self) -> int: ... + def append(self, *args, arg_locs: Sequence | None = None) -> Block: + """ + Appends a new block, with argument types as positional args. + + Returns: + The created block. + """ + +class BoolAttr(Attribute): + @staticmethod + def get(value: bool, context: Context | None = None) -> BoolAttr: + """ + Gets an uniqued bool attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __bool__(self: Attribute) -> bool: + """ + Converts the value of the bool attribute to a Python bool + """ + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> bool: + """ + Returns the value of the bool attribute + """ + +class ComplexType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(arg0: Type) -> ComplexType: + """ + Create a complex type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def element_type(self) -> Type: + """ + Returns element type. + """ + @property + def typeid(self) -> TypeID: ... + +class Context: + current: ClassVar[Context] = ... # read-only + allow_unregistered_dialects: bool + @staticmethod + def _get_live_count() -> int: ... + def _CAPICreate(self) -> object: ... + def __enter__(self) -> Context: ... + def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... + def __init__(self) -> None: ... + def _clear_live_operations(self) -> int: ... + def _get_context_again(self) -> Context: ... + def _get_live_module_count(self) -> int: ... + def _get_live_operation_count(self) -> int: ... + def _get_live_operation_objects(self) -> list[Operation]: ... + def append_dialect_registry(self, registry: DialectRegistry) -> None: ... + def attach_diagnostic_handler( + self, callback: Callable[[Diagnostic], bool] + ) -> DiagnosticHandler: + """ + Attaches a diagnostic handler that will receive callbacks + """ + def enable_multithreading(self, enable: bool) -> None: ... + def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: + """ + Gets or loads a dialect by name, returning its descriptor object + """ + def is_registered_operation(self, operation_name: str) -> bool: ... + def load_all_available_dialects(self) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def d(self) -> Dialects: + """ + Alias for 'dialect' + """ + @property + def dialects(self) -> Dialects: + """ + Gets a container for accessing dialects by name + """ + +class DenseBoolArrayAttr(Attribute): + @staticmethod + def get( + values: Sequence[bool], context: Context | None = None + ) -> DenseBoolArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseBoolArrayAttr: ... + def __getitem__(self, arg0: int) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseBoolArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseBoolArrayIterator: + def __iter__(self) -> DenseBoolArrayIterator: ... + def __next__(self) -> bool: ... + +class DenseElementsAttr(Attribute): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, + ) -> DenseElementsAttr: + """ + Gets a DenseElementsAttr from a Python buffer or array. + + When `type` is not provided, then some limited type inferencing is done based + on the buffer format. Support presently exists for 8/16/32/64 signed and + unsigned integers and float16/float32/float64. DenseElementsAttrs of these + types can also be converted back to a corresponding buffer. + + For conversions outside of these types, a `type=` must be explicitly provided + and the buffer contents must be bit-castable to the MLIR internal + representation: + + * Integer types (except for i1): the buffer must be byte aligned to the + next byte boundary. + * Floating point types: Must be bit-castable to the given floating point + size. + * i1 (bool): Bit packed into 8bit words where the bit pattern matches a + row major ordering. An arbitrary Numpy `bool_` array can be bit packed to + this specification with: `np.packbits(ary, axis=None, bitorder='little')`. + + If a single element buffer is passed (or for i1, a single byte with value 0 + or 255), then a splat will be created. + + Args: + array: The array or buffer to convert. + signless: If inferring an appropriate MLIR type, use signless types for + integers (defaults True). + type: Skips inference of the MLIR element type and uses this instead. The + storage size must be consistent with the actual contents of the buffer. + shape: Overrides the shape of the buffer when constructing the MLIR + shaped type. This is needed when the physical and logical shape differ (as + for i1). + context: Explicit context, if not from context manager. + + Returns: + DenseElementsAttr on success. + + Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. + """ + @staticmethod + def get_splat(shaped_type: Type, element_attr: Attribute) -> DenseElementsAttr: + """ + Gets a DenseElementsAttr where all values are the same + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __len__(self) -> int: ... + def get_splat_value(self) -> Attribute: ... + @property + def is_splat(self) -> bool: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseF32ArrayAttr(Attribute): + @staticmethod + def get( + values: Sequence[float], context: Context | None = None + ) -> DenseF32ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseF32ArrayAttr: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseF32ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseF32ArrayIterator: + def __iter__(self) -> DenseF32ArrayIterator: ... + def __next__(self) -> float: ... + +class DenseF64ArrayAttr(Attribute): + @staticmethod + def get( + values: Sequence[float], context: Context | None = None + ) -> DenseF64ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseF64ArrayAttr: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseF64ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseF64ArrayIterator: + def __iter__(self) -> DenseF64ArrayIterator: ... + def __next__(self) -> float: ... + +class DenseFPElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, + ) -> DenseFPElementsAttr: ... + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI16ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI16ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI16ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI16ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI16ArrayIterator: + def __iter__(self) -> DenseI16ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI32ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI32ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI32ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI32ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI32ArrayIterator: + def __iter__(self) -> DenseI32ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI64ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI64ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI64ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI16ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI64ArrayIterator: + def __iter__(self) -> DenseI64ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI8ArrayAttr(Attribute): + @staticmethod + def get(values: Sequence[int], context: Context | None = None) -> DenseI8ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: list) -> DenseI8ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI8ArrayIterator: ... + def __len__(self) -> int: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI8ArrayIterator: + def __iter__(self) -> DenseI8ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseIntElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, + ) -> DenseIntElementsAttr: ... + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseResourceElementsAttr(Attribute): + @staticmethod + def get_from_buffer( + array: Buffer, + name: str, + type: Type, + alignment: int | None = None, + is_mutable: bool = False, + context: Context | None = None, + ) -> DenseResourceElementsAttr: + """ + Gets a DenseResourceElementsAttr from a Python buffer or array. + + This function does minimal validation or massaging of the data, and it is + up to the caller to ensure that the buffer meets the characteristics + implied by the shape. + + The backing buffer and any user objects will be retained for the lifetime + of the resource blob. This is typically bounded to the context but the + resource can have a shorter lifespan depending on how it is used in + subsequent processing. + + Args: + buffer: The array or buffer to convert. + name: Name to provide to the resource (may be changed upon collision). + type: The explicit ShapedType to construct the attribute with. + context: Explicit context, if not from context manager. + + Returns: + DenseResourceElementsAttr on success. + + Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class Diagnostic: + @property + def location(self) -> Location: ... + @property + def message(self) -> str: ... + @property + def notes(self) -> tuple[Diagnostic]: ... + @property + def severity(self) -> DiagnosticSeverity: ... + +class DiagnosticHandler: + def __enter__(self) -> DiagnosticHandler: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + def detach(self) -> None: ... + @property + def attached(self) -> bool: ... + @property + def had_error(self) -> bool: ... + +class DiagnosticInfo: + def __init__(self, arg0: Diagnostic) -> None: ... + @property + def location(self) -> Location: ... + @property + def message(self) -> str: ... + @property + def notes(self) -> list[DiagnosticInfo]: ... + @property + def severity(self) -> DiagnosticSeverity: ... + +class DiagnosticSeverity: + """ + Members: + + ERROR + + WARNING + + NOTE + + REMARK + """ + + ERROR: ClassVar[DiagnosticSeverity] # value = + NOTE: ClassVar[DiagnosticSeverity] # value = + REMARK: ClassVar[DiagnosticSeverity] # value = + WARNING: ClassVar[DiagnosticSeverity] # value = + __members__: ClassVar[ + dict[str, DiagnosticSeverity] + ] # value = {'ERROR': , 'WARNING': , 'NOTE': , 'REMARK': } + def __eq__(self, other: Any) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: int) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: Any) -> bool: ... + def __setstate__(self, state: int) -> None: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class Dialect: + def __init__(self, descriptor: DialectDescriptor) -> None: ... + @property + def descriptor(self) -> DialectDescriptor: ... + +class DialectDescriptor: + @property + def namespace(self) -> str: ... + +class DialectRegistry: + def _CAPICreate(self) -> DialectRegistry: ... + def __init__(self) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + +class Dialects: + def __getattr__(self, arg0: str) -> Dialect: ... + def __getitem__(self, arg0: str) -> Dialect: ... + +class DictAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(value: dict = {}, context: Context | None = None) -> DictAttr: + """ + Gets an uniqued Dict attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __contains__(self, arg0: str) -> bool: ... + @overload + def __getitem__(self, arg0: str) -> Attribute: ... + @overload + def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __len__(self) -> int: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class FloatType(Type): + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def width(self) -> int: + """ + Returns the width of the floating-point type. + """ + +class F16Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> F16Type: + """ + Create a f16 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class F32Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> F32Type: + """ + Create a f32 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class F64Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> F64Type: + """ + Create a f64 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class FlatSymbolRefAttr(Attribute): + @staticmethod + def get(value: str, context: Context | None = None) -> FlatSymbolRefAttr: + """ + Gets a uniqued FlatSymbolRef attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> str: + """ + Returns the value of the FlatSymbolRef attribute as a string + """ + +class Float4E2M1FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float4E2M1FNType: + """ + Create a float4_e2m1fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float6E2M3FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float6E2M3FNType: + """ + Create a float6_e2m3fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float6E3M2FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float6E3M2FNType: + """ + Create a float6_e3m2fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E3M4Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E3M4Type: + """ + Create a float8_e3m4 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3B11FNUZType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3B11FNUZType: + """ + Create a float8_e4m3b11fnuz type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3FNType: + """ + Create a float8_e4m3fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3FNUZType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3FNUZType: + """ + Create a float8_e4m3fnuz type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E4M3Type: + """ + Create a float8_e4m3 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E5M2FNUZType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E5M2FNUZType: + """ + Create a float8_e5m2fnuz type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E5M2Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E5M2Type: + """ + Create a float8_e5m2 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E8M0FNUType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E8M0FNUType: + """ + Create a float8_e8m0fnu type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class FloatAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(type: Type, value: float, loc: Location | None = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a type + """ + @staticmethod + def get_f32(value: float, context: Context | None = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a f32 type + """ + @staticmethod + def get_f64(value: float, context: Context | None = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a f64 type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __float__(self: Attribute) -> float: + """ + Converts the value of the float attribute to a Python float + """ + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> float: + """ + Returns the value of the float attribute + """ + +class FloatTF32Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> FloatTF32Type: + """ + Create a tf32 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class FunctionType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + inputs: list[Type], results: list[Type], context: Context | None = None + ) -> FunctionType: + """ + Gets a FunctionType from a List of input and result types + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def inputs(self) -> list: + """ + Returns the List of input types in the FunctionType. + """ + @property + def results(self) -> list: + """ + Returns the List of result types in the FunctionType. + """ + @property + def typeid(self) -> TypeID: ... + +class IndexType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> IndexType: + """ + Create a index type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class InferShapedTypeOpInterface: + def __init__(self, object: object, context: Context | None = None) -> None: + """ + Creates an interface from a given operation/opview object or from a + subclass of OpView. Raises ValueError if the operation does not implement the + interface. + """ + def inferReturnTypeComponents( + self, + operands: list | None = None, + attributes: Attribute | None = None, + properties=None, + regions: list[Region] | None = None, + context: Context | None = None, + loc: Location | None = None, + ) -> list[ShapedTypeComponents]: + """ + Given the arguments required to build an operation, attempts to infer + its return shaped type components. Raises ValueError on failure. + """ + @property + def operation(self) -> Operation: + """ + Returns an Operation for which the interface was constructed. + """ + @property + def opview(self) -> OpView: + """ + Returns an OpView subclass _instance_ for which the interface was + constructed + """ + +class InferTypeOpInterface: + def __init__(self, object: object, context: Context | None = None) -> None: + """ + Creates an interface from a given operation/opview object or from a + subclass of OpView. Raises ValueError if the operation does not implement the + interface. + """ + def inferReturnTypes( + self, + operands: list | None = None, + attributes: Attribute | None = None, + properties=None, + regions: list[Region] | None = None, + context: Context | None = None, + loc: Location | None = None, + ) -> list[Type]: + """ + Given the arguments required to build an operation, attempts to infer + its return types. Raises ValueError on failure. + """ + @property + def operation(self) -> Operation: + """ + Returns an Operation for which the interface was constructed. + """ + @property + def opview(self) -> OpView: + """ + Returns an OpView subclass _instance_ for which the interface was + constructed + """ + +class InsertionPoint: + current: ClassVar[InsertionPoint] = ... # read-only + @staticmethod + def at_block_begin(block: Block) -> InsertionPoint: + """ + Inserts at the beginning of the block. + """ + @staticmethod + def at_block_terminator(block: Block) -> InsertionPoint: + """ + Inserts before the block terminator. + """ + def __enter__(self) -> InsertionPoint: ... + def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... + @overload + def __init__(self, block: Block) -> None: + """ + Inserts after the last operation but still inside the block. + """ + @overload + def __init__(self, beforeOperation: _OperationBase) -> None: + """ + Inserts before a referenced operation. + """ + def insert(self, operation: _OperationBase) -> None: + """ + Inserts an operation. + """ + @property + def block(self) -> Block: + """ + Returns the block that this InsertionPoint points to. + """ + @property + def ref_operation(self) -> _OperationBase | None: + """ + The reference operation before which new operations are inserted, or None if the insertion point is at the end of the block + """ + +class IntegerAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(type: Type, value: int) -> IntegerAttr: + """ + Gets an uniqued integer attribute associated to a type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __int__(self) -> int: + """ + Converts the value of the integer attribute to a Python int + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> int: + """ + Returns the value of the integer attribute + """ + +class IntegerSet: + @staticmethod + def get( + num_dims: int, + num_symbols: int, + exprs: list, + eq_flags: list[bool], + context: Context | None = None, + ) -> IntegerSet: ... + @staticmethod + def get_empty( + num_dims: int, num_symbols: int, context: Context | None = None + ) -> IntegerSet: ... + def _CAPICreate(self) -> IntegerSet: ... + @overload + def __eq__(self, arg0: IntegerSet) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_replaced( + self, + dim_exprs: list, + symbol_exprs: list, + num_result_dims: int, + num_result_symbols: int, + ) -> IntegerSet: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def constraints(self) -> IntegerSetConstraintList: ... + @property + def context(self) -> Context: ... + @property + def is_canonical_empty(self) -> bool: ... + @property + def n_dims(self) -> int: ... + @property + def n_equalities(self) -> int: ... + @property + def n_inequalities(self) -> int: ... + @property + def n_inputs(self) -> int: ... + @property + def n_symbols(self) -> int: ... + +class IntegerSetAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(integer_set) -> IntegerSetAttr: + """ + Gets an attribute wrapping an IntegerSet. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class IntegerSetConstraint: + def __init__(self, *args, **kwargs) -> None: ... + @property + def expr(self) -> AffineExpr: ... + @property + def is_eq(self) -> bool: ... + +class IntegerSetConstraintList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: IntegerSetConstraintList) -> list[IntegerSetConstraint]: ... + @overload + def __getitem__(self, arg0: int) -> IntegerSetConstraint: ... + @overload + def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ... + def __len__(self) -> int: ... + +class IntegerType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get_signed(width: int, context: Context | None = None) -> IntegerType: + """ + Create a signed integer type + """ + @staticmethod + def get_signless(width: int, context: Context | None = None) -> IntegerType: + """ + Create a signless integer type + """ + @staticmethod + def get_unsigned(width: int, context: Context | None = None) -> IntegerType: + """ + Create an unsigned integer type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def is_signed(self) -> bool: + """ + Returns whether this is a signed integer + """ + @property + def is_signless(self) -> bool: + """ + Returns whether this is a signless integer + """ + @property + def is_unsigned(self) -> bool: + """ + Returns whether this is an unsigned integer + """ + @property + def typeid(self) -> TypeID: ... + @property + def width(self) -> int: + """ + Returns the width of the integer type + """ + +class Location: + current: ClassVar[Location] = ... # read-only + __hash__: ClassVar[None] = None + @staticmethod + def callsite( + callee: Location, frames: Sequence[Location], context: Context | None = None + ) -> Location: + """ + Gets a Location representing a caller and callsite + """ + @staticmethod + def file( + filename: str, line: int, col: int, context: Context | None = None + ) -> Location: + """ + Gets a Location representing a file, line and column + """ + @staticmethod + def from_attr(attribute: Attribute, context: Context | None = None) -> Location: + """ + Gets a Location from a LocationAttr + """ + @staticmethod + def fused( + locations: Sequence[Location], + metadata: Attribute | None = None, + context: Context | None = None, + ) -> Location: + """ + Gets a Location representing a fused location with optional metadata + """ + @staticmethod + def name( + name: str, + childLoc: Location | None = None, + context: Context | None = None, + ) -> Location: + """ + Gets a Location representing a named location with optional child location + """ + @staticmethod + def unknown(context: Context | None = None) -> Location: + """ + Gets a Location representing an unknown location + """ + def _CAPICreate(self) -> Location: ... + def __enter__(self) -> Location: ... + @overload + def __eq__(self, arg0: Location) -> bool: ... + @overload + def __eq__(self, arg0: Location) -> bool: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + def emit_error(self, message: str) -> None: + """ + Emits an error at this location + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def attr(self) -> Attribute: + """ + Get the underlying LocationAttr + """ + @property + def context(self) -> Context: + """ + Context that owns the Location + """ + +class MemRefType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + shape: list[int], + element_type: Type, + layout: Attribute = None, + memory_space: Attribute = None, + loc: Location | None = None, + ) -> MemRefType: + """ + Create a memref type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def affine_map(self) -> AffineMap: + """ + The layout of the MemRef type as an affine map. + """ + @property + def layout(self) -> Attribute: + """ + The layout of the MemRef type. + """ + @property + def memory_space(self) -> Attribute | None: + """ + Returns the memory space of the given MemRef type. + """ + @property + def typeid(self) -> TypeID: ... + def get_strides_and_offset(self) -> tuple[list[int], int]: + """ + The strides and offset of the MemRef type. + """ + +class Module: + @staticmethod + def create(loc: Location | None = None) -> Module: + """ + Creates an empty module + """ + @staticmethod + def parse(asm: str | bytes, context: Context | None = None) -> Module: + """ + Parses a module's assembly format from a string. + + Returns a new MlirModule or raises an MLIRError if the parsing fails. + + See also: https://mlir.llvm.org/docs/LangRef/ + """ + @staticmethod + def parseFile(path: str, context: Context | None = None) -> Module: + """ + Parses a module's assembly format from file. + + Returns a new MlirModule or raises an MLIRError if the parsing fails. + + See also: https://mlir.llvm.org/docs/LangRef/ + """ + def _CAPICreate(self) -> Any: ... + def __str__(self) -> str: + """ + Gets the assembly form of the operation with default options. + + If more advanced control over the assembly formatting or I/O options is needed, + use the dedicated print or get_asm method, which supports keyword arguments to + customize behavior. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @property + def _CAPIPtr(self) -> object: ... + @property + def body(self) -> Block: + """ + Return the block for this module + """ + @property + def context(self) -> Context: + """ + Context that created the Module + """ + @property + def operation(self) -> Operation: + """ + Accesses the module as an operation + """ + +class MLIRError(Exception): + def __init__( + self, message: str, error_diagnostics: list[DiagnosticInfo] + ) -> None: ... + +class NamedAttribute: + @property + def attr(self) -> Attribute: + """ + The underlying generic attribute of the NamedAttribute binding + """ + @property + def name(self) -> str: + """ + The name of the NamedAttribute binding + """ + +class NoneType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> NoneType: + """ + Create a none type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class OpAttributeMap: + def __contains__(self, arg0: str) -> bool: ... + def __delitem__(self, arg0: str) -> None: ... + @overload + def __getitem__(self, arg0: str) -> Attribute: ... + @overload + def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __len__(self) -> int: ... + def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... + +class OpOperand: + @property + def operand_number(self) -> int: ... + @property + def owner(self) -> _OperationBase: ... + +class OpOperandIterator: + def __iter__(self) -> OpOperandIterator: ... + def __next__(self) -> OpOperand: ... + +class OpOperandList: + def __add__(self, arg0: OpOperandList) -> list[Value]: ... + @overload + def __getitem__(self, arg0: int) -> Value: ... + @overload + def __getitem__(self, arg0: slice) -> OpOperandList: ... + def __len__(self) -> int: ... + def __setitem__(self, arg0: int, arg1: Value) -> None: ... + +class OpResult(Value): + @staticmethod + def isinstance(other_value: Value) -> bool: ... + def __init__(self, value: Value) -> None: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + @property + def owner(self) -> _OperationBase: ... + @property + def result_number(self) -> int: ... + +class OpResultList: + def __add__(self, arg0: OpResultList) -> list[OpResult]: ... + @overload + def __getitem__(self, arg0: int) -> OpResult: ... + @overload + def __getitem__(self, arg0: slice) -> OpResultList: ... + def __len__(self) -> int: ... + @property + def owner(self) -> _OperationBase: ... + @property + def types(self) -> list[Type]: ... + +class OpSuccessors: + def __add__(self, arg0: OpSuccessors) -> list[Block]: ... + @overload + def __getitem__(self, arg0: int) -> Block: ... + @overload + def __getitem__(self, arg0: slice) -> OpSuccessors: ... + def __setitem__(self, arg0: int, arg1: Block) -> None: ... + def __len__(self) -> int: ... + +class OpView(_OperationBase): + _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... + _ODS_REGIONS: ClassVar[tuple] = ... + _ODS_RESULT_SEGMENTS: ClassVar[None] = ... + def __init__(self, operation: _OperationBase) -> None: ... + @classmethod + def build_generic( + cls: type[_TOperation], + results: Sequence[Type] | None = None, + operands: Sequence[Value] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: int | None = None, + loc: Location | None = None, + ip: InsertionPoint | None = None, + ) -> _TOperation: + """ + Builds a specific, generated OpView based on class level attributes. + """ + @classmethod + def parse( + cls: type[_TOperation], + source: str | bytes, + *, + source_name: str = "", + context: Context | None = None, + ) -> _TOperation: + """ + Parses a specific, generated OpView based on class level attributes + """ + def __init__(self, operation: _OperationBase) -> None: ... + @property + def operation(self) -> _OperationBase: ... + @property + def opview(self) -> OpView: ... + @property + def successors(self) -> OpSuccessors: + """ + Returns the List of Operation successors. + """ + +class OpaqueAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + dialect_namespace: str, + buffer: Buffer, + type: Type, + context: Context | None = None, + ) -> OpaqueAttr: + """ + Gets an Opaque attribute. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def data(self) -> bytes: + """ + Returns the data for the Opaqued attributes as `bytes` + """ + @property + def dialect_namespace(self) -> str: + """ + Returns the dialect namespace for the Opaque attribute as a string + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class OpaqueType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + dialect_namespace: str, buffer: str, context: Context | None = None + ) -> OpaqueType: + """ + Create an unregistered (opaque) dialect type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def data(self) -> str: + """ + Returns the data for the Opaque type as a string. + """ + @property + def dialect_namespace(self) -> str: + """ + Returns the dialect namespace for the Opaque type as a string. + """ + @property + def typeid(self) -> TypeID: ... + +class Operation(_OperationBase): + def _CAPICreate(self) -> object: ... + @staticmethod + def create( + name: str, + results: Sequence[Type] | None = None, + operands: Sequence[Value] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: int = 0, + loc: Location | None = None, + ip: InsertionPoint | None = None, + infer_type: bool = False, + ) -> Operation: + """ + Creates a new operation. + + Args: + name: Operation name (e.g. "dialect.operation"). + results: Sequence of Type representing op result types. + attributes: Dict of str:Attribute. + successors: List of Block for the operation's successors. + regions: Number of regions to create. + loc: A Location object (defaults to resolve from context manager). + ip: An InsertionPoint (defaults to resolve from context manager or set to + False to disable insertion, even with an insertion point set in the + context manager). + infer_type: Whether to infer result types. + Returns: + A new "detached" Operation object. Detached operations can be added + to blocks, which causes them to become "attached." + """ + @staticmethod + def parse( + source: str | bytes, *, source_name: str = "", context: Context | None = None + ) -> Operation: + """ + Parses an operation. Supports both text assembly format and binary bytecode format. + """ + def _CAPICreate(self) -> object: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def operation(self) -> Operation: ... + @property + def opview(self) -> OpView: ... + @property + def successors(self) -> OpSuccessors: + """ + Returns the List of Operation successors. + """ + +class OperationIterator: + def __iter__(self) -> OperationIterator: ... + def __next__(self) -> OpView: ... + +class OperationList: + def __getitem__(self, arg0: int) -> OpView: ... + def __iter__(self) -> OperationIterator: ... + def __len__(self) -> int: ... + +class RankedTensorType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + shape: list[int], + element_type: Type, + encoding: Attribute | None = None, + loc: Location | None = None, + ) -> RankedTensorType: + """ + Create a ranked tensor type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def encoding(self) -> Attribute | None: ... + @property + def typeid(self) -> TypeID: ... + +class Region: + __hash__: ClassVar[None] = None + @overload + def __eq__(self, arg0: Region) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __iter__(self) -> BlockIterator: + """ + Iterates over blocks in the region. + """ + @property + def blocks(self) -> BlockList: + """ + Returns a forward-optimized sequence of blocks. + """ + @property + def owner(self) -> OpView: + """ + Returns the operation owning this region. + """ + +class RegionIterator: + def __iter__(self) -> RegionIterator: ... + def __next__(self) -> Region: ... + +class RegionSequence: + @overload + def __getitem__(self, arg0: int) -> Region: ... + @overload + def __getitem__(self, arg0: slice) -> Sequence[Region]: ... + def __iter__(self) -> RegionIterator: ... + def __len__(self) -> int: ... + +class ShapedType(Type): + @staticmethod + def get_dynamic_size() -> int: + """ + Returns the value used to indicate dynamic dimensions in shaped types. + """ + @staticmethod + def get_dynamic_stride_or_offset() -> int: + """ + Returns the value used to indicate dynamic strides or offsets in shaped types. + """ + @staticmethod + def is_dynamic_size(dim_size: int) -> bool: + """ + Returns whether the given dimension size indicates a dynamic dimension. + """ + @staticmethod + def is_static_size(dim_size: int) -> bool: + """ + Returns whether the given dimension size indicates a static dimension. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def get_dim_size(self, dim: int) -> int: + """ + Returns the dim-th dimension of the given ranked shaped type. + """ + def is_dynamic_dim(self, dim: int) -> bool: + """ + Returns whether the dim-th dimension of the given shaped type is dynamic. + """ + def is_static_dim(self, dim: int) -> bool: + """ + Returns whether the dim-th dimension of the given shaped type is static. + """ + def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: + """ + Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types. + """ + def is_static_stride_or_offset(self, dim_size: int) -> bool: + """ + Returns whether the given shaped type stride or offset value is statically-sized. + """ + @property + def element_type(self) -> Type: + """ + Returns the element type of the shaped type. + """ + @property + def has_rank(self) -> bool: + """ + Returns whether the given shaped type is ranked. + """ + @property + def has_static_shape(self) -> bool: + """ + Returns whether the given shaped type has a static shape. + """ + @property + def rank(self) -> int: + """ + Returns the rank of the given ranked shaped type. + """ + @property + def shape(self) -> list[int]: + """ + Returns the shape of the ranked shaped type as a List of integers. + """ + @property + def static_typeid(self) -> TypeID: ... + @property + def typeid(self) -> TypeID: ... + +class ShapedTypeComponents: + @staticmethod + @overload + def get(element_type: Type) -> ShapedTypeComponents: + """ + Create an shaped type components object with only the element type. + """ + @staticmethod + @overload + def get(shape: list, element_type: Type) -> ShapedTypeComponents: + """ + Create a ranked shaped type components object. + """ + @staticmethod + @overload + def get( + shape: list, element_type: Type, attribute: Attribute + ) -> ShapedTypeComponents: + """ + Create a ranked shaped type components object with attribute. + """ + @property + def element_type(self) -> Type: + """ + Returns the element type of the shaped type components. + """ + @property + def has_rank(self) -> bool: + """ + Returns whether the given shaped type component is ranked. + """ + @property + def rank(self) -> int: + """ + Returns the rank of the given ranked shaped type components. If the shaped type components does not have a rank, None is returned. + """ + @property + def shape(self) -> list[int]: + """ + Returns the shape of the ranked shaped type components as a List of integers. Returns none if the shaped type component does not have a rank. + """ + +class StridedLayoutAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + offset: int, strides: list[int], context: Context | None = None + ) -> StridedLayoutAttr: + """ + Gets a strided layout attribute. + """ + @staticmethod + def get_fully_dynamic( + rank: int, context: Context | None = None + ) -> StridedLayoutAttr: + """ + Gets a strided layout attribute with dynamic offset and strides of a given rank. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def offset(self) -> int: + """ + Returns the value of the float point attribute + """ + @property + def strides(self) -> list[int]: + """ + Returns the value of the float point attribute + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class StringAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(value: str | bytes, context: Context | None = None) -> StringAttr: + """ + Gets a uniqued string attribute + """ + @staticmethod + def get_typed(type: Type, value: str) -> StringAttr: + """ + Gets a uniqued string attribute associated to a type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> str: + """ + Returns the value of the string attribute + """ + @property + def value_bytes(self) -> bytes: + """ + Returns the value of the string attribute as `bytes` + """ + +class SymbolRefAttr(Attribute): + @staticmethod + def get(symbols: list[str], context: Context | None = None) -> Attribute: + """ + Gets a uniqued SymbolRef attribute from a List of symbol names + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> list[str]: + """ + Returns the value of the SymbolRef attribute as a List[str] + """ + +class SymbolTable: + @staticmethod + def get_symbol_name(symbol: _OperationBase) -> Attribute: ... + @staticmethod + def get_visibility(symbol: _OperationBase) -> Attribute: ... + @staticmethod + def replace_all_symbol_uses( + old_symbol: str, new_symbol: str, from_op: _OperationBase + ) -> None: ... + @staticmethod + def set_symbol_name(symbol: _OperationBase, name: str) -> None: ... + @staticmethod + def set_visibility(symbol: _OperationBase, visibility: str) -> None: ... + @staticmethod + def walk_symbol_tables( + from_op: _OperationBase, + all_sym_uses_visible: bool, + callback: Callable[[_OperationBase, bool], None], + ) -> None: ... + def __contains__(self, arg0: str) -> bool: ... + def __delitem__(self, arg0: str) -> None: ... + def __getitem__(self, arg0: str) -> OpView: ... + def __init__(self, arg0: _OperationBase) -> None: ... + def erase(self, operation: _OperationBase) -> None: ... + def insert(self, operation: _OperationBase) -> Attribute: ... + +class TupleType(Type): + static_typeid: ClassVar[TypeID] + @staticmethod + def get_tuple(elements: list[Type], context: Context | None = None) -> TupleType: + """ + Create a Tuple type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def get_type(self, pos: int) -> Type: + """ + Returns the pos-th type in the Tuple type. + """ + @property + def num_types(self) -> int: + """ + Returns the number of types contained in a Tuple. + """ + @property + def typeid(self) -> TypeID: ... + +class TypeAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(value: Type, context: Context | None = None) -> TypeAttr: + """ + Gets a uniqued Type attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> Type: ... + +class TypeID: + def _CAPICreate(self) -> TypeID: ... + @overload + def __eq__(self, arg0: TypeID) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + +class UnitAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> UnitAttr: + """ + Create a Unit attribute. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class UnrankedMemRefType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + element_type: Type, memory_space: Attribute, loc: Location | None = None + ) -> UnrankedMemRefType: + """ + Create a unranked memref type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def memory_space(self) -> Attribute | None: + """ + Returns the memory space of the given Unranked MemRef type. + """ + @property + def typeid(self) -> TypeID: ... + +class UnrankedTensorType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(element_type: Type, loc: Location | None = None) -> UnrankedTensorType: + """ + Create a unranked tensor type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + +class VectorType(ShapedType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get( + shape: list[int], + element_type: Type, + *, + scalable: list | None = None, + scalable_dims: list[int] | None = None, + loc: Location | None = None, + ) -> VectorType: + """ + Create a vector type + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def scalable(self) -> bool: ... + @property + def scalable_dims(self) -> list[bool]: ... + @property + def typeid(self) -> TypeID: ... + +class _GlobalDebug: + flag: ClassVar[bool] = False diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi new file mode 100644 index 000000000..1010dadda --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -0,0 +1,36 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlir.passmanager +# Local modifications: +# * Relative imports for cross-module references. +# * Add __all__ + + +from . import ir as _ir + +__all__ = [ + "PassManager", +] + +class PassManager: + def __init__(self, context: _ir.Context | None = None) -> None: ... + def _CAPICreate(self) -> object: ... + def _testing_release(self) -> None: ... + def enable_ir_printing( + self, + print_before_all: bool = False, + print_after_all: bool = True, + print_module_scope: bool = False, + print_after_change: bool = False, + print_after_failure: bool = False, + large_elements_limit: int | None = None, + large_resource_limit: int | None = None, + enable_debug_info: bool = False, + print_generic_op_form: bool = False, + tree_printing_dir_path: str | None = None, + ) -> None: ... + def enable_verifier(self, enable: bool) -> None: ... + @staticmethod + def parse(pipeline: str, context: _ir.Context | None = None) -> PassManager: ... + def run(self, module: _ir._OperationBase) -> None: ... + @property + def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi new file mode 100644 index 000000000..58d453d2b --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -0,0 +1,23 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlirExecutionEngine +# Local modifications: +# * Relative imports for cross-module references. +# * Add __all__ + +from collections.abc import Sequence + +from ._mlir import ir as _ir + +__all__ = [ + "ExecutionEngine", +] + +class ExecutionEngine: + def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ... + def _CAPICreate(self) -> object: ... + def _testing_release(self) -> None: ... + def dump_to_object_file(self, file_name: str) -> None: ... + def raw_lookup(self, func_name: str) -> int: ... + def raw_register_runtime(self, name: str, callback: object) -> None: ... + @property + def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/dialects/AMDGPUOps.td b/mlir/python/mlir/dialects/AMDGPUOps.td new file mode 100644 index 000000000..fe9371971 --- /dev/null +++ b/mlir/python/mlir/dialects/AMDGPUOps.td @@ -0,0 +1,14 @@ +//===-- AMDGPUOps.td - Entry point for AMDGPUOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_AMDGPU_OPS +#define PYTHON_BINDINGS_AMDGPU_OPS + +include "mlir/Dialect/AMDGPU/IR/AMDGPU.td" + +#endif diff --git a/mlir/python/mlir/dialects/AffineOps.td b/mlir/python/mlir/dialects/AffineOps.td new file mode 100644 index 000000000..e12ffafb8 --- /dev/null +++ b/mlir/python/mlir/dialects/AffineOps.td @@ -0,0 +1,14 @@ +//===-- AffineOps.td - Entry point for Affine bindings -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_AFFINE_OPS +#define PYTHON_BINDINGS_AFFINE_OPS + +include "mlir/Dialect/Affine/IR/AffineOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/ArithOps.td b/mlir/python/mlir/dialects/ArithOps.td new file mode 100644 index 000000000..60dbb08a0 --- /dev/null +++ b/mlir/python/mlir/dialects/ArithOps.td @@ -0,0 +1,14 @@ +//===-- ArithOps.td - Entry point for ArithOps bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ARITH_OPS +#define PYTHON_BINDINGS_ARITH_OPS + +include "mlir/Dialect/Arith/IR/ArithOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/AsyncOps.td b/mlir/python/mlir/dialects/AsyncOps.td new file mode 100644 index 000000000..2b05045cf --- /dev/null +++ b/mlir/python/mlir/dialects/AsyncOps.td @@ -0,0 +1,14 @@ +//===-- AsyncOps.td - Entry point async_dialect bindings --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ASYNC_OPS +#define PYTHON_BINDINGS_ASYNC_OPS + +include "mlir/Dialect/Async/IR/AsyncOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/BufferizationEnums.td b/mlir/python/mlir/dialects/BufferizationEnums.td new file mode 100644 index 000000000..f676ce082 --- /dev/null +++ b/mlir/python/mlir/dialects/BufferizationEnums.td @@ -0,0 +1,14 @@ +//===-- BufferizationEnums.td - Entry point for BufferizationEnums bindings ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_BUFFERIZATION_ENUMS +#define PYTHON_BINDINGS_BUFFERIZATION_ENUMS + +include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td" + +#endif // PYTHON_BINDINGS_BUFFERIZATION_ENUMS diff --git a/mlir/python/mlir/dialects/BufferizationOps.td b/mlir/python/mlir/dialects/BufferizationOps.td new file mode 100644 index 000000000..b2ac7e281 --- /dev/null +++ b/mlir/python/mlir/dialects/BufferizationOps.td @@ -0,0 +1,14 @@ +//===-- BufferizationOps.td - Entry point for BufferizationOps bindings ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_BUFFERIZATION_OPS +#define PYTHON_BINDINGS_BUFFERIZATION_OPS + +include "mlir/Dialect/Bufferization/IR/BufferizationOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/BufferizationTransformOps.td b/mlir/python/mlir/dialects/BufferizationTransformOps.td new file mode 100644 index 000000000..34213be22 --- /dev/null +++ b/mlir/python/mlir/dialects/BufferizationTransformOps.td @@ -0,0 +1,20 @@ +//===-- BufferizationTransformOps.td -----------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the bufferization dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS +#define PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS + +include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td" + +#endif // PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/lib/Bindings/Python/BuiltinOps.td b/mlir/python/mlir/dialects/BuiltinOps.td similarity index 91% rename from mlir/lib/Bindings/Python/BuiltinOps.td rename to mlir/python/mlir/dialects/BuiltinOps.td index ecbb8227d..d1c595283 100644 --- a/mlir/lib/Bindings/Python/BuiltinOps.td +++ b/mlir/python/mlir/dialects/BuiltinOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_BUILTIN_OPS #define PYTHON_BINDINGS_BUILTIN_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/IR/BuiltinOps.td" #endif diff --git a/mlir/python/mlir/dialects/ComplexOps.td b/mlir/python/mlir/dialects/ComplexOps.td new file mode 100644 index 000000000..17825b6be --- /dev/null +++ b/mlir/python/mlir/dialects/ComplexOps.td @@ -0,0 +1,14 @@ +//===-- ComplexOps.td - Entry point for ComplexOps bindings ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_COMPLEX_OPS +#define PYTHON_BINDINGS_COMPLEX_OPS + +include "mlir/Dialect/Complex/IR/ComplexOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/ControlFlowOps.td b/mlir/python/mlir/dialects/ControlFlowOps.td new file mode 100644 index 000000000..c9610a3c6 --- /dev/null +++ b/mlir/python/mlir/dialects/ControlFlowOps.td @@ -0,0 +1,14 @@ +//===-- ControlFlowOps.td - Python ControlFlowOps bindings -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_CONTROL_FLOW_OPS +#define PYTHON_BINDINGS_CONTROL_FLOW_OPS + +include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/EmitC.td b/mlir/python/mlir/dialects/EmitC.td new file mode 100644 index 000000000..ff0a56d15 --- /dev/null +++ b/mlir/python/mlir/dialects/EmitC.td @@ -0,0 +1,14 @@ +//===-- EmitC.td - Entry point for EmitC bind --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_EMITC +#define PYTHON_BINDINGS_EMITC + +include "mlir/Dialect/EmitC/IR/EmitC.td" + +#endif diff --git a/mlir/lib/Bindings/Python/StandardOps.td b/mlir/python/mlir/dialects/FuncOps.td similarity index 63% rename from mlir/lib/Bindings/Python/StandardOps.td rename to mlir/python/mlir/dialects/FuncOps.td index 5b7caabc2..0816d6a3f 100644 --- a/mlir/lib/Bindings/Python/StandardOps.td +++ b/mlir/python/mlir/dialects/FuncOps.td @@ -1,4 +1,4 @@ -//===-- StandardOps.td - Entry point for StandardOps bind --*- tablegen -*-===// +//===-- FuncOps.td - Entry point for Func bind -------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,15 +6,14 @@ // //===----------------------------------------------------------------------===// // -// This is the main file from which the Python bindings for the Standard -// dialect are generated. +// This is the main file from which the Python bindings for the Func dialect +// are generated. // //===----------------------------------------------------------------------===// -#ifndef PYTHON_BINDINGS_STANDARD_OPS -#define PYTHON_BINDINGS_STANDARD_OPS +#ifndef PYTHON_BINDINGS_FUNC +#define PYTHON_BINDINGS_FUNC -include "mlir/Bindings/Python/Attributes.td" -include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Func/IR/FuncOps.td" #endif diff --git a/mlir/python/mlir/dialects/GPUOps.td b/mlir/python/mlir/dialects/GPUOps.td new file mode 100644 index 000000000..83b1f6cd4 --- /dev/null +++ b/mlir/python/mlir/dialects/GPUOps.td @@ -0,0 +1,14 @@ +//===-- GPUOps.td - Entry point GPU_dialect bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_GPU_OPS +#define PYTHON_BINDINGS_GPU_OPS + +include "mlir/Dialect/GPU/IR/GPUOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/GPUTransformOps.td b/mlir/python/mlir/dialects/GPUTransformOps.td new file mode 100644 index 000000000..08bd9537b --- /dev/null +++ b/mlir/python/mlir/dialects/GPUTransformOps.td @@ -0,0 +1,20 @@ +//===-- GPUTransformOps.td ---------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the GPU dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_GPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_GPU_TRANSFORM_OPS + +include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.td" + +#endif // PYTHON_BINDINGS_GPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/IndexOps.td b/mlir/python/mlir/dialects/IndexOps.td new file mode 100644 index 000000000..13b1d782c --- /dev/null +++ b/mlir/python/mlir/dialects/IndexOps.td @@ -0,0 +1,14 @@ +//===-- IndexOps.td - Entry point for Index bindings -----*- tablegen -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_INDEX_OPS +#define PYTHON_BINDINGS_INDEX_OPS + +include "mlir/Dialect/Index/IR/IndexOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td new file mode 100644 index 000000000..30f047f21 --- /dev/null +++ b/mlir/python/mlir/dialects/LLVMOps.td @@ -0,0 +1,15 @@ +//===-- LlvmOps.td - Entry point for llvm bind ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_LLVM_OPS +#define PYTHON_BINDINGS_LLVM_OPS + +include "mlir/Dialect/LLVMIR/LLVMOps.td" +include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td similarity index 91% rename from mlir/lib/Bindings/Python/LinalgOps.td rename to mlir/python/mlir/dialects/LinalgOps.td index 7650e954d..89fb3f219 100644 --- a/mlir/lib/Bindings/Python/LinalgOps.td +++ b/mlir/python/mlir/dialects/LinalgOps.td @@ -9,8 +9,8 @@ #ifndef PYTHON_BINDINGS_LINALG_OPS #define PYTHON_BINDINGS_LINALG_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Linalg/IR/LinalgOps.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" +include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td" #endif diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td b/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td new file mode 100644 index 000000000..e86c9b7dd --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td @@ -0,0 +1,20 @@ +//===-- LinalgStructuredTransformEnums.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the structured transform ops +// provided by Linalg (and other dialects). +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS +#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS + +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" + +#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td new file mode 100644 index 000000000..e11065bf8 --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td @@ -0,0 +1,20 @@ +//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the structured transform ops +// provided by Linalg (and other dialects). +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS +#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS + +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td" + +#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/MLProgramOps.td b/mlir/python/mlir/dialects/MLProgramOps.td new file mode 100644 index 000000000..35b348d5f --- /dev/null +++ b/mlir/python/mlir/dialects/MLProgramOps.td @@ -0,0 +1,14 @@ +//===-- MLProgramOps.td - Entry point for MLProgramOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MLPROGRAM_OPS +#define PYTHON_BINDINGS_MLPROGRAM_OPS + +include "mlir/Dialect/MLProgram/IR/MLProgramOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/MathOps.td b/mlir/python/mlir/dialects/MathOps.td new file mode 100644 index 000000000..8f68467ea --- /dev/null +++ b/mlir/python/mlir/dialects/MathOps.td @@ -0,0 +1,14 @@ +//===-- MathOps.td - Entry point for MathOps bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MATH_OPS +#define PYTHON_BINDINGS_MATH_OPS + +include "mlir/Dialect/Math/IR/MathOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/MemRefOps.td b/mlir/python/mlir/dialects/MemRefOps.td new file mode 100644 index 000000000..ed346d5a2 --- /dev/null +++ b/mlir/python/mlir/dialects/MemRefOps.td @@ -0,0 +1,14 @@ +//===-- MemRefOps.td - Entry point for MemRefOps bind ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MEMREF_OPS +#define PYTHON_BINDINGS_MEMREF_OPS + +include "mlir/Dialect/MemRef/IR/MemRefOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/MemRefTransformOps.td b/mlir/python/mlir/dialects/MemRefTransformOps.td new file mode 100644 index 000000000..a64c2e238 --- /dev/null +++ b/mlir/python/mlir/dialects/MemRefTransformOps.td @@ -0,0 +1,14 @@ +//===-- MemRefTransformOps.td - Memref transform ops -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MEMREF_TRANSFORM_OPS +#define PYTHON_BINDINGS_MEMREF_TRANSFORM_OPS + +include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/NVGPUOps.td b/mlir/python/mlir/dialects/NVGPUOps.td new file mode 100644 index 000000000..cdf651901 --- /dev/null +++ b/mlir/python/mlir/dialects/NVGPUOps.td @@ -0,0 +1,14 @@ +//===-- NVGPUOps.td - Entry point for NVGPUOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_NVGPU_OPS +#define PYTHON_BINDINGS_NVGPU_OPS + +include "mlir/Dialect/NVGPU/IR/NVGPUOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/NVGPUTransformOps.td b/mlir/python/mlir/dialects/NVGPUTransformOps.td new file mode 100644 index 000000000..1f504e322 --- /dev/null +++ b/mlir/python/mlir/dialects/NVGPUTransformOps.td @@ -0,0 +1,20 @@ +//===-- NVGPUTransformOps.td -------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the NVGPU dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS + +include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td" + +#endif // PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/NVVMOps.td b/mlir/python/mlir/dialects/NVVMOps.td new file mode 100644 index 000000000..f57d204a8 --- /dev/null +++ b/mlir/python/mlir/dialects/NVVMOps.td @@ -0,0 +1,14 @@ +//===-- NVVMOps.td - Entry point for NVVMOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_NVVM_OPS +#define PYTHON_BINDINGS_NVVM_OPS + +include "mlir/Dialect/LLVMIR/NVVMOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/OpenMPOps.td b/mlir/python/mlir/dialects/OpenMPOps.td new file mode 100644 index 000000000..b91179b0d --- /dev/null +++ b/mlir/python/mlir/dialects/OpenMPOps.td @@ -0,0 +1,14 @@ +//===-- OpenMPOps.td - Entry point for OpenMPOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_OPENMP_OPS +#define PYTHON_BINDINGS_OPENMP_OPS + +include "mlir/Dialect/OpenMP/OpenMPOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/PDLOps.td b/mlir/python/mlir/dialects/PDLOps.td new file mode 100644 index 000000000..a8c2d6bdb --- /dev/null +++ b/mlir/python/mlir/dialects/PDLOps.td @@ -0,0 +1,14 @@ +//===-- PDLOps.td - Entry point for PDLOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_PDL_OPS +#define PYTHON_BINDINGS_PDL_OPS + +include "mlir/Dialect/PDL/IR/PDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/ROCDLOps.td b/mlir/python/mlir/dialects/ROCDLOps.td new file mode 100644 index 000000000..fa5c9ebc3 --- /dev/null +++ b/mlir/python/mlir/dialects/ROCDLOps.td @@ -0,0 +1,14 @@ +//===-- ROCDLOps.td - Entry point for ROCDLOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ROCDL_OPS +#define PYTHON_BINDINGS_ROCDL_OPS + +include "mlir/Dialect/LLVMIR/ROCDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td new file mode 100644 index 000000000..4a904d578 --- /dev/null +++ b/mlir/python/mlir/dialects/SCFLoopTransformOps.td @@ -0,0 +1,21 @@ +//===-- SCFLoopTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the loop transform ops +// provided by the SCF (and other) dialects. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS +#define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS + +include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td" +include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td" + +#endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td new file mode 100644 index 000000000..f1fc8a8db --- /dev/null +++ b/mlir/python/mlir/dialects/SCFOps.td @@ -0,0 +1,14 @@ +//===-- SCFOps.td - Entry point for SCF dialect bindings ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SCF_OPS +#define PYTHON_BINDINGS_SCF_OPS + +include "mlir/Dialect/SCF/IR/SCFOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/SMTOps.td b/mlir/python/mlir/dialects/SMTOps.td new file mode 100644 index 000000000..e143f071e --- /dev/null +++ b/mlir/python/mlir/dialects/SMTOps.td @@ -0,0 +1,14 @@ +//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BINDINGS_PYTHON_SMT_OPS +#define BINDINGS_PYTHON_SMT_OPS + +include "mlir/Dialect/SMT/IR/SMT.td" + +#endif // BINDINGS_PYTHON_SMT_OPS diff --git a/mlir/python/mlir/dialects/SPIRVOps.td b/mlir/python/mlir/dialects/SPIRVOps.td new file mode 100644 index 000000000..eaae0e609 --- /dev/null +++ b/mlir/python/mlir/dialects/SPIRVOps.td @@ -0,0 +1,14 @@ +//===-- SPIRVOps.td - Entry point for SPIRVOps bind --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPIRV_OPS +#define PYTHON_BINDINGS_SPIRV_OPS + +include "mlir/Dialect/SPIRV/IR/SPIRVOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/ShapeOps.td b/mlir/python/mlir/dialects/ShapeOps.td similarity index 91% rename from mlir/lib/Bindings/Python/ShapeOps.td rename to mlir/python/mlir/dialects/ShapeOps.td index c469a586b..e217b2edc 100644 --- a/mlir/lib/Bindings/Python/ShapeOps.td +++ b/mlir/python/mlir/dialects/ShapeOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_SHAPE_OPS #define PYTHON_BINDINGS_SHAPE_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Shape/IR/ShapeOps.td" #endif diff --git a/mlir/python/mlir/dialects/SparseTensorAttrDefs.td b/mlir/python/mlir/dialects/SparseTensorAttrDefs.td new file mode 100644 index 000000000..5a86f55df --- /dev/null +++ b/mlir/python/mlir/dialects/SparseTensorAttrDefs.td @@ -0,0 +1,14 @@ +//===-- SparseTensorAttrDefs.td - Entry point for bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS +#define PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS + +include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" + +#endif // PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS diff --git a/mlir/python/mlir/dialects/SparseTensorOps.td b/mlir/python/mlir/dialects/SparseTensorOps.td new file mode 100644 index 000000000..3f0d522f3 --- /dev/null +++ b/mlir/python/mlir/dialects/SparseTensorOps.td @@ -0,0 +1,14 @@ +//===-- SparseTensorOps.td - Entry point for bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_OPS +#define PYTHON_BINDINGS_SPARSE_TENSOR_OPS + +include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/SparseTensorTransformOps.td b/mlir/python/mlir/dialects/SparseTensorTransformOps.td new file mode 100644 index 000000000..f4c4464ee --- /dev/null +++ b/mlir/python/mlir/dialects/SparseTensorTransformOps.td @@ -0,0 +1,14 @@ +//===-- SparseTensorTransformOps.td ------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS +#define PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS + +include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/TensorOps.td b/mlir/python/mlir/dialects/TensorOps.td similarity index 91% rename from mlir/lib/Bindings/Python/TensorOps.td rename to mlir/python/mlir/dialects/TensorOps.td index 40ecea7bf..d68cd2447 100644 --- a/mlir/lib/Bindings/Python/TensorOps.td +++ b/mlir/python/mlir/dialects/TensorOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_TENSOR_OPS #define PYTHON_BINDINGS_TENSOR_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" #endif diff --git a/mlir/python/mlir/dialects/TensorTransformOps.td b/mlir/python/mlir/dialects/TensorTransformOps.td new file mode 100644 index 000000000..87c5c7f39 --- /dev/null +++ b/mlir/python/mlir/dialects/TensorTransformOps.td @@ -0,0 +1,20 @@ +//===-- TensorTransformOps.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the tensor dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_TENSOR_TRANSFORM_OPS +#define PYTHON_BINDINGS_TENSOR_TRANSFORM_OPS + +include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td" + +#endif // PYTHON_BINDINGS_TENSOR_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/TosaOps.td b/mlir/python/mlir/dialects/TosaOps.td new file mode 100644 index 000000000..b429780bc --- /dev/null +++ b/mlir/python/mlir/dialects/TosaOps.td @@ -0,0 +1,14 @@ +//===-- TosaOps.td - Entry point for TosaOps bind ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TOSA_OPS +#define PYTHON_BINDINGS_TOSA_OPS + +include "mlir/Dialect/Tosa/IR/TosaOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/TransformAttrs.td b/mlir/python/mlir/dialects/TransformAttrs.td new file mode 100644 index 000000000..451118a5d --- /dev/null +++ b/mlir/python/mlir/dialects/TransformAttrs.td @@ -0,0 +1,14 @@ +//===-- TransformAttrs.td - Transform attrs bind entry point ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_ATTRS +#define PYTHON_BINDINGS_TRANSFORM_ATTRS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_ATTRS diff --git a/mlir/python/mlir/dialects/TransformDebugExtensionOps.td b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td new file mode 100644 index 000000000..22a85d236 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the Debug extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS + +include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/TransformOps.td b/mlir/python/mlir/dialects/TransformOps.td new file mode 100644 index 000000000..e2f6cf932 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformOps.td @@ -0,0 +1,14 @@ +//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_OPS +#define PYTHON_BINDINGS_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/TransformOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td new file mode 100644 index 000000000..56fadd029 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the PDL extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS + +include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td new file mode 100644 index 000000000..c622c31e2 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the Tune extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_TUNE_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_TUNE_EXTENSION_OPS + +include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_TUNE_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/Vector.td b/mlir/python/mlir/dialects/Vector.td new file mode 100644 index 000000000..f659f754b --- /dev/null +++ b/mlir/python/mlir/dialects/Vector.td @@ -0,0 +1,14 @@ +//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR +#define PYTHON_BINDINGS_VECTOR + +include "mlir/Dialect/Vector/IR/Vector.td" + +#endif // PYTHON_BINDINGS_VECTOR diff --git a/mlir/python/mlir/dialects/VectorAttributes.td b/mlir/python/mlir/dialects/VectorAttributes.td new file mode 100644 index 000000000..038e0ba21 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorAttributes.td @@ -0,0 +1,14 @@ +//===-- VectorAttributes.td - Entry point for bindings -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR_ATTRDEFS_TD +#define PYTHON_BINDINGS_VECTOR_ATTRDEFS_TD + +include "mlir/Dialect/Vector/IR/VectorAttributes.td" + +#endif // PYTHON_BINDINGS_VECTOR_ATTRDEFS_TD diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td new file mode 100644 index 000000000..69a1028c9 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -0,0 +1,14 @@ +//===-- VectorOps.td - Entry point for VectorOps bind ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR_OPS +#define PYTHON_BINDINGS_VECTOR_OPS + +include "mlir/Dialect/Vector/IR/VectorOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/VectorTransformOps.td b/mlir/python/mlir/dialects/VectorTransformOps.td new file mode 100644 index 000000000..42aa8c006 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorTransformOps.td @@ -0,0 +1,19 @@ +//===-- VectorTransformOps.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the vector transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_VECTORTRANSFORMOPS +#define PYTHON_BINDINGS_VECTORTRANSFORMOPS + +include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.td" + +#endif // PYTHON_BINDINGS_VECTORTRANSFORMOPS diff --git a/mlir/python/mlir/dialects/VectorTransformsBase.td b/mlir/python/mlir/dialects/VectorTransformsBase.td new file mode 100644 index 000000000..acb4aeced --- /dev/null +++ b/mlir/python/mlir/dialects/VectorTransformsBase.td @@ -0,0 +1,19 @@ +//===-- VectorTransformsBase.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the vector transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_VECTORTRANSFORMBASE +#define PYTHON_BINDINGS_VECTORTRANSFORMBASE + +include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" + +#endif // PYTHON_BINDINGS_VECTORTRANSFORMBASE diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py new file mode 100644 index 000000000..10abd06ff --- /dev/null +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -0,0 +1,307 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import ( + List as _List, + Optional as _Optional, + Sequence as _Sequence, + Tuple as _Tuple, + Type as _Type, + Union as _Union, +) + +from .._mlir_libs import _mlir as _cext +from ..ir import ( + ArrayAttr, + Attribute, + BoolAttr, + DenseI64ArrayAttr, + IntegerAttr, + IntegerType, + OpView, + Operation, + ShapedType, + Value, +) + +__all__ = [ + "equally_sized_accessor", + "get_default_loc_context", + "get_op_result_or_value", + "get_op_results_or_values", + "get_op_result_or_op_results", + "segmented_accessor", +] + + +def segmented_accessor(elements, raw_segments, idx): + """ + Returns a slice of elements corresponding to the idx-th segment. + + elements: a sliceable container (operands or results). + raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing + sizes of the segments. + idx: index of the segment. + """ + segments = _cext.ir.DenseI32ArrayAttr(raw_segments) + start = sum(segments[i] for i in range(idx)) + end = start + segments[idx] + return elements[start:end] + + +def equally_sized_accessor( + elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic +): + """ + Returns a starting position and a number of elements per variadic group + assuming equally-sized groups and the given numbers of preceding groups. + + elements: a sequential container. + n_simple: the number of non-variadic groups in the container. + n_variadic: the number of variadic groups in the container. + n_preceding_simple: the number of non-variadic groups preceding the current + group. + n_preceding_variadic: the number of variadic groups preceding the current + group. + """ + + total_variadic_length = len(elements) - n_simple + # This should be enforced by the C++-side trait verifier. + assert total_variadic_length % n_variadic == 0 + + elements_per_group = total_variadic_length // n_variadic + start = n_preceding_simple + n_preceding_variadic * elements_per_group + return start, elements_per_group + + +def get_default_loc_context(location=None): + """ + Returns a context in which the defaulted location is created. If the location + is None, takes the current location from the stack. + """ + if location is None: + if _cext.ir.Location.current: + return _cext.ir.Location.current.context + return None + return location.context + + +def get_op_result_or_value( + arg: _Union[ + _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList + ] +) -> _cext.ir.Value: + """Returns the given value or the single result of the given op. + + This is useful to implement op constructors so that they can take other ops as + arguments instead of requiring the caller to extract results for every op. + Raises ValueError if provided with an op that doesn't have a single result. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.result + elif isinstance(arg, _cext.ir.Operation): + return arg.result + elif isinstance(arg, _cext.ir.OpResultList): + return arg[0] + else: + assert isinstance(arg, _cext.ir.Value), f"expects Value, got {type(arg)}" + return arg + + +def get_op_results_or_values( + arg: _Union[ + _cext.ir.OpView, + _cext.ir.Operation, + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + ] +) -> _Union[ + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + _cext.ir.OpResultList, +]: + """Returns the given sequence of values or the results of the given op. + + This is useful to implement op constructors so that they can take other ops as + lists of arguments instead of requiring the caller to extract results for + every op. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.results + elif isinstance(arg, _cext.ir.Operation): + return arg.results + else: + return arg + + +def get_op_result_or_op_results( + op: _Union[_cext.ir.OpView, _cext.ir.Operation], +) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: + results = op.results + num_results = len(results) + if num_results == 1: + return results[0] + elif num_results > 1: + return results + elif isinstance(op, _cext.ir.OpView): + return op.operation + else: + return op + + +ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value +ResultValueT = _Union[ResultValueTypeTuple] +VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] + +StaticIntLike = _Union[int, IntegerAttr] +ValueLike = _Union[Operation, OpView, Value] +MixedInt = _Union[StaticIntLike, ValueLike] + +IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] +OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] +OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] + +MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: _Union[DynamicIndexList, ArrayAttr], +) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a _Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(size) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + + +def _get_value_or_attribute_value( + value_or_attr: _Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr] +) -> _Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr( + values: _Optional[_Union[ArrayAttr, IntOrAttrList]] +) -> ArrayAttr: + if values is None: + return None + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( + values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. + + The input has to be a collection of a collection of integers, where any + Python _Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. + """ + if values is None: + return None + + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. + return ArrayAttr.get(values) diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py new file mode 100644 index 000000000..7641d36e3 --- /dev/null +++ b/mlir/python/mlir/dialects/affine.py @@ -0,0 +1,216 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._affine_ops_gen import * +from ._affine_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ResultValueTypeTuple as _ResultValueTypeTuple, + ResultValueT as _ResultValueT, + VariadicResultValueT as _VariadicResultValueT, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineForOp(AffineForOp): + """Specialization for the Affine for op class.""" + + def __init__( + self, + lower_bound: Union[int, _ResultValueT, AffineMap], + upper_bound: Optional[Union[int, _ResultValueT, AffineMap]], + step: Optional[Union[int, Attribute]] = None, + iter_args: Optional[_ResultValueT] = None, + *, + lower_bound_operands: Optional[_VariadicResultValueT] = None, + upper_bound_operands: Optional[_VariadicResultValueT] = None, + loc=None, + ip=None, + ): + """Creates an Affine `for` operation. + + - `lower_bound` is the affine map to use as lower bound of the loop. + - `upper_bound` is the affine map to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + - `lower_bound_operands` is the list of arguments to substitute the dimensions, + then symbols in the `lower_bound` affine map, in an increasing order. + - `upper_bound_operands` is the list of arguments to substitute the dimensions, + then symbols in the `upper_bound` affine map, in an increasing order. + """ + + if lower_bound_operands is None: + lower_bound_operands = [] + if upper_bound_operands is None: + upper_bound_operands = [] + + if step is None: + step = 1 + + bounds_operands = [lower_bound_operands, upper_bound_operands] + bounds = [lower_bound, upper_bound] + bounds_names = ["lower", "upper"] + for i, name in enumerate(bounds_names): + if isinstance(bounds[i], int): + bounds[i] = AffineMap.get_constant(bounds[i]) + elif isinstance(bounds[i], _ResultValueTypeTuple): + if len(bounds_operands[i]): + raise ValueError( + f"Either a concrete {name} bound or an AffineMap in combination " + f"with {name} bound operands, but not both, is supported." + ) + if ( + isinstance(bounds[i], (OpView, Operation)) + and len(bounds[i].results) > 1 + ): + raise ValueError( + f"Only a single concrete value is supported for {name} bound." + ) + + bounds_operands[i].append(_get_op_result_or_value(bounds[i])) + bounds[i] = AffineMap.get_identity(1) + + if not isinstance(bounds[i], AffineMap): + raise ValueError( + f"{name} bound must be int | ResultValueT | AffineMap." + ) + if len(bounds_operands[i]) != bounds[i].n_inputs: + raise ValueError( + f"Wrong number of {name} bound operands passed to AffineForOp; " + + f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}." + ) + + lower_bound, upper_bound = bounds + + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + + results = [arg.type for arg in iter_args] + super().__init__( + results_=results, + lowerBoundOperands=_get_op_results_or_values(lower_bound_operands), + upperBoundOperands=_get_op_results_or_values(upper_bound_operands), + inits=list(iter_args), + lowerBoundMap=AffineMapAttr.get(lower_bound), + upperBoundMap=AffineMapAttr.get(upper_bound), + step=step, + loc=loc, + ip=ip, + ) + self.regions[0].blocks.append(IndexType.get(), *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] + + +def for_( + start, + stop, + step=None, + iter_args: Optional[Sequence[Value]] = None, + *, + loc=None, + ip=None, +): + for_op = AffineForOp( + start, + stop, + step, + iter_args=iter_args, + loc=loc, + ip=ip, + ) + iv = for_op.induction_variable + iter_args = tuple(for_op.inner_iter_args) + with InsertionPoint(for_op.body): + if len(iter_args) > 1: + yield iv, iter_args + elif len(iter_args) == 1: + yield iv, iter_args[0] + else: + yield iv + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineIfOp(AffineIfOp): + """Specialization for the Affine if op class.""" + + def __init__( + self, + cond: IntegerSet, + results_: Optional[Type] = None, + *, + cond_operands: Optional[_VariadicResultValueT] = None, + has_else: bool = False, + loc=None, + ip=None, + ): + """Creates an Affine `if` operation. + + - `cond` is the integer set used to determine which regions of code + will be executed. + - `results` are the list of types to be yielded by the operand. + - `cond_operands` is the list of arguments to substitute the + dimensions, then symbols in the `cond` integer set expression to + determine whether they are in the set. + - `has_else` determines whether the affine if operation has the else + branch. + """ + if results_ is None: + results_ = [] + if cond_operands is None: + cond_operands = [] + + if cond.n_inputs != len(cond_operands): + raise ValueError( + f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}" + ) + + operands = [] + operands.extend(cond_operands) + results = [] + results.extend(results_) + + super().__init__(results, cond_operands, cond) + self.regions[0].blocks.append(*[]) + if has_else: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self) -> Block: + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self) -> Optional[Block]: + """Returns the else block of the if operation.""" + if len(self.regions[1].blocks) == 0: + return None + return self.regions[1].blocks[0] diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py new file mode 100644 index 000000000..43d905d0c --- /dev/null +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._amdgpu_ops_gen import * +from ._amdgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py new file mode 100644 index 000000000..92da5df9b --- /dev/null +++ b/mlir/python/mlir/dialects/arith.py @@ -0,0 +1,110 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._arith_ops_gen import * +from ._arith_ops_gen import _Dialect +from ._arith_enum_gen import * +from array import array as _array +from typing import overload + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + get_op_result_or_op_results as _get_op_result_or_op_results, + ) + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstantOp(ConstantOp): + """Specialization for the constant op class.""" + + @overload + def __init__(self, value: Attribute, *, loc=None, ip=None): + ... + + @overload + def __init__( + self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None + ): + ... + + def __init__(self, result, value, *, loc=None, ip=None): + if value is None: + assert isinstance(result, Attribute) + super().__init__(result, loc=loc, ip=ip) + return + + if isinstance(value, int): + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, _array): + if 8 * value.itemsize != result.element_type.width: + raise ValueError( + f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width." + ) + if value.typecode in ["i", "l", "q"]: + super().__init__(DenseIntElementsAttr.get(value, type=result)) + elif value.typecode in ["f", "d"]: + super().__init__(DenseFPElementsAttr.get(value, type=result)) + else: + raise ValueError(f'Unsupported typecode: "{value.typecode}".') + else: + super().__init__(value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip + ) + + @property + def type(self): + return self.results[0].type + + @property + def value(self): + return Attribute(self.operation.attributes["value"]) + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") + + +def constant( + result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None +) -> Value: + return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/async_dialect/__init__.py similarity index 86% rename from mlir/lib/Bindings/Python/mlir/dialects/tensor.py rename to mlir/python/mlir/dialects/async_dialect/__init__.py index 26edf6b64..6a5ecfc20 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/async_dialect/__init__.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._tensor_ops_gen import * +from .._async_ops_gen import * diff --git a/mlir/python/mlir/dialects/async_dialect/passes/__init__.py b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py new file mode 100644 index 000000000..851d56148 --- /dev/null +++ b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...._mlir_libs import _mlirAsyncPasses as _cextAsyncPasses diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py new file mode 100644 index 000000000..759b6aa24 --- /dev/null +++ b/mlir/python/mlir/dialects/bufferization.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._bufferization_ops_gen import * +from ._bufferization_enum_gen import * diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py new file mode 100644 index 000000000..1c69d6d7c --- /dev/null +++ b/mlir/python/mlir/dialects/builtin.py @@ -0,0 +1,48 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, Optional + +from ._builtin_ops_gen import * +from ._builtin_ops_gen import _Dialect +from ..extras.meta import region_op + +try: + from ..ir import * + from ._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ModuleOp(ModuleOp): + """Specialization for the module op class.""" + + def __init__(self, *, loc=None, ip=None): + super().__init__(loc=loc, ip=ip) + body = self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] + + +@region_op +def module( + *, + sym_name=None, + sym_visibility=None, + attrs: Optional[Dict[str, Attribute]] = None, + loc=None, + ip=None, +): + mod = ModuleOp.__base__( + sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip + ) + if attrs is None: + attrs = {} + for attr_name, attr in attrs.items(): + mod.operation.attributes[attr_name] = attr + + return mod diff --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/python/mlir/dialects/cf.py similarity index 87% rename from mlir/lib/Bindings/Python/mlir/dialects/std.py rename to mlir/python/mlir/dialects/cf.py index 8e55807a0..c2e357a8e 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/std.py +++ b/mlir/python/mlir/dialects/cf.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._std_ops_gen import * +from ._cf_ops_gen import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/complex.py similarity index 86% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py rename to mlir/python/mlir/dialects/complex.py index 81949b8f8..ca81173cf 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/complex.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._linalg_ops_gen import * +from ._complex_ops_gen import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/emitc.py similarity index 86% rename from mlir/lib/Bindings/Python/mlir/dialects/builtin.py rename to mlir/python/mlir/dialects/emitc.py index 30279e161..99c3286e5 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/builtin.py +++ b/mlir/python/mlir/dialects/emitc.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._builtin_ops_gen import * +from ._emitc_ops_gen import * diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py new file mode 100644 index 000000000..1898fc156 --- /dev/null +++ b/mlir/python/mlir/dialects/func.py @@ -0,0 +1,330 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._func_ops_gen import * +from ._func_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) + + import inspect + + from typing import Any, List, Optional, Sequence, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstantOp(ConstantOp): + """Specialization for the constant op class.""" + + @property + def type(self): + return self.results[0].type + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + if ARGUMENT_ATTRIBUTE_NAME not in self.attributes: + return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs]) + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func( + FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None, + ): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and ( + param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY + ): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results + ) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None." + ) + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get( + inputs=inputs, results=return_types + ) + func_op.attributes["function_type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp( + return_types, FlatSymbolRefAttr.get(symbol_name), call_args + ) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator + + +func = FuncOp.from_py_func + + +@_ods_cext.register_operation(_Dialect, replace=True) +class CallOp(CallOp): + """Specialization for the call op class.""" + + def __init__( + self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None, + ): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = func.FuncOp("foo", ...) + func.CallOp(f, [args]) + func.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}" + ) + if arguments is not None: + raise ValueError( + "unexpected third argument when constructing a call" + + "to a function" + ) + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, context=_get_default_loc_context(loc) + ), + argumentsOrCallee, + loc=loc, + ip=ip, + ) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" + ) + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip + ) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc) + ), + arguments, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py new file mode 100644 index 000000000..4cd80aa8b --- /dev/null +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -0,0 +1,7 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._gpu_ops_gen import * +from .._gpu_enum_gen import * +from ..._mlir_libs._mlirDialectsGPU import * diff --git a/mlir/python/mlir/dialects/gpu/passes/__init__.py b/mlir/python/mlir/dialects/gpu/passes/__init__.py new file mode 100644 index 000000000..9b1ef076a --- /dev/null +++ b/mlir/python/mlir/dialects/gpu/passes/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...._mlir_libs import _mlirGPUPasses as _cextGPUPasses diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py new file mode 100644 index 000000000..73708c7d7 --- /dev/null +++ b/mlir/python/mlir/dialects/index.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._index_ops_gen import * +from ._index_enum_gen import * diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py new file mode 100644 index 000000000..d387c12de --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -0,0 +1,354 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Re-export the objects provided by pybind. +from ..._mlir_libs._mlirDialectsLinalg import * + +# These are the backing OpView classes generated from the linalg tablegen +# definitions following these steps: +# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. +from .._linalg_ops_gen import * +from .._linalg_enum_gen import * +from .._linalg_enum_gen import _iteratortypeenum + +# These are the ground truth functions defined as: +# ``` +# @linalg_structured_op +# def matmul(A=TensorDef(T1, S.M, S.K), +# B=TensorDef(T2, S.K, S.N), +# C=TensorDef(U, S.M, S.N, output=True)): +# ``` +# using the linalg-py eDSL. +# The linalg-py eDSL builds a python representation (PyRepr) that is +# used in following ways: +# 1. PyRepr -> YAML to generate the C++ and Python .td files. These +# then turn into the core C++ Op classes and Python OpView classes +# respectively (made available in _linalg_ops_gen). The generic OpView class +# mechanism makes the C++ classes available to python through the CAPI. +# PyRepr -> YAML currently occurs before compiler compile time. +# The other steps in this category occur at compiler compile time. +# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the +# _linalg_ops_gen classes and the OpView mechanism to build IR at +# runtime in python: +# a. by default, the Named Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR: +# ``` +# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) +# -> tensor<4x8xf32> +# ``` +# b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR: +# ``` +# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} +# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) { +# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): +# ... +# linalg.yield %3 : f32 +# } -> tensor<4x8xf32> +# ``` +# 3. PyRepr -> Runtime Custom Op definitions: directly generates a +# linalg.generic form like in 2.b. +# !!!WARNING!!!: if one creates a runtime custom op with the same name +# as an existing core named op, step 2. will likely take precedence. +# TODO: guard against surprises and fail create Runtime Custom Ops with +# the same name as existing Core Named Ops. +from .opdsl.ops.core_named_ops import * + +from ...ir import * +from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_result_or_op_results as _get_op_result_or_op_results, + _dispatch_mixed_values, +) +from ...extras.meta import region_op + + +def transpose( + input: Union[Operation, OpView, Sequence[Value]], + *, + outs: List[Union[Operation, OpView, Sequence[Value]]], + permutation: Union[DenseI64ArrayAttr, List[int]], +): + input = _get_op_result_or_value(input) + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = TransposeOp( + result=result_types, + input=input, + init=init, + permutation=permutation, + ) + fill_builtin_region(op.operation) + return op + + +def broadcast( + input: Union[Operation, OpView, Sequence[Value]], + *, + outs: List[Union[Operation, OpView, Sequence[Value]]], + dimensions: Union[DenseI64ArrayAttr, List[int]], +): + input = _get_op_result_or_value(input) + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = BroadcastOp( + result=result_types, + input=input, + init=init, + dimensions=dimensions, + ) + fill_builtin_region(op.operation) + return op + + +@register_attribute_builder("IteratorTypeArrayAttr") +def _IteratorTypeArrayAttr(x, context): + return ArrayAttr.get([_iteratortypeenum(v, context) for v in x]) + + +# The underscore is needed here so that there's no collision with opdsl generation. +class GenericOp_(GenericOp): + def __init__( + self, + inputs, + outputs, + indexing_maps, + iterator_types, + *, + doc=None, + library_call=None, + loc=None, + ip=None, + ): + result_types = [] + if isinstance(outputs[0].type, RankedTensorType): + result_types = [o.type for o in outputs] + + super().__init__( + result_types, + inputs, + outputs, + indexing_maps, + iterator_types, + doc=doc, + library_call=library_call, + loc=loc, + ip=ip, + ) + element_types = [i.type.element_type for i in inputs] + [ + o.type.element_type for o in outputs + ] + self.regions[0].blocks.append(*element_types) + + +generic = region_op(GenericOp_, terminator=YieldOp) + + +def _create_matmul_like_op( + op_type, + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = op_type( + result_tensors=result_types, + inputs=ins, + outputs=[init], + indexing_maps=indexing_maps, + cast=cast, + ) + fill_builtin_region(op.operation) + return op + + +def matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return _get_op_result_or_op_results( + _create_matmul_like_op( + MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + +def batch_matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + +def batch_reduce_matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + +def contract( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Sequence[AffineMapAttr], + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return _get_op_result_or_op_results( + _create_matmul_like_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + +# Extend and shadow the TableGen-derived version to make sure correct default +# indexing_maps are derived (as there is no mechanism for doing so given the +# Python API bypasses the C++-builders). +class ElementwiseOp_(ElementwiseOp): + def __init__( + self, + result_tensors, + inputs, + outputs, + kind, + *, + indexing_maps=None, + loc=None, + ip=None, + ): + if indexing_maps is None: + inputs = [_get_op_result_or_value(in_) for in_ in inputs] + for in0, in1 in zip(inputs[:-1], inputs[1:]): + assert in0.type == in1.type + output = _get_op_result_or_value(outputs[0]) + assert inputs[0].type == output.type + num_args = len(inputs) + 1 + indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args + + super().__init__( + result_tensors=result_tensors, + inputs=inputs, + outputs=outputs, + kind=kind, + indexing_maps=indexing_maps, + loc=loc, + ip=ip, + ) + + +ElementwiseOp = ElementwiseOp_ + + +def elementwise( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + kind: Union[ElementwiseKind, Attribute], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) != 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = ElementwiseOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + kind=kind, + indexing_maps=indexing_maps, + ) + fill_builtin_region(op.operation) + return _get_op_result_or_op_results(op) + + +def pack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + padding_value=None, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + PackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + padding_value=padding_value, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) + ) + + +def unpack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + UnPackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) + ) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/__init__.py b/mlir/python/mlir/dialects/linalg/opdsl/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/__init__.py rename to mlir/python/mlir/dialects/linalg/opdsl/__init__.py diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py new file mode 100644 index 000000000..2f6513199 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -0,0 +1,90 @@ +#!/usr/bin/which python +# Command line tool to load an oplib module and dump all of the operations +# it contains in some format. +"""Loads one or more modules containing op definitions and dumps them. + +The dump format can be: + +* `--dump_format=yaml` (default) +* `--dump_format=repr` + +Positional arguments are interpreted as module names (optionally, relative to +this module). Loose module files can be specified via `--file `. + +Sample usage: + # Dump the YAML op definitions for the core named ops (as in the dialect + # source tree). + python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops + +Note: YAML output is emitted in "document list" format with each operation +as its own "document". Practically, this means that each operation (or group +of composite ops) is emitted with a "---" preceding it, which can be useful +for testing. +""" + +import argparse +import importlib + +from .lang import * +from .lang.config import * +from .lang.yaml_helper import * + + +def create_arg_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description="Dump an oplib in various formats") + p.add_argument( + "modules", metavar="M", type=str, nargs="*", help="Op module to dump" + ) + p.add_argument( + "--file", metavar="F", type=str, nargs="*", help="Python op file to dump" + ) + p.add_argument( + "--format", + type=str, + dest="format", + default="yaml", + choices=("yaml", "repr"), + help="Format in which to dump", + ) + return p + + +def load_module_from_file(module_name, file_path): + spec = importlib.util.spec_from_file_location(module_name, file_path) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + return m + + +def main(args): + # Load all configs. + configs = [] + modules = [] + for module_name in args.modules: + modules.append( + importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl") + ) + for i, file_path in enumerate(args.file or []): + modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) + for m in modules: + for attr_name, value in m.__dict__.items(): + # TODO: This class layering is awkward. + if isinstance(value, DefinedOpCallable): + try: + linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) + except Exception as e: + raise ValueError( + f"Could not create LinalgOpConfig from {value.op_def}" + ) from e + configs.extend(linalg_config) + + # Print. + if args.format == "yaml": + print(yaml_dump_all(configs)) + elif args.format == "repr": + for config in configs: + print(repr(config)) + + +if __name__ == "__main__": + main(create_arg_parser().parse_args()) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/__init__.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/__init__.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/__init__.py diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py new file mode 100644 index 000000000..9fa626dfa --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -0,0 +1,306 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""DSL for constructing affine expressions and maps. + +These python wrappers allow construction of affine expressions in a more +pythonic fashion that is later instantiated as an IR AffineExpr. Separating the +AST from construction of the map allows for manipulations of symbols and dims +beyond the scope of one expression. + +Affine expression construction: + >>> with _ir.Context(): + ... s = AffineBuildState() + ... (S.K + S.M).build(s) + ... (S.K * S.M).build(s) + ... (S.K // S.M).build(s) + ... (S.K / S.M).build(s) + ... (S.K % 4).build(s) + ... (D.i + D.j * 4).build(s) + ... s + AffineExpr(s0 + s1) + AffineExpr(s0 * s1) + AffineExpr(s0 floordiv s1) + AffineExpr(s0 ceildiv s1) + AffineExpr(s0 mod 4) + AffineExpr(d0 + d1 * 4) + AffineBuildState< + symbols={'K': 0, 'M': 1} + dims={'i': 0, 'j': 1}> + +In the DSL, dimensions and symbols are name-uniqued instances of DimDef and +SymbolDef. There are shortcut "expando" instances that will create a +corresponding DimDef/SymbolDef upon accessing an attribute: + +Referencing a named dimension: + + >>> D.i + Dim(i) + >>> D.a is D.b + False + >>> D.a is D.a + True + +Referencing a named symbol: + + >>> S.foobar + Symbol(foobar) + >>> S.a is S.b + False + >>> S.a is S.a + True +""" + +from typing import Callable, Dict, Optional, Tuple, Union + +from ..... import ir as _ir + +__all__ = [ + "AffineBuildState", + "AffineExprDef", + "D", + "DimDef", + "S", + "SymbolDef", +] + + +class AffineBuildState: + """Internal state for the AffineExprDef._create impls. + + Note that a "local" AffineBuildState can be created relative to a "global" + AffineBuildState. In that case, any affine expressions built will inherit + symbol and dim bindings from the global state and will update both as new + ones are discovered. This allows for building expressions across contexts + which share a common symbol and dim space. + """ + + def __init__( + self, + *, + global_state: "AffineBuildState" = None, + allow_new_symbols: bool = True, + allow_new_dims: bool = True, + ): + if not global_state: + self.all_symbols = dict() # type: Dict[str, int] + self.all_dims = dict() # type: Dict[str, int] + else: + # Alias the global dict. + self.all_symbols = global_state.all_symbols + self.all_dims = global_state.all_dims + + # Map of symbols and dims in the current build. + self.local_symbols = dict() # type: Dict[str, int] + self.local_dims = dict() # type: Dict[str, int] + self.allow_new_symbols = allow_new_symbols + self.allow_new_dims = allow_new_dims + + def get_dim(self, dimname: str) -> int: + """Gets the dim position given a name.""" + pos = self.all_dims.get(dimname) + if pos is None: + if not self.allow_new_dims: + raise ValueError( + f"New dimensions not allowed in the current affine expression: " + f"Requested '{dimname}', Availble: {self.all_dims}" + ) + pos = len(self.all_dims) + self.all_dims[dimname] = pos + self.local_dims[dimname] = pos + return pos + + def get_symbol(self, symname: str) -> int: + """Geta a symbol position given a name.""" + pos = self.all_symbols.get(symname) + if pos is None: + if not self.allow_new_symbols: + raise ValueError( + f"New symbols not allowed in the current affine expression: " + f"Requested '{symname}', Availble: {self.all_symbols}" + ) + pos = len(self.all_symbols) + self.all_symbols[symname] = pos + self.local_symbols[symname] = pos + return pos + + @property + def local_dim_count(self) -> int: + return len(self.local_dims) + + @property + def local_symbol_count(self) -> int: + return len(self.local_symbols) + + @property + def dim_count(self) -> int: + return len(self.all_dims) + + @property + def symbol_count(self) -> int: + return len(self.all_symbols) + + def __repr__(self): + lines = [f"AffineBuildState<"] + lines.append(f" symbols={self.local_symbols}") + lines.append(f" dims={self.local_dims}>") + return "\n".join(lines) + + +class AffineExprDef: + """Base class for an affine expression being defined.""" + + def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: + """Builds the corresponding _ir.AffineExpr from the definitions.""" + state = AffineBuildState() if state is None else state + expr = self._create(state) + return expr + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + raise NotImplementedError() + + @staticmethod + def coerce_from(py_value): + if isinstance(py_value, int): + return AffineConstantExpr(py_value) + assert isinstance(py_value, AffineExprDef) + return py_value + + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + callback(self) + + def __add__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) + + def __mul__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) + + def __mod__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) + + def __floordiv__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) + + def __truediv__(lhs, rhs): + # TODO: Not really a ceil div - taking liberties for the DSL. + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) + + +class AffineConstantExpr(AffineExprDef): + """An affine constant being defined.""" + + def __init__(self, value: int): + assert isinstance(value, int) + self.value = value + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return _ir.AffineConstantExpr.get(self.value) + + def __repr__(self): + return f"Const({self.value})" + + +class AffineBinaryExprDef(AffineExprDef): + """An affine binary expression being defined.""" + + def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): + self.ir_ctor = ir_ctor + self.lhs = lhs + self.rhs = rhs + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) + + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + super().visit_affine_exprs(callback) + self.lhs.visit_affine_exprs(callback) + self.rhs.visit_affine_exprs(callback) + + def __repr__(self): + return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" + + +class DimDef(AffineExprDef): + """Represents a named dimension.""" + + ALL_DIMS = dict() # type: Dict[str, "DimDef"] + + def __new__(cls, dimname: str): + existing = cls.ALL_DIMS.get(dimname) + if existing is not None: + return existing + new = super().__new__(cls) + new.dimname = dimname + cls.ALL_DIMS[dimname] = new + return new + + def __repr__(self): + return f"Dim({self.dimname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_dim(self.dimname) + return _ir.AffineDimExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class ExpandoDims: + def __getattr__(self, n): + return cls(n) + + return ExpandoDims() + + +class SymbolDef(AffineExprDef): + """Represents a named symbol. + + >>> s1 = SymbolDef("s1") + >>> s1 + Symbol(s1) + >>> s2 = SymbolDef("s2") + >>> s1 is s2 + False + >>> s1 is SymbolDef("s1") + True + """ + + ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] + + def __new__(cls, symname: str): + existing = cls.ALL_SYMBOLS.get(symname) + if existing is not None: + return existing + new = super().__new__(cls) + new.symname = symname + cls.ALL_SYMBOLS[symname] = new + return new + + def __repr__(self): + return f"Symbol({self.symname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_symbol(self.symname) + return _ir.AffineSymbolExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class ExpandoSymbols: + def __getattr__(self, n): + return cls(n) + + return ExpandoSymbols() + + +# Global accessor for on-demand dims and symbols. +D = DimDef.create_expando() +S = SymbolDef.create_expando() diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py new file mode 100644 index 000000000..4f81a3874 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -0,0 +1,899 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Model classes representing a tensor comprehension. + +These classes model the language more at an AST level as evaluated. Reasoning +about it typically involves processing this form into config objects that +represent actual op definitions (i.e. YAML). +""" + +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple +from enum import Enum + +from ..... import ir as _ir +from .affine import * +from .scalar_expr import * +from .types import * +from .yaml_helper import * + +############################################################################### +# Tensor expression nodes. +############################################################################### + + +class TensorExpression: + """An expression that can appear on the RHS of a comprehension.""" + + def to_scalar_expression(self) -> ScalarExpression: + raise NotImplementedError() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + """Visits all tensor expression reachable by the expression.""" + callback(self) + + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" + + def visit_dim_def(dim_def: AffineExprDef): + if isinstance(dim_def, DimDef): + uses.add(dim_def) + + def visit_affine_exprs(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + for ind in expr.indices: + ind.visit_affine_exprs(visit_dim_def) + if isinstance(expr, TensorReduceFn): + for ind in expr.reduce_fn.reduce_dims: + ind.visit_affine_exprs(visit_dim_def) + + self.visit_tensor_exprs(visit_affine_exprs) + + def collect_tensor_uses(self, uses: Set["TensorUse"]): + """Collects all TensorUses reachable through this expression.""" + + def visit_tensor_use(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + uses.add(expr) + + self.visit_tensor_exprs(visit_tensor_use) + + def collect_indices(self, indices: Set["index"]): + """Collects all index accesses reachable through this expression.""" + + def visit_index(expr: "TensorExpression"): + if isinstance(expr, index): + indices.add(expr) + + self.visit_tensor_exprs(visit_index) + + def collect_scalar_uses(self, uses: Set["ScalarDef"]): + """Collects all ScalarDefs reachable through this expression.""" + + def visit_scalar_def(expr: "TensorExpression"): + if isinstance(expr, ScalarDef): + uses.add(expr) + + self.visit_tensor_exprs(visit_scalar_def) + + def __add__(self, rhs: "TensorExpression") -> "TensorExpression": + return BinaryFn.add(self, rhs) + + def __mul__(self, rhs) -> "TensorExpression": + return BinaryFn.mul(self, rhs) + + def __sub__(self, rhs) -> "TensorExpression": + return BinaryFn.sub(self, rhs) + + def __truediv__(self, rhs) -> "TensorExpression": + return BinaryFn.div(self, rhs) + + def __hash__(self): + return hash(id(self)) + + +class TensorUse(TensorExpression): + """A used tensor represented by its (tensor_name, indices). + + Note that forming a comprehension via direct assignment is performed through + __setitem__ on the TensorDef level. However, performing a reduction with + compound ops (+=, *=, etc) is done by doing a: + TensorDef.__getitem__ + TensorUse.__iadd__ + TensorDef.__setitem__ + """ + + def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]): + self.operand_def = operand_def + self.indices = tuple(indices) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.tensor_name).expr() + + @property + def tensor_name(self) -> str: + name = self.operand_def.name + assert name is not None, "TensorDef not registered with an op" + return name + + def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: + # Computes the reduction dims for implicit reductions. Assumes that the rhs + # is the expression being reduced and self is being reduced into. Any + # indices referenced on the rhs and not in self are considered reduction + # dims and will be ordered as encountered on the rhs. + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) + return rhs_dims - lhs_dims + + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) + + def __repr__(self): + return ( + f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]" + ) + + +class TensorFn(TensorExpression): + """Application of a tensor function.""" + + def __init__( + self, + kind: "FunctionKind", + name: Optional[str], + operand_def: Optional["OperandDef"], + type_var: Optional[TypeVar], + args: Sequence[TensorExpression], + ): + if bool(name) + bool(operand_def) != 1: + raise ValueError("One of 'name', 'operand_def' must be specified") + self.name = name + self.kind = kind + self.operand_def = operand_def + self.type_var = type_var + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.operand_def: + assert self.operand_def.name, "TensorFn not registered with an op" + attr_name = self.operand_def.name if self.operand_def else None + args = [arg.to_scalar_expression() for arg in self.args] + return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + name = self.operand_def.name if self.operand_def else self.name + return ( + f"{self.kind.name}.{name}(type_var={self.type_var}, " + f"args={', '.join(repr(a) for a in self.args)})" + ) + + +class TensorReduceFn(TensorExpression): + """Application of a reduction function. + + This captures the lhs (initial value) separately from the rhs. + """ + + def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): + self.reduce_use = reduce_use + self.lhs = None # type: Optional[TensorUse] + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError( + f"Cannot scalarize a TensorReduceFn that has not been " + f"bound to its lhs: {self}" + ) + full_args = [self.lhs.to_scalar_expression()] + [ + arg.to_scalar_expression() for arg in self.args + ] + fn_name = None + attr_name = None + if self.reduce_use.binary_fn: + fn_name = self.reduce_use.binary_fn.fn_name + if self.reduce_use.binary_attr: + attr_name = self.reduce_use.binary_attr.operand_def.name + return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" + + +class const(TensorExpression): + """Returns the given constant floating point or integer value.""" + + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)) + ) + else: + raise ValueError(f"const requires int or float but got {type(value)}") + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.value).expr() + + def __repr__(self): + return f"const({self.value})" + + +class index(TensorExpression): + """Returns the iteration index for a given dimension name. + + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ + + def __init__(self, dim: DimDef): + self.dim_def = dim + self.dim = -1 + + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() + + def __repr__(self): + return f"index({repr(self.dim)})" + + +############################################################################### +# Function types and function definitions. +############################################################################### + + +class FunctionKind(Enum): + UNARY = 0 + BINARY = 1 + TERNARY = 2 + TYPE = 3 + + +class UnaryFnType: + """Unary function. + + A unary function takes one tensor expression and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) + + def __repr__(self): + return f"{self.fn_name}" + + +class UnaryFn: + """Unary function namespace.""" + + exp = UnaryFnType("exp") + log = UnaryFnType("log") + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") + floor = UnaryFnType("floor") + negf = UnaryFnType("negf") + reciprocal = UnaryFnType("reciprocal") + round = UnaryFnType("round") + sqrt = UnaryFnType("sqrt") + rsqrt = UnaryFnType("rsqrt") + square = UnaryFnType("square") + tanh = UnaryFnType("tanh") + erf = UnaryFnType("erf") + + +class BinaryFnType: + """Binary function. + + A binary function takes two tensor expressions and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) + + def __repr__(self): + return f"{self.fn_name}" + + +class BinaryFn: + """Binary function namespace. + + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsigned -> `arith.MaxUIOp` + """ + + add = BinaryFnType("add") + sub = BinaryFnType("sub") + mul = BinaryFnType("mul") + div = BinaryFnType("div") + div_unsigned = BinaryFnType("div_unsigned") + max_signed = BinaryFnType("max_signed") + min_signed = BinaryFnType("min_signed") + max_unsigned = BinaryFnType("max_unsigned") + min_unsigned = BinaryFnType("min_unsigned") + powf = BinaryFnType("powf") + + +class TernaryFnType: + """Ternary function. + + A ternary function takes three tensor expressions and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__( + self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression + ) -> "TensorFn": + return TensorFn( + FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2] + ) + + def __repr__(self): + return f"{self.fn_name}" + + +class TernaryFn: + """Ternary function namespace.""" + + select = TernaryFnType("select") + + +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast_signed`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast_signed(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + + cast_signed = TypeFnType("cast_signed") + cast_unsigned = TypeFnType("cast_unsigned") + + +class ReduceFnUse: + """Reduction function use. + + A reduction use specifies the reduction function and dimensions. + """ + + def __init__( + self, + binary_fn: Optional[BinaryFnType], + binary_attr: Optional["BinaryFnAttrDef"], + *reduce_dims: DimDef, + ): + if bool(binary_fn) + bool(binary_attr) != 1: + raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") + self.binary_fn = binary_fn + self.binary_attr = binary_attr + self.reduce_dims = reduce_dims + + def __call__(self, *args: TensorExpression) -> "TensorReduceFn": + return TensorReduceFn(self, args) + + def __repr__(self): + fn = self.binary_fn if self.binary_fn else self.binary_attr + return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})" + + +class ReduceFnType: + """Reduction function. + + A binary function that reduces its RHS into its LHS. + """ + + def __init__(self, binary_fn: BinaryFnType): + if not isinstance(binary_fn, BinaryFnType): + raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") + self.binary_fn = binary_fn + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.binary_fn, None, *reduce_dims) + + def __repr__(self): + return f"reduce_{repr(self.binary_fn)}" + + +class ReduceFn: + add = ReduceFnType(BinaryFn.add) + mul = ReduceFnType(BinaryFn.mul) + max_signed = ReduceFnType(BinaryFn.max_signed) + min_signed = ReduceFnType(BinaryFn.min_signed) + max_unsigned = ReduceFnType(BinaryFn.max_unsigned) + min_unsigned = ReduceFnType(BinaryFn.min_unsigned) + + +############################################################################### +# Operand definitions. +############################################################################### + + +class OperandKind(Enum): + INPUT_TENSOR = 0 + SCALAR = 1 + OUTPUT_TENSOR = 2 + INDEX_ATTR = 3 + UNARY_FN_ATTR = 4 + BINARY_FN_ATTR = 5 + TERNARY_FN_ATTR = 6 + TYPE_FN_ATTR = 7 + + +class OperandDef: + """Definition of an operand passed to an operation. + + Keep the meta information of Tensor, Scalar, and Attribute operands and + provide the shared registration functionality. + """ + + def __init__( + self, + kind: OperandKind, + type_var: Optional[TypeVar] = None, + size_exprs: Optional[Sequence[AffineExprDef]] = None, + index_dims: Optional[Sequence[DimDef]] = None, + default_indices: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None, + ): + if type_var and not isinstance(type_var, TypeVar): + raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.size_exprs = size_exprs + self.index_dims = index_dims + self.default_indices = default_indices + self.default_fn = default_fn + self.kind = kind + self.name = None # type: Optional[str] + self.registered_index = -1 # type: int + + def attach(self, index: int, name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError(f"OperandDef already registered with an op: {self}") + self.registered_index = index + self.name = name + self.owner = owner + + def is_input(self) -> bool: + return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR + + def is_tensor(self) -> bool: + return ( + self.kind == OperandKind.INPUT_TENSOR + or self.kind == OperandKind.OUTPUT_TENSOR + ) + + def is_attribute(self) -> bool: + return ( + self.kind == OperandKind.INDEX_ATTR + or self.kind == OperandKind.UNARY_FN_ATTR + or self.kind == OperandKind.BINARY_FN_ATTR + or self.kind == OperandKind.TERNARY_FN_ATTR + or self.kind == OperandKind.TYPE_FN_ATTR + ) + + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + return ( + f"{self.name}:OperandDef(kind={self.kind.name}, " + f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " + f"index_dims={self.index_dims}, " + f"default_indices={self.default_indices}, " + f"default_fn={self.default_fn})" + ) + + +class TensorDef: + """Tensor operand definition. + + Tensor operands are indexed using the associated indexing_map when forwarded + to the body of the structured op. A unique name identifies the tensor operands + and an index determines their position in the operation's parameter list. A + tensor definition takes type, a shape, and an optional flag to mark output + tensors. Additionally, a tuple of index dimensions may be used to map the + tensor to the loop dimensions of the operation. This mapping is needed to + compute the indexing map of shape-only tensors that have no uses. + """ + + def __init__( + self, + type_var: TypeVar, + *shape: AffineExprDef, + index_dims: Optional[Sequence[DimDef]] = None, + output: bool = False, + ): + if index_dims and len(shape) != len(index_dims): + raise ValueError( + f"Expected the shape rank {len(shape)} to match the " + f"number of index_dims {len(index_dims)}" + ) + if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): + raise ValueError( + f"TensorDef requires index dims of type DimDef but " f"got {index_dims}" + ) + kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR + self.operand_def = OperandDef( + kind, type_var=type_var, size_exprs=shape, index_dims=index_dims + ) + + def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: + assert self.operand_def.owner, "TensorDef is not registered with an op" + state = AffineBuildState( + global_state=self.operand_def.owner._affine_state, allow_new_symbols=False + ) + if not isinstance(dims, tuple): + dims = (dims,) # Handle single subscript case. + # Special case: (None) is a 0d-scalar use. + if dims == (None,): + dims = () + + exprs = [] + for expr_def in dims: + if not isinstance(expr_def, AffineExprDef): + raise KeyError( + "A TensorDef can only be subscripted by a tuple of affine dims" + ) + exprs.append(expr_def) + return TensorUse(self.operand_def, exprs) + + def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): + """Creates a new 1:1 comprehension by binding this tensor to an expression. + + Note that due to the way assignment works in Python, we have to capture + direct assignment as a setitem on the TensorDef. + """ + if not isinstance(value, TensorExpression): + raise ValueError( + f"Only TensorExpressions can be assigned to TensorDefs. " + f"Got: {repr(value)}" + ) + use = self[dims] + comp = Comprehension((use, value)) + self.operand_def.owner.comprehensions.append(comp) + + +class ScalarDef(TensorExpression): + """Scalar operand definition. + + Scalar operands are forwarded to the body of the structured op as they are. + A unique name identifies the scalars and an index determines their position in + the operation's parameter list. + """ + + def __init__(self, type_var: TypeVar): + self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) + + @property + def scalar_name(self) -> str: + name = self.operand_def.name + assert name is not None, "ScalarDef not registered with an op" + return name + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.scalar_name).expr() + + +class IndexAttrDef: + """Index attribute definition. + + Index attributes provide a way to define and set symbols that can be used in + indexing expressions. Every attribute specifies a tuple of symbols that at + compile-time are replaced by integer values as well as their default values. + """ + + def __init__(self, *sizes: SymbolDef, default: Sequence[int]): + if any(not isinstance(size, SymbolDef) for size in sizes): + raise ValueError( + f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}" + ) + if any(not isinstance(default_val, int) for default_val in default): + raise ValueError( + f"IndexAttrDef requires default values of type int " + f"but got {default}" + ) + if len(sizes) != len(default): + raise ValueError( + f"IndexAttrDef expects {len(sizes)} default values " + f"but got {len(default)}" + ) + self.operand_def = OperandDef( + OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default + ) + + +class UnaryFnAttrDef: + """Unary function attribute definition. + + Unary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default unary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "UnaryFnType"): + if not isinstance(default, UnaryFnType): + raise ValueError( + f"UnaryFnAttrDef requires default of type UnaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) + + +class BinaryFnAttrDef: + """Binary function attribute definition. + + Binary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default binary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "BinaryFnType"): + if not isinstance(default, BinaryFnType): + raise ValueError( + f"BinaryFnAttrDef requires default of type BinaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1]) + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) + + +class TernaryFnAttrDef: + """Ternary function attribute definition. + + Ternary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default Ternary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TernaryFnType"): + if not isinstance(default, TernaryFnType): + raise ValueError( + f"TernaryFnAttrDef requires default of type TernaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: + return TensorFn( + FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1] + ) + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) + + +class TypeFnAttrDef: + """Type conversion function attribute definition. + + Type conversion function attributes provide a way to make type conversions + parameterizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError( + f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) + + +############################################################################### +# Operation definition. +############################################################################### + + +class Comprehension: + """Represents a single comprehension.""" + + def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): + self.definitions = list() # List[TensorUse] + self.values = list() # List[TensorExpression] + + # Find the lhs to reduction rhs. + for assign, value in bindings: + if isinstance(value, TensorReduceFn): + if value.lhs: + raise ValueError(f"Reduction expression already assigns: {value}") + value.lhs = assign + self.definitions.append(assign) + self.values.append(value) + + @property + def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: + """Gets the reduction dims for the comprehension or None.""" + result = set() + for use in self.values: + if isinstance(use, TensorReduceFn): + result.add(use.reduce_use.reduce_dims) + else: + result.add(tuple()) + return result + + def __repr__(self): + if len(self.definitions) > 1: + defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" + values_repr = f"({', '.join(repr(v) for v in self.values)})" + else: + defs_repr = f"{repr(self.definitions[0])}" + values_repr = f"{repr(self.values[0])}" + + return f"{defs_repr} = {values_repr}" + + +class OpInterfaceDef: + """An interface that an op implements.""" + + def __init__(self, cpp_name: str): + self.cpp_name = cpp_name + + +ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") +ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") +FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") + + +class OpDefinitionDef: + """A method that an op implements.""" + + def __init__(self, def_name: str): + self.def_name = def_name + + +Canonicalizer = OpDefinitionDef("hasCanonicalizer") + + +class OpMetadataDef(YAMLObject): + """Metadata about the op (generally not behavior impacting).""" + + yaml_tag = "!LinalgOpMetadata" + + def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): + self.name = name + self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name + self.doc = doc + self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] + + def to_yaml_custom_dict(self): + d = dict( + name=self.name, + cpp_class_name=self.cpp_class_name, + doc=self.doc, + ) + if self.implements: + d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] + return d + + +class LinalgOpDef: + """Definition of a linalg op.""" + + def __init__( + self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None + ): + self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) + self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] + self.comprehensions = list() # type: List[Comprehension] + self._affine_state = AffineBuildState() + + def add_operand(self, name: str, operand: OperandDef): + """Registers an operand.""" + if name in self.registered_operands: + raise ValueError( + f"The operand {name} is already registered " + f"to {self.registered_operands['name']}" + ) + structured_op_methods = [ + "inputs", + "outputs", + "result_tensors", + "region", + "iterator_types", + "indexing_maps", + "getRegionBuilder", + "getLibraryCallName", + ] + if operand.is_attribute() and name in structured_op_methods: + raise ValueError( + f"The attribute name {name} conflicts with a structured " + f"op method name" + ) + # Ensure output tensors are registered after input tensors and scalars and + # attributes are registered after all other operand types. + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OUTPUT_TENSOR and any( + op_def.is_attribute() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Output {name} registered after an attribute") + operand.attach(len(self.registered_operands), name, self) + self.registered_operands[name] = operand + + def __repr__(self): + lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"] + for name, operand in self.registered_operands.items(): + lines.append(f" {operand}") + if self.comprehensions: + lines[-1] += " {" + for comprehension in self.comprehensions: + lines.append(f" {comprehension}") + lines.append("}") + return "\n".join(lines) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py new file mode 100644 index 000000000..d522d5712 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -0,0 +1,488 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Represents configured ops as emitted for code generation. + +Classes in this module generally are directly serializable to YAML for use +by the code generator. + +TODO: These should just be dumb containers or serialization code but they +currently encode too many details of how the language is interpreted. Move this +to helpers on the comprehension objects themselves. +""" + +from typing import Dict, Optional + +from ..... import ir as _ir +from .comprehension import * +from .yaml_helper import * + +__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"] + + +def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: + with affine_map.context: + # Affine map printing/parsing is via an AffineMap attr. + attr = _ir.AffineMapAttr.get(affine_map) + return str(attr) + + +class TensorUseConfig: + """Wrapper around a TensorUse with additional context-bound state.""" + + def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): + self.tensor_use = tensor_use + self.indexing_map = indexing_map + + def __repr__(self): + return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" + + +class OperandDefConfig(YAMLObject): + """Wrapper containing an operand definition with additional state.""" + + yaml_tag = "!LinalgOperandDefConfig" + + def __init__( + self, + operand_def: OperandDef, + shape_map: Optional[_ir.AffineMap] = None, + index_attr_map: Optional[_ir.AffineMap] = None, + ): + self.operand_def = operand_def + self.shape_map = shape_map # type: Optional[_ir.AffineMap] + self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] + self.indexing_map = None # type: Optional[_ir.AffineMap] + + @property + def name(self) -> str: + return self.operand_def.name + + @property + def kind(self) -> OperandKind: + return self.operand_def.kind + + @property + def type_var(self) -> TypeVar: + return self.operand_def.type_var + + def to_yaml_custom_dict(self): + self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) + if self.type_var: + self_dict["type_var"] = self.type_var.name + if self.shape_map: + self_dict["shape_map"] = _serialize_affine_map(self.shape_map) + if self.index_attr_map: + self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) + if self.operand_def.default_indices: + self_dict["default_indices"] = self.operand_def.default_indices + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn + return self_dict + + def __repr__(self): + return ( + f"OperandDefConfig({self.operand_def}, " + f"shape_map={self.shape_map}, " + f"index_attr_map={self.index_attr_map}, " + f"indexing_map={self.indexing_map})" + ) + + +class LinalgIndexingMapsConfig(YAMLObject): + """Abstracts the style of indexing maps that the op exports. + + Presently only static (tied to the op name) indexing maps are supported. In + the future, it is expected that we will have additional variants: + - Dynamic based on attributes + - Dynamic based on operands + Each is expected to require a different variant of specification. + """ + + yaml_tag = "!LinalgIndexingMapsConfig" + + def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): + self.static_indexing_maps = static_indexing_maps + + def to_yaml_custom_dict(self): + if self.static_indexing_maps is not None: + return dict( + static_indexing_maps=[ + _serialize_affine_map(m) for m in self.static_indexing_maps + ] + ) + raise ValueError( + f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)" + ) + + +class LinalgStructuredOpConfig(YAMLObject): + """Configuration for metadata sufficient to construct a linalg named op.""" + + yaml_tag = "!LinalgStructuredOpConfig" + + def __init__( + self, + comprehension: Comprehension, + domain: Sequence[DimDef], + registered_operands: Sequence[OperandDef], + context: Optional[_ir.Context] = None, + ): + self.context = context if context is not None else _ir.Context() + self.affine_state = AffineBuildState() + self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] + self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] + self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] + + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. + collected_tensor_uses = set() + collected_scalar_uses = set() + collected_dim_uses = set() + collected_indices = set() + for write_use, read_use in zip(comprehension.definitions, comprehension.values): + self.writes.append((write_use, read_use)) + + for write_use, read_use in self.writes: + collected_tensor_uses.add(write_use) + read_use.collect_tensor_uses(collected_tensor_uses) + read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) + read_use.collect_indices(collected_indices) + + # Set domain to the sorted list of uses if no domain annotation is given. + if not domain: + domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) + + # Verify the domain dimensions match the used dimensions. + if len(domain) != len(collected_dim_uses) or any( + dim not in collected_dim_uses for dim in domain + ): + raise ValueError( + f"Expected the annotated domain dimensions {domain} to " + f"match the set of dimension used by the tensor " + f"comprehension {collected_dim_uses}" + ) + + # Instantiate the dimensions in the given order. + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + for dim in domain: + dim.build(state=local_state) + + # Collect all attribute definitions. + collected_attr_defs = list() + for operand in registered_operands: + if operand.is_attribute(): + collected_attr_defs.append(operand) + + # Collect all tensors with manual indexing annotation. + collected_index_defs = list() + for operand in registered_operands: + if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError( + f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses." + ) + collected_index_defs.append(operand) + + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() + for use in collected_tensor_uses: + all_operand_defs.append(use.operand_def) + for use in collected_scalar_uses: + all_operand_defs.append(use.operand_def) + for definition in collected_attr_defs: + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index + ) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. + for definition in collected_index_defs: + self.add_indexed_operand(definition) + for use in collected_tensor_uses: + self.add_tensor_use(use) + + # Normalize all shape and indexing maps now that full count of dims and + # symbols are known. + for cuse in self.uses.values(): + cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) + for definition in collected_index_defs: + self.operands[definition].indexing_map = self._normalize_affine_map( + self.operands[definition].indexing_map + ) + for operand_config in self.operands.values(): + if operand_config.shape_map: + operand_config.shape_map = self._normalize_affine_map( + operand_config.shape_map, with_dims=False + ) + if operand_config.index_attr_map: + operand_config.index_attr_map = self._normalize_affine_map( + operand_config.index_attr_map, with_dims=False + ) + + # Now for each write use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for write_use, _ in self.writes: + write_tensor_config = self.operands[write_use.operand_def] + if write_tensor_config.indexing_map: + raise ValueError( + f"Unexpected multi-write to a single tensor: {write_tensor_config}" + ) + write_tensor_config.indexing_map = self.uses[write_use].indexing_map + + # For each read use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for _, read_expr in self.writes: + read_uses = set() # type: Set[TensorUse] + read_expr.collect_tensor_uses(read_uses) + for read_use in read_uses: + read_operand_config = self.operands[read_use.operand_def] + if ( + read_operand_config.indexing_map + and read_operand_config.indexing_map + != self.uses[read_use].indexing_map + ): + raise ValueError( + f"Unexpected multi-read of a tensor with different accesses:" + f"{read_operand_config} vs {read_use}" + ) + read_operand_config.indexing_map = self.uses[read_use].indexing_map + + # Set the indexing map of all scalar uses to the empty map. + for operand_config in self.operands.values(): + if operand_config.operand_def.kind == OperandKind.SCALAR: + operand_config.indexing_map = self._get_scalar_map() + + # Check all registered tensor and scalar operands have an indexing map. + for operand in registered_operands: + if operand.is_attribute(): + continue + if not (operand in self.operands and self.operands[operand].indexing_map): + raise ValueError( + f"Failed to compute an indexing map for operand " f"{operand.name}" + ) + + # Collect reduction dims and ensure all the same. + all_reduction_dims = set(comprehension.all_reduction_dims) + if len(all_reduction_dims) != 1: + raise ValueError( + f"All writes within a generic must have the same reduction " + f"dims. Got: {all_reduction_dims}" + ) + self.reduction_dims = next(iter(all_reduction_dims)) + + # Check the index dimension exists and resolve. + for index in collected_indices: + if index.dim_def.dimname not in self.affine_state.all_dims: + raise ValueError( + f"The dimension {index.dim_def.dimname} is not part of the " + f"iteration domain {self.affine_state.all_dims}" + ) + index.resolve_dimension_name(self.affine_state) + + # Generate the scalar assignments (used to build a body). + self.assignments = [ + ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) + for write_use, read_expr in self.writes + ] + + @property + def ordered_operands(self) -> Sequence[OperandDefConfig]: + return sorted( + self.operands.values(), + key=lambda operand: operand.operand_def.registered_index, + ) + + @property + def ordered_dims(self) -> Sequence[Tuple[str, int]]: + """Gets the ordered list of dim bindings (symbolic name, position). + + TODO: The original parser relies on parse ordering to arrive at the + iterator types, but that ordering is not defined on the Python side, so + this may be ambiguous. + """ + return list(self.affine_state.all_dims.items()) + + @property + def indexing_maps(self) -> Sequence[_ir.AffineMap]: + return [o.indexing_map for o in self.ordered_operands if o.indexing_map] + + @property + def iterator_types(self) -> Sequence[str]: + def get_type(symbolic_name, position): + for reduction_dim_expr in self.reduction_dims: + if reduction_dim_expr.dimname == symbolic_name: + return "reduction" + return "parallel" + + return [get_type(*dim) for dim in self.ordered_dims] + + def add_operand(self, operand_def: OperandDef): + if operand_def in self.operands: + return + if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR): + self.operands[operand_def] = OperandDefConfig(operand_def) + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_dims=False + ) + exprs = [] + for expr in operand_def.size_exprs: + exprs.append(expr.build(state=local_state)) + assert local_state.local_dim_count == 0 + affine_map = _ir.AffineMap.get( + dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs + ) + if operand_def.kind == OperandKind.INDEX_ATTR: + self.operands[operand_def] = OperandDefConfig( + operand_def, index_attr_map=affine_map + ) + else: + self.operands[operand_def] = OperandDefConfig( + operand_def, shape_map=affine_map + ) + + def add_indexed_operand(self, operand_def: OperandDef): + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in operand_def.index_dims: + exprs.append(expr.build(state=local_state)) + self.operands[operand_def].indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + def add_tensor_use(self, tensor_use: TensorUse): + if tensor_use in self.uses: + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in tensor_use.indices: + exprs.append(expr.build(state=local_state)) + indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + use_config = TensorUseConfig(tensor_use, indexing_map) + self.uses[tensor_use] = use_config + + def _get_scalar_map(self) -> _ir.AffineMap: + """Create an empty affine map used to index a scalar.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list(), + ) + + def _normalize_affine_map( + self, affine_map: _ir.AffineMap, with_dims: bool = True + ) -> _ir.AffineMap: + """Normalizes an indexing map to have the max known symbols and dims.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count if with_dims else 0, + symbol_count=self.affine_state.symbol_count, + exprs=list(affine_map.results), + ) + + def to_yaml_custom_dict(self): + self_dict = dict(args=self.ordered_operands) + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + self_dict["indexing_maps"] = LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps + ) + self_dict["iterator_types"] = self.iterator_types + self_dict["assignments"] = self.assignments + return self_dict + + def __repr__(self): + lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] + lines.append("operands=[") + for def_config in self.ordered_operands: + lines.append(f" {repr(def_config)}") + lines.append("], indexing_maps=[") + for m in self.indexing_maps: + lines.append(f" {repr(m)}") + lines.append(f"], iterator_types=[") + for t in self.iterator_types: + lines.append(f" {t}") + lines.append("])") + return "\n".join(lines) + + +class LinalgOpConfig(YAMLObject): + """Container for any supported linalg op type. + + This includes the concrete type by name for ease of parsing by systems + that ignore tags. + """ + + yaml_tag = "!LinalgOpConfig" + + def __init__( + self, + metadata: OpMetadataDef, + *, + structured_op: Optional[LinalgStructuredOpConfig] = None, + ): + self.metadata = metadata + self.structured_op = structured_op + + def to_yaml_custom_dict(self): + self_dict = dict( + metadata=self.metadata, + ) + if self.structured_op: + self_dict["structured_op"] = self.structured_op + return self_dict + + @staticmethod + def from_linalg_op_def( + op_def: LinalgOpDef, context: Optional[_ir.Context] = None + ) -> Sequence["LinalgOpConfig"]: + """Expands a LinalgOpDef into corresponding Linalg configured ops.""" + # TODO: Many LinalgOpDef patterns need to expand to multiple generics. + assert len(op_def.comprehensions) == 1, "Only one comprehension supported" + return [ + LinalgOpConfig( + op_def.metadata, + structured_op=LinalgStructuredOpConfig( + op_def.comprehensions[0], + op_def.domain, + op_def.registered_operands.values(), + context, + ), + ), + ] + + def __repr__(self): + return ( + f"LinalgOpConfig(metadata={self.metadata},\n" + f"structured_op={self.structured_op})" + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py new file mode 100644 index 000000000..8b8726f8f --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -0,0 +1,201 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, List, Sequence, Union + +from contextlib import contextmanager +import functools +import inspect +import threading + +from ..... import ir +from ...._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) +from .comprehension import * +from .config import * +from .emitter import * + +_CONTEXT = threading.local() + +StructuredOpOuts = Union[ + ir.Operation, + ir.OpView, + ir.OpResultList, + Sequence[Union[ir.Value, ir.Operation, ir.OpView]], +] + + +@contextmanager +def bind_op_def(op_def: LinalgOpDef): + if hasattr(_CONTEXT, "current_op_def"): + raise ValueError("Cannot recursively define an operation") + _CONTEXT.current_op_def = op_def + try: + yield op_def + finally: + del _CONTEXT.current_op_def + + +def current_op_def() -> LinalgOpDef: + try: + return _CONTEXT.current_op_def + except AttributeError: + raise ValueError( + "Attempt to access the current op definition being defined " + "but none is set. Did you mean to call this in an op definition?" + ) + + +def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: + if isinstance(outs, (ir.Operation, ir.OpView)): + return _get_op_results_or_values(outs) + elif isinstance(outs, ir.OpResultList): + return outs + + return [_get_op_result_or_value(o) for o in outs] + + +class DefinedOpCallable: + """Callable that wraps any defined op function.""" + + def __init__(self, op_name: str, op_def: LinalgOpDef): + self.op_name = op_name + self.op_def = op_def + + def __call__( + self, + *ins: Union[ir.Operation, ir.OpView, ir.Value], + outs: StructuredOpOuts, + **kwargs, + ): + """Emits the corresponding op definition as IR. + + Most arguments are passed through to the underlying emitter. The following + keyword argument is interpreted here: + emit_generic: Emits a generic form as appropriate (default True). If + False, a named form is emitted (which must have been built in to the + compiler). + """ + emit_generic = kwargs.pop("emit_generic", False) + if not isinstance(emit_generic, bool): + raise ValueError( + f"The named argument 'emit_generic' needs to be " + f" of type bool but got {type(emit_generic)}" + ) + + op_configs = LinalgOpConfig.from_linalg_op_def( + self.op_def, context=ir.Context.current + ) + + if len(op_configs) != 1: + # TODO: Support composite ops. + raise NotImplementedError( + f"Emission of composite linalg ops not supported: {op_configs}" + ) + + ctx = ir.Context.current + linalgDialect = ctx.get_dialect_descriptor("linalg") + fully_qualified_name = "linalg." + self.op_name + emit_generic = emit_generic or not ctx.is_registered_operation( + fully_qualified_name + ) + + op_config = op_configs[0] + out_values = _prepare_structured_op_outs(outs) + in_values = [_get_op_result_or_value(i) for i in ins] + if op_config.structured_op: + if emit_generic: + return emit_generic_structured_op( + op_config.structured_op, *in_values, outs=out_values, **kwargs + ) + else: + return emit_named_structured_op( + op_config.structured_op, + self.op_name, + self.op_def.metadata.cpp_class_name, + *in_values, + outs=out_values, + **kwargs, + ) + + raise NotImplementedError( + f"Emission of linalg op type not supported: {op_config}" + ) + + +def linalg_structured_op( + dsl_func=None, *, op_name=None, op_class_name=None +) -> DefinedOpCallable: + if dsl_func is None: + # Curry the keyword args in for delayed application. + return functools.partial( + linalg_structured_op, op_name=op_name, op_class_name=op_class_name + ) + # Determine default names by introspecting the function. + if op_name is None: + op_name = dsl_func.__name__ + if op_class_name is None: + # Camel case it. + op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" + + op_def = LinalgOpDef( + name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func) + ) + + # Extract arguments and TensorDefs from the signature. + dsl_func_args = list() + sig = inspect.signature(dsl_func) + for param_name, param in sig.parameters.items(): + param_default = param.default + if isinstance( + param_default, + ( + TensorDef, + ScalarDef, + IndexAttrDef, + UnaryFnAttrDef, + BinaryFnAttrDef, + TypeFnAttrDef, + ), + ): + op_def.add_operand(param_name, param_default.operand_def) + else: + raise ValueError( + f"@linalg_structured_op function parameters must be defaulted as " + f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " + f"Found {param_name}: {param_default}" + ) + dsl_func_args.append(param_default) + + # Invoke the DSL func to finish populating the op definition. + with bind_op_def(op_def): + dsl_func(*dsl_func_args) + + # TODO: The returned callable should be an IR emitter but that is not + # upstreamed yet. + return DefinedOpCallable(op_name, op_def) + + +def domain(*dimensions: DimDef): + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) + + +def implements(*interfaces: OpInterfaceDef): + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}" + ) + current_op_def().metadata.implements.extend(interfaces) + + +def defines(*definitions: OpDefinitionDef): + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}" + ) + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py new file mode 100644 index 000000000..254458a97 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -0,0 +1,648 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable, Dict, List, Sequence, Tuple, Union + +from .....ir import * + +from .... import func +from .... import linalg +from .... import math +from .... import arith +from .... import complex +from ...._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) + +from .scalar_expr import * +from .config import * +from .comprehension import * +import numpy as np + +__all__ = [ + "emit_generic_structured_op", + "emit_named_structured_op", + "ValueList", +] + +# Type aliases. +ValueList = Union[Sequence[Value], OpResultList] + + +def prepare_common_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Union[Sequence[int], TypeFnType], +): + all_arg_defs = op_config.ordered_operands + in_arg_defs = [ + d + for d in all_arg_defs + if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] + ] + out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR] + index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR] + fn_attr_arg_defs = [ + d + for d in all_arg_defs + if d.kind + in [ + OperandKind.UNARY_FN_ATTR, + OperandKind.BINARY_FN_ATTR, + OperandKind.TERNARY_FN_ATTR, + OperandKind.TYPE_FN_ATTR, + ] + ] + + # Verify outs is a sequence or a list of results. + if not isinstance(outs, (Sequence, OpResultList)): + raise ValueError( + f"Expected named argument outs to have type Sequence or " + f"OpResultLis but got {type(outs)}" + ) + + # Arity validation. + if len(ins) != len(in_arg_defs): + raise ValueError( + f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}" + ) + if outs and len(outs) != len(out_arg_defs): + raise ValueError( + f"Expected {len(out_arg_defs)} outputs but got " + f"{len(outs)} for {op_config}" + ) + + # Compute a replacement list for all index attribute symbols. + expressions = [] # type: Sequence[AffineExpr] + replacements = [] # type: Sequence[AffineExpr] + for index_attr in index_attr_arg_defs: + index_attr_vals = index_attr.operand_def.default_indices + if index_attr.name in attrs: + index_attr_vals = attrs.get(index_attr.name) + assert index_attr_vals, "Index attribute has no value" + if not all(isinstance(value, int) for value in index_attr_vals): + raise ValueError( + f"Attribute {index_attr.name} needs to be of type " + f"Sequence[int] but got {type(index_attr_vals)}" + ) + results = index_attr.index_attr_map.results # type: AffineExprList + if len(index_attr_vals) != len(results): + raise ValueError( + f"Attribute {index_attr.name} has length {len(results)} " + f"but got {len(index_attr_vals)} values" + ) + for expr, value in zip(results, index_attr_vals): + expressions.append(expr) + replacements.append(AffineConstantExpr.get(value)) + + # Replace all index attribute symbols by their value. + # TODO: Add support for shape symbols. + indexing_maps = [] # type: Sequence[AffineMap] + for curr in op_config.indexing_maps: + for expression, replacement in zip(expressions, replacements): + curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) + indexing_maps.append(curr) + + # TODO: Linalg verification does not currently allow symbols. + # Compress them for now and verify none are left. + indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current) + if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): + raise ValueError( + f"Expected indexing_maps to use no symbols after " + f"replacement and compression but got {indexing_maps}" + ) + + outs, out_types = _infer_structured_outs( + op_config, in_arg_defs, ins, out_arg_defs, outs + ) + + result_types = [t for t in out_types if isinstance(t, RankedTensorType)] + + # Initialize the type dictionary with the predefined types. + type_mapping = dict() # type: Dict[str, Type] + type_mapping["F32"] = F32Type.get() + type_mapping["F64"] = F64Type.get() + type_mapping["I32"] = IntegerType.get_signless(32) + type_mapping["I64"] = IntegerType.get_signless(64) + + # Extract type vars for input/output based types. + block_arg_types = list() # type: List[Type] + for arg_def, arg_element_type in zip( + in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs) + ): + _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) + + # Emit the generic op. + # TODO: Support emission of pure memref form. + indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps]) + iterator_types_attr = ArrayAttr.get( + [ + Attribute.parse(f"#linalg.iterator_type<{s}>") + for s in op_config.iterator_types + ] + ) + + # Compute the index attributes used when emitting a named structured op. + index_attrs = {} # type: Dict[str, DenseElementAttr] + for index_attr in index_attr_arg_defs: + index_attr_vals = attrs.get(index_attr.name) + # Only forward attributes set to a non-default value. + if index_attr_vals: + array = np.array(index_attr_vals, dtype=np.int64) + index_attrs[index_attr.name] = DenseElementsAttr.get(array) + + # Compute the function attribute mapping. + fn_attr_mapping = {} + for fn_attr in fn_attr_arg_defs: + attr_val = fn_attr.operand_def.default_fn + attr_kind = fn_attr.kind + if fn_attr.name in attrs: + fn = attrs.get(fn_attr.name) + if attr_kind == OperandKind.UNARY_FN_ATTR: + if not isinstance(fn, UnaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"UnaryFnType but got {type(attr_val)}" + ) + elif attr_kind == OperandKind.BINARY_FN_ATTR: + if not isinstance(fn, BinaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"BinaryFnType but got {type(attr_val)}" + ) + elif attr_kind == OperandKind.TERNARY_FN_ATTR: + if not isinstance(fn, TernaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"TernaryFnType but got {type(attr_val)}" + ) + else: + if not isinstance(fn, TypeFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}" + ) + attr_val = fn.fn_name + assert attr_val, "Function attribute has no value" + fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) + + return ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) + + +def emit_generic_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get( + [Attribute.parse("#linalg.iterator_type")] * rank + ) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.SCALAR: + indexing_maps.append(scalar_map) + if arg_def.operand_def.is_tensor(): + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps] + ) + + generic_op = linalg.GenericOp( + result_tensors=result_types, + inputs=ins, + outputs=outs, + indexing_maps=indexing_maps_attr, + iterator_types=iterator_types_attr, + doc=None, # TODO: Make optional. + library_call=None, + ) # TODO: Make optional. + + # Construct the body. + block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) + block = generic_op.regions[0].blocks.append(*block_arg_types) + block_arg_mapping = dict(zip(block_arg_names, block.arguments)) + with InsertionPoint(block): + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping) + for assignment in op_config.assignments: + body_builder.assign(assignment) + body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) + + if len(result_types) == 1: + return generic_op.result + else: + return generic_op.results + + +def emit_named_structured_op( + op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # If we get here, there must exist a builtin class `op_class_name`. + ctx = Context.current + fully_qualified_name = "linalg." + op_name + if ( + not ctx.is_registered_operation(fully_qualified_name) + or not op_class_name in linalg.__dict__.keys() + ): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}" + ) + + # Set the index attributes used to compute the indexing maps. + named_op = getattr(linalg, op_class_name)(result_types, ins, outs) + for name, value in index_attrs.items(): + named_op.operation.attributes[name] = value + + # Compute the function attributes by combining operand kind and function name. + for name, (fn_name, kind) in fn_attr_mapping.items(): + assert kind.name.lower().endswith("_attr") + enum_name = kind.name.lower()[:-5] + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg.{enum_name}<{fn_name}>" + ) + + linalg.fill_builtin_region(named_op.operation) + + if len(result_types) == 1: + return named_op.result + else: + return named_op.results + + +class _BodyBuilder: + """Constructs a structured op body by evaluating assignments.""" + + def __init__( + self, + type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value], + fn_attr_mapping: Dict[str, str], + ): + self.type_mapping = type_mapping + self.block_arg_mapping = block_arg_mapping + self.fn_attr_mapping = fn_attr_mapping + self.yield_mapping = dict() # type: Dict[str, Value] + + def assign(self, assignment: ScalarAssign): + if assignment.arg in self.yield_mapping: + raise ValueError( + f"Multiple assignments to the same argument are forbidden: " + f"{assignment}" + ) + self.yield_mapping[assignment.arg] = self.expression(assignment.value) + + def expression(self, expr: ScalarExpression) -> Value: + if expr.scalar_arg: + try: + return self.block_arg_mapping[expr.scalar_arg.arg] + except KeyError: + raise ValueError( + f"Argument {expr.scalar_arg.arg} is not bound for " + f"this structured op." + ) + elif expr.scalar_const: + value_attr = Attribute.parse(expr.scalar_const.value) + return arith.ConstantOp(value_attr.type, value_attr).result + elif expr.scalar_index: + dim_attr = IntegerAttr.get( + IntegerType.get_signless(64), expr.scalar_index.dim + ) + return linalg.IndexOp(dim_attr).result + elif expr.scalar_fn: + kind = expr.scalar_fn.kind.name.lower() + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] + fn = self._get_function(f"_{kind}_{fn_name}") + operand_values = [ + self.expression(operand) for operand in expr.scalar_fn.operands + ] + if expr.scalar_fn.kind == FunctionKind.TYPE: + operand_values = [expr.scalar_fn.type_var.name] + operand_values + return fn(*operand_values) + raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError( + f"Body assignments do not assign all outputs: " f"missing '{n}'" + ) + linalg.YieldOp(output_values) + + def _get_function(self, fn_name: str) -> Callable: + try: + fn = getattr(self, f"{fn_name}") + except AttributeError: + raise ValueError(f"Function '{fn_name}' is not a known function") + return fn + + def _cast( + self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False + ) -> Value: + try: + to_type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError( + f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mapping.keys()}" + ) + if operand.type == to_type: + return operand + if _is_integer_type(to_type): + return self._cast_to_integer(to_type, operand, is_unsigned_cast) + elif _is_floating_point_type(to_type): + return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) + + def _cast_to_integer( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + to_width = IntegerType(to_type).width + operand_type = operand.type + if _is_floating_point_type(operand_type): + if is_unsigned_cast: + return arith.FPToUIOp(to_type, operand).result + return arith.FPToSIOp(to_type, operand).result + if _is_index_type(operand_type): + return arith.IndexCastOp(to_type, operand).result + # Assume integer. + from_width = IntegerType(operand_type).width + if to_width > from_width: + if is_unsigned_cast: + return arith.ExtUIOp(to_type, operand).result + return arith.ExtSIOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncIOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _cast_to_floating_point( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + operand_type = operand.type + if _is_integer_type(operand_type): + if is_unsigned_cast: + return arith.UIToFPOp(to_type, operand).result + return arith.SIToFPOp(to_type, operand).result + # Assume FloatType. + to_width = _get_floating_point_width(to_type) + from_width = _get_floating_point_width(operand_type) + if to_width > from_width: + return arith.ExtFOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncFOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, False) + + def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, True) + + def _unary_exp(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.ExpOp(x).result + raise NotImplementedError("Unsupported 'exp' operand: {x}") + + def _unary_log(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.LogOp(x).result + raise NotImplementedError("Unsupported 'log' operand: {x}") + + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsFOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + + def _unary_negf(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + if _is_complex_type(x.type): + return complex.NegOp(x).result + raise NotImplementedError("Unsupported 'negf' operand: {x}") + + def _binary_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.AddFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.AddIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.AddOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") + + def _binary_sub(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.SubFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.SubIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.SubOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") + + def _binary_mul(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MulFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MulIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.MulOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") + + def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaximumFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") + + def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaximumFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxUIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") + + def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinimumFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") + + def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinimumFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinUIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") + + +def _infer_structured_outs( + op_config: LinalgStructuredOpConfig, + in_arg_defs: Sequence[OperandDefConfig], + ins: Sequence[Value], + out_arg_defs: Sequence[OperandDefConfig], + outs: Union[Sequence[Value], OpResultList], +) -> Tuple[ValueList, List[Type]]: + """Infers implicit outs and output types. + + Respects existing contents of outs if not empty. + + Returns: + normalized outs, output types + """ + # If outs were explicitly provided, we accept them verbatim. + if outs: + return outs, [out.type for out in outs] + + raise NotImplementedError( + f"Output tensor inference not yet supported for " "structured ops" + ) + + +def _get_types_from_values(*values: Value) -> Sequence[Type]: + types = [] + for v in values: + types.append(v.type) + return types + + +def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: + return [odc.operand_def.name for odc in operand_configs] + + +def _add_type_mapping( + operand_config: OperandDefConfig, + operand_type: Type, + type_mapping: Dict[str, Type], + block_arg_types: Sequence[Type], +): + element_or_self_type = operand_type + # Get the element type for tensor operands and the type itself for scalars. + if operand_config.shape_map: + try: + element_or_self_type = ShapedType(operand_type).element_type + except Exception as e: + raise ValueError(f"Expected ShapedType but got {operand_type}") from e + name = operand_config.type_var.name + if name in type_mapping: + if type_mapping[name] != element_or_self_type: + raise ValueError( + f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {element_or_self_type}" + ) + type_mapping[name] = element_or_self_type + block_arg_types.append(element_or_self_type) + + +def _is_complex_type(t: Type) -> bool: + return ComplexType.isinstance(t) + + +def _is_floating_point_type(t: Type) -> bool: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return ( + F64Type.isinstance(t) + or F32Type.isinstance(t) + or F16Type.isinstance(t) + or BF16Type.isinstance(t) + ) + + +def _is_integer_type(t: Type) -> bool: + return IntegerType.isinstance(t) + + +def _is_index_type(t: Type) -> bool: + return IndexType.isinstance(t) + + +def _get_floating_point_width(t: Type) -> int: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py new file mode 100644 index 000000000..86853994c --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -0,0 +1,166 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Models DAGs of scalar math expressions. + +Used for generating region bodies at the "math" level where they are still type +polymorphic. This is modeled to be polymorphic by attribute name for interop +with serialization schemes that are just plain-old-dicts. + +These classes are typically not user accessed and are created as a by-product +of interpreting a comprehension DSL and model the operations to perform in the +op body. The class hierarchy is laid out to map well to a form of YAML that +can be easily consumed from the C++ side, not necessarily for ergonomics. +""" + +from typing import Optional, Sequence + +from .comprehension import * +from .types import * +from .yaml_helper import * + +__all__ = [ + "ScalarAssign", + "ScalarFn", + "ScalarArg", + "ScalarConst", + "ScalarIndex", + "ScalarExpression", +] + + +class ScalarFn: + """A type of ScalarExpression that applies a function.""" + + def __init__( + self, + kind: "FunctionKind", + fn_name: Optional[str], + attr_name: Optional[str], + type_var: Optional["TypeVar"], + operands: Sequence["ScalarExpression"], + ): + if bool(fn_name) + bool(attr_name) != 1: + raise ValueError("One of 'fn_name', 'attr_name' must be specified") + self.kind = kind + self.fn_name = fn_name + self.attr_name = attr_name + self.type_var = type_var + self.operands = operands + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_fn=self) + + def __repr__(self): + name = self.fn_name if self.fn_name else self.attr_name + return ( + f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " + f"operands=[{', '.join(self.operands)}])" + ) + + +class ScalarArg: + """A type of ScalarExpression that references a named argument.""" + + def __init__(self, arg: str): + self.arg = arg + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_arg=self) + + def __repr__(self): + return f"(ScalarArg({self.arg})" + + +class ScalarConst: + """A type of ScalarExpression representing a constant.""" + + def __init__(self, value: str): + self.value = value + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_const=self) + + def __repr__(self): + return f"(ScalarConst({self.value})" + + +class ScalarIndex: + """A type of ScalarExpression accessing an iteration index.""" + + def __init__(self, dim: int): + self.dim = dim + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_index=self) + + def __repr__(self): + return f"(ScalarIndex({self.dim})" + + +class ScalarExpression(YAMLObject): + """An expression on scalar values. + + Can be one of: + - ScalarFn + - ScalarArg + - ScalarConst + - ScalarIndex + """ + + yaml_tag = "!ScalarExpression" + + def __init__( + self, + scalar_fn: Optional[ScalarFn] = None, + scalar_arg: Optional[ScalarArg] = None, + scalar_const: Optional[ScalarConst] = None, + scalar_index: Optional[ScalarIndex] = None, + ): + if ( + bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index) + ) != 1: + raise ValueError( + "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " + "'scalar_index' must be specified" + ) + self.scalar_fn = scalar_fn + self.scalar_arg = scalar_arg + self.scalar_const = scalar_const + self.scalar_index = scalar_index + + def to_yaml_custom_dict(self): + if self.scalar_fn: + scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) + if self.scalar_fn.fn_name: + scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name + if self.scalar_fn.attr_name: + scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name + if self.scalar_fn.type_var: + scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name + scalar_fn_dict["operands"] = list(self.scalar_fn.operands) + return dict(scalar_fn=scalar_fn_dict) + elif self.scalar_arg: + return dict(scalar_arg=self.scalar_arg.arg) + elif self.scalar_const: + return dict(scalar_const=self.scalar_const.value) + elif self.scalar_index: + return dict(scalar_index=self.scalar_index.dim) + else: + raise ValueError(f"Unexpected ScalarExpression type: {self}") + + +class ScalarAssign(YAMLObject): + """An assignment to a named argument (LHS of a comprehension).""" + + yaml_tag = "!ScalarAssign" + + def __init__(self, arg: str, value: ScalarExpression): + self.arg = arg + self.value = value + + def to_yaml_custom_dict(self): + return dict(arg=self.arg, value=self.value) + + def __repr__(self): + return f"ScalarAssign({self.arg}, {self.value})" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py new file mode 100644 index 000000000..4f36029b7 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py @@ -0,0 +1,79 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""Facility for symbolically referencing type variables. + +Type variables are instances of the TypeVar class, which is uniqued by name. +An "expando" accessor `TV` is provided that generates a named TypeVar for +any attribute access: + + >>> TV.T + TypeVar(T) + >>> TV.T is TV.U + False + >>> TV.T is TV.T + True +""" + +from enum import Enum +from typing import Dict + +__all__ = [ + "TypeVar", + "TV", + # Predefined types. + "I32", + "I64", + "F32", + "F64", + # TypeVar aliases. + "T", + "U", + "V", +] + + +class TypeVar: + """A replaceable type variable. + + Type variables are uniqued by name. + """ + + ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] + + def __new__(cls, name: str): + existing = cls.ALL_TYPEVARS.get(name) + if existing is not None: + return existing + new = super().__new__(cls) + new.name = name + cls.ALL_TYPEVARS[name] = new + return new + + def __repr__(self): + return f"TypeVar({self.name})" + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique type vars on attr access.""" + + class ExpandoTypeVars: + def __getattr__(self, n): + return cls(n) + + return ExpandoTypeVars() + + +# Expando access via TV.foo +TV = TypeVar.create_expando() + +# Predefined types. +I32 = TV.I32 +I64 = TV.I64 +F32 = TV.F32 +F64 = TV.F64 + +# Some common type name aliases. +T = TV.T +U = TV.U +V = TV.V diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py new file mode 100644 index 000000000..1672656b3 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py @@ -0,0 +1,53 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +"""YAML serialization is routed through here to centralize common logic.""" + +import sys + +try: + import yaml +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e + +__all__ = [ + "yaml_dump", + "yaml_dump_all", + "YAMLObject", +] + + +class YAMLObject(yaml.YAMLObject): + @classmethod + def to_yaml(cls, dumper, self): + """Default to a custom dictionary mapping.""" + return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) + + def to_yaml_custom_dict(self): + raise NotImplementedError() + + def as_linalg_yaml(self): + return yaml_dump(self) + + +def multiline_str_representer(dumper, data): + if len(data.splitlines()) > 1: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) + + +yaml.add_representer(str, multiline_str_representer) + + +def yaml_dump(data, sort_keys=False, **kwargs): + return yaml.dump(data, sort_keys=sort_keys, **kwargs) + + +def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): + return yaml.dump_all( + data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs + ) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/__init__.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/__init__.py rename to mlir/python/mlir/dialects/linalg/opdsl/ops/__init__.py diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py new file mode 100644 index 000000000..fd4a5a848 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -0,0 +1,1771 @@ +from ..lang import * + +T1 = TV.T1 +T2 = TV.T2 + +Batch = S.Batch + + +@linalg_structured_op +def copy( + I=TensorDef(T1), + O=TensorDef(U, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Copies the tensor elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + defines(Canonicalizer) + O[None] = cast(U, I[None]) + + +@linalg_structured_op +def exp( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies exp(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.exp(I[None]) + + +@linalg_structured_op +def log( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies log(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.log(I[None]) + + +@linalg_structured_op +def abs( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies abs(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.abs(I[None]) + + +@linalg_structured_op +def ceil( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies ceil(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.ceil(I[None]) + + +@linalg_structured_op +def floor( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies floor(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.floor(I[None]) + + +@linalg_structured_op(op_class_name="NegFOp") +def negf( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies negf(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.negf(I[None]) + + +@linalg_structured_op(op_class_name="ReciprocalOp") +def reciprocal( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies reciprocal(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.reciprocal(I[None]) + + +@linalg_structured_op +def round( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies round(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.round(I[None]) + + +@linalg_structured_op +def sqrt( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies sqrt(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.sqrt(I[None]) + + +@linalg_structured_op +def rsqrt( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies rsqrt(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.rsqrt(I[None]) + + +@linalg_structured_op +def square( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies square(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.square(I[None]) + + +@linalg_structured_op +def tanh( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies tanh(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.tanh(I[None]) + + +@linalg_structured_op +def erf( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies erf(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.erf(I[None]) + + +@linalg_structured_op +def add( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Adds two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.add` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.add(lhs[None], rhs[None]) + + +@linalg_structured_op +def sub( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Subtracts two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.sub(lhs[None], rhs[None]) + + +@linalg_structured_op +def mul( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Multiplies two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.mul(lhs[None], rhs[None]) + + +@linalg_structured_op +def div( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.div(lhs[None], rhs[None]) + + +@linalg_structured_op +def div_unsigned( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. For integer + types, performs an unsigned division. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.div_unsigned(lhs[None], rhs[None]) + + +@linalg_structured_op +def max( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the max (signed) between two inputs, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.max` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.max_signed(lhs[None], rhs[None]) + + +@linalg_structured_op +def min( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the min (signed) between two inputs, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.min` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.min_signed(lhs[None], rhs[None]) + + +@linalg_structured_op(op_class_name="PowFOp") +def powf( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`. + + Only applies to floating point values. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.powf` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.powf(lhs[None], rhs[None]) + + +@linalg_structured_op +def select( + cond=TensorDef(U), + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Chooses one value based on a binary condition supplied as its first operand. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.select` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) + + +@linalg_structured_op +def quantized_matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.m, D.n, D.k) + C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp) + ) + + +@linalg_structured_op +def mmt4d( + lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True), +): + """Performs a matrix-matrix-transpose multiplication of two 4D inputs. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + """ + domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0] + ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + + +@linalg_structured_op +def batch_mmt4d( + lhs=TensorDef(TV.LhsType, Batch, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, Batch, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, Batch, S.M, S.N, S.M0, S.N0, output=True), +): + """Performs a batched matrix-matrix-transpose multiplication of two + batched-4D (5D) inputs. + + Besides the outermost batch dimension has the same semantic as + linalg.batch_matmul, the differences from linalg.batch_matmul in the + non-batch dimensions are the same as linalg.mmt4d vs. linalg.matmul. See the + description of lingalg.mmt4d. + """ + domain(D.b, D.m, D.n, D.k, D.m0, D.n0, D.k0) + implements(ContractionOpInterface) + accum[D.b, D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.b, D.m, D.k, D.m0, D.k0] + ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) + + +@linalg_structured_op +def quantized_batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.b, D.m, D.n, D.k) + C[D.b, D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp) + ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) + + +@linalg_structured_op +def matvec( + A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True) +): + """Performs a matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n) + implements(ContractionOpInterface) + x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) + + +@linalg_structured_op +def vecmat( + y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True) +): + """Performs a vector-matrix multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.m) + implements(ContractionOpInterface) + x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) + + +@linalg_structured_op +def batch_matvec( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True), +): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.k) + implements(ContractionOpInterface) + C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k] + ) + + +@linalg_structured_op +def batch_vecmat( + A=TensorDef(T1, Batch, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.N, output=True), +): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) + + +@linalg_structured_op +def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): + """Performs a dot product of two vectors to a scalar result. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) + + +@linalg_structured_op +def conv_1d( + I=TensorDef(T1, S.OW + S.KW), + K=TensorDef(T2, S.KW), + O=TensorDef(U, S.OW, output=True), +): + """Performs 1-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.ow, D.kw) + O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw]) + + +@linalg_structured_op +def conv_2d( + I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KH, S.KW), + O=TensorDef(U, S.OH, S.OW, output=True), +): + """Performs 2-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.oh, D.ow, D.kh, D.kw) + O[D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw]) + + +@linalg_structured_op +def conv_3d( + I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KD, S.KH, S.KW), + O=TensorDef(U, S.OD, S.OH, S.OW, output=True), +): + """Performs 3-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) + O[D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw]) + + +@linalg_structured_op +def conv_1d_nwc_wcf( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.f, D.kw, D.c) + O[D.n, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) + + +@linalg_structured_op +def conv_1d_ncw_fcw( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KW), + O=TensorDef(U, S.N, S.F, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Layout: + * Input: NCW. + * Kernel: FCW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.ow, D.c, D.kw) + O[D.n, D.f, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw]) + + +@linalg_structured_op +def conv_2d_nhwc_hwcf( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) + + +@linalg_structured_op +def conv_2d_nhwc_fhwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) + + +@linalg_structured_op +def conv_2d_nhwc_hwcf_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def conv_2d_nhwc_fhwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def conv_2d_nchw_fchw_q( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += ( + TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp)) + +@linalg_structured_op +def conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) + + +@linalg_structured_op +def conv_2d_ngchw_fgchw( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: FGCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) + + +@linalg_structured_op +def conv_2d_ngchw_gfchw( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: GFCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + + +@linalg_structured_op +def conv_2d_nhwgc_gfhwc( + I=TensorDef( + T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C + ), + K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NHWGC. + * Kernel: GFHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c] + ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c]) + + +@linalg_structured_op +def conv_2d_nhwgc_gfhwc_q( + I=TensorDef( + T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C + ), + K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution with zero point offsets. + + Layout: + * Input: NHWGC. + * Kernel: GFHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.g, D.fg] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c]) + - TypeFn.cast_signed(U, KZp) + ) + + +@linalg_structured_op +def conv_2d_ngchw_gfchw_q( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution with zero-point offsets. + + Layout: + * Input: NGCHW. + * Kernel: GFCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += ( + TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + - TypeFn.cast_signed(U, KZp) + ) + + +@linalg_structured_op +def conv_3d_ndhwc_dhwcf( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + + +@linalg_structured_op +def conv_3d_ndhwc_dhwcf_q( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + - TypeFn.cast_signed(U, KZp) + ) + + +@linalg_structured_op +def conv_3d_ncdhw_fcdhw( + I=TensorDef( + T1, + S.N, + S.C, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) + + +@linalg_structured_op +def depthwise_conv_1d_nwc_wc( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC), + O=TensorDef(U, S.N, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_1d_ncw_cw( + I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KW), + O=TensorDef(U, S.N, S.IC, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ic, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kw]) + + +@linalg_structured_op +def depthwise_conv_1d_nwc_wcm( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.cm, D.kw) + O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_2d_nchw_chw( + I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwcm( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) + + +@linalg_structured_op +def depthwise_conv_2d_nhwc_hwcm_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp)) + + +@linalg_structured_op +def depthwise_conv_3d_ndhwc_dhwc( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_3d_ncdhw_cdhw( + I=TensorDef( + T1, + S.N, + S.IC, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.ic, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw]) + + +@linalg_structured_op +def depthwise_conv_3d_ndhwc_dhwcm( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) + + +@linalg_structured_op +def pooling_nhwc_sum( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NHWC. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + + +@linalg_structured_op +def pooling_nchw_sum( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NCHW. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) + + +@linalg_structured_op +def pooling_nhwc_max( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nhwc_max_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nchw_max( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) + + +@linalg_structured_op +def pooling_nhwc_min( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nhwc_min_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + + +@linalg_structured_op +def pooling_nwc_sum( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + + +@linalg_structured_op +def pooling_ncw_sum( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) + + +@linalg_structured_op +def pooling_nwc_max( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_nwc_max_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_ncw_max( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) + + +@linalg_structured_op +def pooling_nwc_min( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_nwc_min_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) + + +@linalg_structured_op +def pooling_ndhwc_sum( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D sum pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + + +@linalg_structured_op +def pooling_ndhwc_max( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) + + +@linalg_structured_op +def pooling_ndhwc_min( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) + + +@linalg_structured_op +def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): + """Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + implements(FillOpInterface) + defines(Canonicalizer) + O[None] = TypeFn.cast_signed(U, value) + + +@linalg_structured_op +def fill_rng_2d( + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True), +): + """Fills the output tensor with pseudo random numbers. + + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. + """ + domain(D.m, D.n) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min + ) diff --git a/mlir/python/mlir/dialects/linalg/passes/__init__.py b/mlir/python/mlir/dialects/linalg/passes/__init__.py new file mode 100644 index 000000000..0920e8ef4 --- /dev/null +++ b/mlir/python/mlir/dialects/linalg/passes/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...._mlir_libs import _mlirLinalgPasses as _cextLinalgPasses diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py new file mode 100644 index 000000000..941a58496 --- /dev/null +++ b/mlir/python/mlir/dialects/llvm.py @@ -0,0 +1,15 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._llvm_ops_gen import * +from ._llvm_enum_gen import * +from .._mlir_libs._mlirDialectsLLVM import * +from ..ir import Value +from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results + + +def mlir_constant(value, *, loc=None, ip=None) -> Value: + return _get_op_result_or_op_results( + ConstantOp(res=value.type, value=value, loc=loc, ip=ip) + ) diff --git a/mlir/python/mlir/dialects/math.py b/mlir/python/mlir/dialects/math.py new file mode 100644 index 000000000..f082bf461 --- /dev/null +++ b/mlir/python/mlir/dialects/math.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._math_ops_gen import * diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py new file mode 100644 index 000000000..bc9a3a527 --- /dev/null +++ b/mlir/python/mlir/dialects/memref.py @@ -0,0 +1,136 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import operator +from itertools import accumulate +from typing import Optional + +from ._memref_ops_gen import * +from ._ods_common import _dispatch_mixed_values, MixedValues +from .arith import ConstantOp, _is_integer_like_type +from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation + + +def _is_constant_int_like(i): + return ( + isinstance(i, Value) + and isinstance(i.owner, Operation) + and isinstance(i.owner.opview, ConstantOp) + and _is_integer_like_type(i.type) + ) + + +def _is_static_int_like(i): + return ( + isinstance(i, int) and not ShapedType.is_dynamic_size(i) + ) or _is_constant_int_like(i) + + +def _infer_memref_subview_result_type( + source_memref_type, offsets, static_sizes, static_strides +): + source_strides, source_offset = source_memref_type.get_strides_and_offset() + # "canonicalize" from tuple|list -> list + offsets, static_sizes, static_strides, source_strides = map( + list, (offsets, static_sizes, static_strides, source_strides) + ) + + if not all( + all(_is_static_int_like(i) for i in s) + for s in [ + static_sizes, + static_strides, + source_strides, + ] + ): + raise ValueError( + "Only inferring from python or mlir integer constant is supported." + ) + + for s in [offsets, static_sizes, static_strides]: + for idx, i in enumerate(s): + if _is_constant_int_like(i): + s[idx] = i.owner.opview.literal_value + + if any(not _is_static_int_like(i) for i in offsets + [source_offset]): + target_offset = ShapedType.get_dynamic_size() + else: + target_offset = source_offset + for offset, target_stride in zip(offsets, source_strides): + target_offset += offset * target_stride + + target_strides = [] + for source_stride, static_stride in zip(source_strides, static_strides): + target_strides.append(source_stride * static_stride) + + # If default striding then no need to complicate things for downstream ops (e.g., expand_shape). + default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1] + if target_strides == default_strides and target_offset == 0: + layout = None + else: + layout = StridedLayoutAttr.get(target_offset, target_strides) + return ( + offsets, + static_sizes, + static_strides, + MemRefType.get( + static_sizes, + source_memref_type.element_type, + layout, + source_memref_type.memory_space, + ), + ) + + +_generated_subview = subview + + +def subview( + source: Value, + offsets: MixedValues, + sizes: MixedValues, + strides: MixedValues, + *, + result_type: Optional[MemRefType] = None, + loc=None, + ip=None, +): + if offsets is None: + offsets = [] + if sizes is None: + sizes = [] + if strides is None: + strides = [] + source_strides, source_offset = source.type.get_strides_and_offset() + if result_type is None and all( + all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides] + ): + # If any are arith.constant results then this will canonicalize to python int + # (which can then be used to fully specify the subview). + ( + offsets, + sizes, + strides, + result_type, + ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides) + elif result_type is None: + raise ValueError( + "mixed static/dynamic offset/sizes/strides requires explicit result type." + ) + + offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets) + sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes) + strides, _packed_strides, static_strides = _dispatch_mixed_values(strides) + + return _generated_subview( + result_type, + source, + offsets, + sizes, + strides, + static_offsets, + static_sizes, + static_strides, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py new file mode 100644 index 000000000..dfb6d7f2c --- /dev/null +++ b/mlir/python/mlir/dialects/ml_program.py @@ -0,0 +1,119 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Union + +from ._ml_program_ops_gen import * +from ._ml_program_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py new file mode 100644 index 000000000..d6a54f277 --- /dev/null +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -0,0 +1,7 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._nvgpu_ops_gen import * +from ._nvgpu_enum_gen import * +from .._mlir_libs._mlirDialectsNVGPU import * diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py new file mode 100644 index 000000000..9477de39c --- /dev/null +++ b/mlir/python/mlir/dialects/nvvm.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._nvvm_ops_gen import * +from ._nvvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/openmp.py b/mlir/python/mlir/dialects/openmp.py new file mode 100644 index 000000000..604f0bd03 --- /dev/null +++ b/mlir/python/mlir/dialects/openmp.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._omp_ops_gen import * diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py new file mode 100644 index 000000000..b7b8430ce --- /dev/null +++ b/mlir/python/mlir/dialects/pdl.py @@ -0,0 +1,236 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._pdl_ops_gen import * +from ._pdl_ops_gen import _Dialect +from .._mlir_libs._mlirDialectsPDL import * +from .._mlir_libs._mlirDialectsPDL import OperationType +from ..extras.meta import region_op + +try: + from ..ir import * + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union, Optional, Sequence, Mapping, NewType +from ._ods_common import ( + get_op_result_or_value as _get_value, + get_op_results_or_values as _get_values, + _cext as _ods_cext, +) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AttributeOp(AttributeOp): + """Specialization for PDL attribute op class.""" + + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): + valueType = valueType if valueType is None else _get_value(valueType) + result = pdl.AttributeType.get() + super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperandOp(OperandOp): + """Specialization for PDL operand op class.""" + + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, valueType=type, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperandsOp(OperandsOp): + """Specialization for PDL operands op class.""" + + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, valueType=types, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperationOp(OperationOp): + """Specialization for PDL operand op class.""" + + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] + args = _get_values(args) + attrNames = [] + attrValues = [] + for attrName, attrValue in attributes.items(): + attrNames.append(StringAttr.get(attrName)) + attrValues.append(_get_value(attrValue)) + attrNames = ArrayAttr.get(attrNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__( + result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PatternOp(PatternOp): + """Specialization for PDL pattern op class.""" + + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + """Creates an PDL `pattern` operation.""" + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] + + +pattern = region_op(PatternOp.__base__) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ReplaceOp(ReplaceOp): + """Specialization for PDL replace op class.""" + + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ResultOp(ResultOp): + """Specialization for PDL result op class.""" + + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class RewriteOp(RewriteOp): + """Specialization for PDL rewrite op class.""" + + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + root = root if root is None else _get_value(root) + args = _get_values(args) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] + + +rewrite = region_op(RewriteOp) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypeOp(TypeOp): + """Specialization for PDL type op class.""" + + def __init__( + self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None + ): + result = pdl.TypeType.get() + super().__init__(result, constantType=constantType, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypesOp(TypesOp): + """Specialization for PDL types op class.""" + + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) + + +OperationTypeT = NewType("OperationType", OperationType) + + +def op_t() -> OperationTypeT: + return OperationTypeT(OperationType.get()) diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py new file mode 100644 index 000000000..9380896c8 --- /dev/null +++ b/mlir/python/mlir/dialects/python_test.py @@ -0,0 +1,16 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._python_test_ops_gen import * + + +def register_python_test_dialect(registry, use_nanobind): + if use_nanobind: + from .._mlir_libs import _mlirPythonTestNanobind + + _mlirPythonTestNanobind.register_dialect(registry) + else: + from .._mlir_libs import _mlirPythonTestPybind11 + + _mlirPythonTestPybind11.register_dialect(registry) diff --git a/mlir/python/mlir/dialects/quant.py b/mlir/python/mlir/dialects/quant.py new file mode 100644 index 000000000..bf1fc5f2d --- /dev/null +++ b/mlir/python/mlir/dialects/quant.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._mlir_libs._mlirDialectsQuant import * diff --git a/mlir/python/mlir/dialects/rocdl.py b/mlir/python/mlir/dialects/rocdl.py new file mode 100644 index 000000000..aa47cb4b5 --- /dev/null +++ b/mlir/python/mlir/dialects/rocdl.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._rocdl_ops_gen import * diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py new file mode 100644 index 000000000..678ceeeba --- /dev/null +++ b/mlir/python/mlir/dialects/scf.py @@ -0,0 +1,256 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + + +from ._scf_ops_gen import * +from ._scf_ops_gen import _Dialect +from .arith import constant + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Sequence, Tuple, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ForOp(ForOp): + """Specialization for the SCF for op class.""" + + def __init__( + self, + lower_bound, + upper_bound, + step, + iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + + results = [arg.type for arg in iter_args] + super().__init__( + results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip + ) + self.regions[0].blocks.append(self.operands[0].type, *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] + + +def _dispatch_index_op_fold_results( + ofrs: Sequence[Union[Operation, OpView, Value, int]], +) -> Tuple[List[Value], List[int]]: + """`mlir::dispatchIndexOpFoldResults`""" + dynamic_vals = [] + static_vals = [] + for ofr in ofrs: + if isinstance(ofr, (Operation, OpView, Value)): + val = _get_op_result_or_value(ofr) + dynamic_vals.append(val) + static_vals.append(ShapedType.get_dynamic_size()) + else: + static_vals.append(ofr) + return dynamic_vals, static_vals + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ForallOp(ForallOp): + """Specialization for the SCF forall op class.""" + + def __init__( + self, + lower_bounds: Sequence[Union[Operation, OpView, Value, int]], + upper_bounds: Sequence[Union[Operation, OpView, Value, int]], + steps: Sequence[Union[Value, int]], + shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + mapping=None, + loc=None, + ip=None, + ): + """Creates an SCF `forall` operation. + + - `lower_bounds` are the values to use as lower bounds of the loop. + - `upper_bounds` are the values to use as upper bounds of the loop. + - `steps` are the values to use as loop steps. + - `shared_outs` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + assert ( + len(lower_bounds) == len(upper_bounds) == len(steps) + ), "Mismatch in length of lower bounds, upper bounds, and steps" + if shared_outs is None: + shared_outs = [] + shared_outs = _get_op_results_or_values(shared_outs) + + dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds) + dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds) + dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps) + + results = [arg.type for arg in shared_outs] + super().__init__( + results, + dynamic_lbs, + dynamic_ubs, + dynamic_steps, + static_lbs, + static_ubs, + static_steps, + shared_outs, + mapping=mapping, + loc=loc, + ip=ip, + ) + rank = len(static_lbs) + iv_types = [IndexType.get()] * rank + self.regions[0].blocks.append(*iv_types, *results) + + @property + def body(self) -> Block: + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def rank(self) -> int: + """Returns the number of induction variables the loop has.""" + return len(self.staticLowerBound) + + @property + def induction_variables(self) -> BlockArgumentList: + """Returns the induction variables usable within the loop.""" + return self.body.arguments[: self.rank] + + @property + def inner_iter_args(self) -> BlockArgumentList: + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[self.rank :] + + def terminator(self) -> InParallelOp: + """ + Returns the loop terminator if it exists. + Otherwise, creates a new one. + """ + ops = self.body.operations + with InsertionPoint(self.body): + if not ops: + return InParallelOp() + last = ops[len(ops) - 1] + return last if isinstance(last, InParallelOp) else InParallelOp() + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InParallelOp(InParallelOp): + """Specialization of the SCF forall.in_parallel op class.""" + + def __init__(self, loc=None, ip=None): + super().__init__(loc=loc, ip=ip) + self.region.blocks.append() + + @property + def block(self) -> Block: + return self.region.blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class IfOp(IfOp): + """Specialization for the SCF if op class.""" + + def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + if results_ is None: + results_ = [] + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super().__init__(results, cond, loc=loc, ip=ip) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] + + +def for_( + start, + stop=None, + step=None, + iter_args: Optional[Sequence[Value]] = None, + *, + loc=None, + ip=None, +): + if step is None: + step = 1 + if stop is None: + stop = start + start = 0 + params = [start, stop, step] + for i, p in enumerate(params): + if isinstance(p, int): + p = constant(IndexType.get(), p) + elif isinstance(p, float): + raise ValueError(f"{p=} must be int.") + params[i] = p + + start, stop, step = params + + for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip) + iv = for_op.induction_variable + iter_args = tuple(for_op.inner_iter_args) + with InsertionPoint(for_op.body): + if len(iter_args) > 1: + yield iv, iter_args, for_op.results + elif len(iter_args) == 1: + yield iv, iter_args[0], for_op.results[0] + else: + yield iv diff --git a/mlir/lib/Bindings/Python/mlir/dialects/shape.py b/mlir/python/mlir/dialects/shape.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/shape.py rename to mlir/python/mlir/dialects/shape.py diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py new file mode 100644 index 000000000..ae7a4c41c --- /dev/null +++ b/mlir/python/mlir/dialects/smt.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._smt_ops_gen import * + +from .._mlir_libs._mlirDialectsSMT import * +from ..extras.meta import region_op + + +def bool_t(): + return BoolType.get() + + +def bv_t(width): + return BitVectorType.get(width) + + +def _solver( + inputs=None, + results=None, + loc=None, + ip=None, +): + if inputs is None: + inputs = [] + if results is None: + results = [] + + return SolverOp(results, inputs, loc=loc, ip=ip) + + +solver = region_op(_solver, terminator=YieldOp) diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py new file mode 100644 index 000000000..209ecc95f --- /dev/null +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._sparse_tensor_ops_gen import * +from ._sparse_tensor_enum_gen import * +from .._mlir_libs._mlirDialectsSparseTensor import * +from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses diff --git a/mlir/python/mlir/dialects/spirv.py b/mlir/python/mlir/dialects/spirv.py new file mode 100644 index 000000000..269678a20 --- /dev/null +++ b/mlir/python/mlir/dialects/spirv.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._spirv_ops_gen import * diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py new file mode 100644 index 000000000..146b5f85d --- /dev/null +++ b/mlir/python/mlir/dialects/tensor.py @@ -0,0 +1,67 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional + +from ._tensor_ops_gen import * +from ._tensor_ops_gen import _Dialect +from ..extras.meta import region_op + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Sequence, Union +from ._ods_common import _cext as _ods_cext +from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmptyOp(EmptyOp): + """Extends the tensor.empty op.""" + + def __init__( + self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + encoding: Optional[Attribute] = None, + loc=None, + ip=None, + ): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type, encoding) + super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) + + +def empty( + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + encoding: Optional[Attribute] = None, + loc=None, + ip=None, +) -> _ods_cext.ir.Value: + return _get_op_result_or_op_results( + EmptyOp( + sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip + ) + ) + + +generate = region_op( + lambda result, dynamic_extents: GenerateOp(result, dynamic_extents), + terminator=lambda args: YieldOp(args[0]), +) diff --git a/mlir/python/mlir/dialects/tosa.py b/mlir/python/mlir/dialects/tosa.py new file mode 100644 index 000000000..aebda742f --- /dev/null +++ b/mlir/python/mlir/dialects/tosa.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._tosa_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py new file mode 100644 index 000000000..b075919d1 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -0,0 +1,306 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_enum_gen import * +from .._transform_ops_gen import * +from .._transform_ops_gen import _Dialect +from ..._mlir_libs._mlirDialectsTransform import * +from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Dict, Optional, Sequence, Union, NewType + + +@register_attribute_builder("ParamOperandAttr") +def _paramOperandAttr(x: int, context) -> Attribute: + return Attribute.parse(f"#transform.param_operand", context=context) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class CastOp(CastOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyPatternsOp(ApplyPatternsOp): + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + super().__init__(target, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GetParentOp(GetParentOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + nth_parent: int = 1, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + nth_parent=nth_parent, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MergeHandlesOp(MergeHandlesOp): + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ReplicateOp(ReplicateOp): + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h).type for h in handles], + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SequenceOp(SequenceOp): + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[ + Union[Sequence[Value], Sequence[Type], Operation, OpView] + ] = None, + ): + root = ( + _get_op_result_or_value(target) + if isinstance(target, (Operation, Value)) + else None + ) + root_type = root.type if not isinstance(target, Type) else target + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode, + root=root, + extra_bindings=extra_bindings, + ) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class NamedSequenceOp(NamedSequenceOp): + def __init__( + self, + sym_name, + input_types: Sequence[Type], + result_types: Sequence[Type], + sym_visibility=None, + arg_attrs=None, + res_attrs=None, + ): + function_type = FunctionType.get(input_types, result_types) + super().__init__( + sym_name=sym_name, + function_type=TypeAttr.get(function_type), + sym_visibility=sym_visibility, + arg_attrs=arg_attrs, + res_attrs=res_attrs, + ) + self.regions[0].blocks.append(*input_types) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class YieldOp(YieldOp): + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) + + +OptionValueTypes = Union[ + Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool +] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyRegisteredPassOp(ApplyRegisteredPassOp): + def __init__( + self, + result: Type, + target: Union[Operation, Value, OpView], + pass_name: Union[str, StringAttr], + *, + options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None, + loc=None, + ip=None, + ): + options_dict = {} + dynamic_options = [] + + ParamOperandAttr = AttrBuilder.get("ParamOperandAttr") + context = (loc and loc.context) or Context.current + + cur_param_operand_idx = 0 + + def option_value_to_attr(value): + nonlocal cur_param_operand_idx + if isinstance(value, (Value, Operation, OpView)): + dynamic_options.append(_get_op_result_or_value(value)) + cur_param_operand_idx += 1 + return ParamOperandAttr(cur_param_operand_idx - 1, context) + elif isinstance(value, Attribute): + return value + # The following cases auto-convert Python values to attributes. + elif isinstance(value, bool): + return BoolAttr.get(value) + elif isinstance(value, int): + default_int_type = IntegerType.get_signless(64, context) + return IntegerAttr.get(default_int_type, value) + elif isinstance(value, str): + return StringAttr.get(value) + elif isinstance(value, Sequence): + return ArrayAttr.get([option_value_to_attr(elt) for elt in value]) + else: + raise TypeError(f"Unsupported option type: {type(value)}") + + for key, value in options.items() if options is not None else {}: + if isinstance(key, StringAttr): + key = key.value + options_dict[key] = option_value_to_attr(value) + super().__init__( + result, + _get_op_result_or_value(target), + pass_name, + dynamic_options, + options=DictAttr.get(options_dict), + loc=loc, + ip=ip, + ) + + +def apply_registered_pass( + result: Type, + target: Union[Operation, Value, OpView], + pass_name: Union[str, StringAttr], + *, + options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None, + loc=None, + ip=None, +) -> Value: + return ApplyRegisteredPassOp( + result=result, + pass_name=pass_name, + target=target, + options=options, + loc=loc, + ip=ip, + ).result + + +AnyOpTypeT = NewType("AnyOpType", AnyOpType) + + +def any_op_t() -> AnyOpTypeT: + return AnyOpTypeT(AnyOpType.get()) diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py new file mode 100644 index 000000000..485a8a36b --- /dev/null +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -0,0 +1,134 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._bufferization_transform_ops_gen import * +from .._bufferization_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from enum import Enum +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp): + """Specialization for EmptyTensorToAllocTensorOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.OperationType.get("bufferization.alloc_tensor") + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OneShotBufferizeOp(OneShotBufferizeOp): + """Specialization for OneShotBufferizeOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + allow_return_allocs_from_loops=allow_return_allocs_from_loops, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + function_boundary_type_conversion=function_boundary_type_conversion, + memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py new file mode 100644 index 000000000..f7c04268d --- /dev/null +++ b/mlir/python/mlir/dialects/transform/debug.py @@ -0,0 +1,81 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from ...ir import Attribute, Operation, Value, StringAttr +from .._transform_debug_extension_ops_gen import * +from .._transform_pdl_extension_ops_gen import _Dialect + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmitParamAsRemarkOp(EmitParamAsRemarkOp): + def __init__( + self, + param: Attribute, + *, + anchor: Optional[Operation] = None, + message: Optional[Union[StringAttr, str]] = None, + loc=None, + ip=None, + ): + if isinstance(message, str): + message = StringAttr.get(message) + + super().__init__( + param, + anchor=anchor, + message=message, + loc=loc, + ip=ip, + ) + + +def emit_param_as_remark( + param: Attribute, + *, + anchor: Optional[Operation] = None, + message: Optional[Union[StringAttr, str]] = None, + loc=None, + ip=None, +): + return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmitRemarkAtOp(EmitRemarkAtOp): + def __init__( + self, + at: Union[Operation, Value], + message: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(message, str): + message = StringAttr.get(message) + + super().__init__( + at, + message, + loc=loc, + ip=ip, + ) + + +def emit_remark_at( + at: Union[Operation, Value], + message: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, +): + return EmitRemarkAtOp(at, message, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py new file mode 100644 index 000000000..8d045cad7 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -0,0 +1,246 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Callable, Optional, Sequence, Union + +from ....extras.meta import region_op +from .... import ir +from ... import transform +from .. import ( + AnyOpType, + AnyParamType, + AnyValueType, + OperationType, + ParamType, + NamedSequenceOp, + YieldOp, + SequenceOp, + ApplyPatternsOp, +) +from .. import structured + + +class Handle(ir.Value): + """ + Base class for wrappers around different types of transform handle with + methods to chain further transforms. + + The fields `children` and `parent` are used to capture the relation of + handles statically in order to enable further analysis. The payload + operation of a child handle is nested into a region of the payload operation + of the corresponding parent handle. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional["Handle"] = None, + children: Optional[Sequence["Handle"]] = None, + ): + super().__init__(v) + self.parent = parent + self.children = children if children is not None else [] + +@ir.register_value_caster(AnyOpType.get_static_typeid()) +@ir.register_value_caster(OperationType.get_static_typeid()) +class OpHandle(Handle): + """ + Wrapper around a transform operation handle with methods to chain further + transforms. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle": + """ + Emits a `transform.GetResultOp`. + Returns a handle to the result of the payload operation at the given + indices. + """ + get_result_op = transform.GetResultOp( + AnyValueType.get(), + self, + indices, + ) + return get_result_op.result + + def match_ops( + self, + ops: Union[ + str, + ir.OpView, + structured.MatchInterfaceEnum, + Sequence[Union[str, ir.OpView]], + ], + ) -> "OpHandle": + """ + Emits a `transform.structured.MatchOp`. + Returns a handle to payload ops that match the given names, types, or + interface. If only a single type is given, the value wrapped by the + resulting handle is populated with the respective type. + """ + # Handle interface. + if isinstance(ops, structured.MatchInterfaceEnum) or ( + isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ + ): + if isinstance(ops, str): + ops = structured.MatchInterfaceEnum[ops] + match_op = structured.MatchOp( + AnyOpType.get(), + self, + interface=ops, + ) + + # Handle op name(s), either given directly as string or given as op. + else: + if isinstance(ops, str): + op_type = OperationType.get(ops) + op_names = [ops] + elif isinstance(ops, Sequence): + op_type = AnyOpType.get() + op_names = [ + op if isinstance(op, str) else op.OPERATION_NAME for op in ops + ] + else: + op_type = OperationType.get(ops.OPERATION_NAME) + op_names = [ops.OPERATION_NAME] + match_op = structured.MatchOp.match_op_names( + op_type, + self, + op_names, + ) + + handle = OpHandle(match_op.results_, parent=self) + self.children.append(handle) + return handle + + def print(self, name: Optional[str] = None) -> "OpHandle": + """ + Emits a `transform.PrintOp` to print this handle and an optional message. + Returns the existing handle to facilitate further chaining. + """ + transform.PrintOp(target=self, name=name) + return self + + +@ir.register_value_caster(AnyParamType.get_static_typeid()) +@ir.register_value_caster(ParamType.get_static_typeid()) +class ParamHandle(Handle): + """Wrapper around a transform param handle.""" + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + +@ir.register_value_caster(AnyValueType.get_static_typeid()) +class ValueHandle(Handle): + """ + Wrapper around a transform value handle with methods to chain further + transforms. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + def get_defining_op(self) -> OpHandle: + """ + Emits a `transform.GetDefiningOpOp`. + Returns a handle to the defining op of the wrapped value. + """ + get_defining_op = transform.GetDefiningOp( + AnyOpType.get(), + self, + ) + return get_defining_op.result + + +def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle: + """ + Emits a `transform.ParamConstantOp`. + Returns a handle to the newly created parameter. The type of the parameter + is `transfrom.any_param` if the value is not an integer, otherwise the type + is `transform.param` parametrized with the according integer type. + """ + if isinstance(value, int): + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + if isinstance(value.type, ir.IntegerType): + param_type = ParamType.get(value.type) + else: + param_type = AnyParamType.get() + op = transform.ParamConstantOp(param_type, value) + return op.param + + +def insert_transform_script( + block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], + script: Callable[[OpHandle], None], + dump_script: bool = False, +) -> None: + """ + Inserts the transform script of the schedule into the module. The script + should accept an instance of OpHandle as argument, which will be called with + the block arg of the newly created named_sequence op. + + Example: + This python code + ``` + module = ir.Module.create() + def test_match_ops_single(module: OpHandle): + module.match_ops(scf.ForOp) + insert_transform_script(module.body, script) + ``` + generates the following IR: + ``` + module { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + ^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg0 + : (!transform.any_op) -> !transform.op<"scf.for"> + } + } + ``` + """ + if isinstance(block_or_insertion_point, ir.Block): + context = block_or_insertion_point.owner.context + insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point) + else: + context = block_or_insertion_point.block.owner.context + insertion_point = block_or_insertion_point + + with context, ir.Location.unknown(context): + with insertion_point: + named_sequence_op = NamedSequenceOp( + "__transform_main", [AnyOpType.get()], [] + ) + with ir.InsertionPoint(named_sequence_op.body): + script(named_sequence_op.bodyTarget) + YieldOp([]) + + if dump_script: + print(named_sequence_op) + + +sequence = region_op(SequenceOp.__base__, terminator=YieldOp) +named_sequence = region_op(NamedSequenceOp, terminator=YieldOp) +apply_patterns = region_op(ApplyPatternsOp) diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py new file mode 100644 index 000000000..00cf0840e --- /dev/null +++ b/mlir/python/mlir/dialects/transform/gpu.py @@ -0,0 +1,130 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._gpu_transform_ops_gen import * +from .._gpu_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union, overload + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapForallToBlocks(MapForallToBlocks): + """Specialization for MapForallToBlocks class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Type, Value], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + + super().__init__( + result_type, + target, + grid_dims=grid_dims, + generate_gpu_launch=generate_gpu_launch, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapNestedForallToThreads(MapNestedForallToThreads): + """Specialization for MapNestedForallToThreads class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Value, Type], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + block_dims: Optional[Union[Sequence[int], Attribute]] = None, + warp_size: Optional[Union[Sequence[int], Attribute]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = result_type_or_target.type + target = result_type_or_target + super().__init__( + result_type, + target, + block_dims=block_dims, + warp_size=warp_size, + sync_after_distribute=sync_after_distribute, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py new file mode 100644 index 000000000..e69aa9630 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py @@ -0,0 +1,41 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ....ir import Operation +from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter + +TransformOptions = _cextTransformInterpreter.TransformOptions + + +def _unpack_operation(op): + if isinstance(op, Operation): + return op + return op.operation + + +def apply_named_sequence( + payload_root, transform_root, transform_module, transform_options=None +): + """Applies the transformation script starting at the given transform root + operation to the given payload operation. The module containing the + transform root as well as the transform options should be provided. + The transform operation must implement TransformOpInterface and the module + must be a ModuleOp.""" + + args = tuple( + map(_unpack_operation, (payload_root, transform_root, transform_module)) + ) + if transform_options is None: + _cextTransformInterpreter.apply_named_sequence(*args) + else: + _cextTransformInterpreter.apply_named_sequence(*args, transform_options) + + +def copy_symbols_and_merge_into(target, other): + """Copies symbols from other into target, renaming private symbols to avoid + duplicates. Raises an error if copying would lead to duplicate public + symbols.""" + _cextTransformInterpreter.copy_symbols_and_merge_into( + _unpack_operation(target), _unpack_operation(other) + ) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py new file mode 100644 index 000000000..c4770b1c4 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -0,0 +1,127 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._loop_transform_ops_gen import * +from .._loop_transform_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopOutlineOp(LoopOutlineOp): + """Extension for LoopOutlineOp.""" + + def __init__( + self, + function_type: Type, + call_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): + super().__init__( + function_type, + call_type, + _get_op_result_or_value(target), + func_name=( + func_name + if isinstance(func_name, StringAttr) + else StringAttr.get(func_name) + ), + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopPeelOp(LoopPeelOp): + """Extension for LoopPeelOp.""" + + def __init__( + self, + main_loop_type: Type, + remainder_loop_type: Type, + target: Union[Operation, Value], + *, + peel_front: Union[bool, BoolAttr] = False, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): + super().__init__( + main_loop_type, + remainder_loop_type, + _get_op_result_or_value(target), + peel_front=( + peel_front + if isinstance(peel_front, BoolAttr) + else BoolAttr.get(peel_front) + ), + fail_if_already_divisible=( + fail_if_already_divisible + if isinstance(fail_if_already_divisible, BoolAttr) + else BoolAttr.get(fail_if_already_divisible) + ), + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopPipelineOp(LoopPipelineOp): + """Extension for LoopPipelineOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 + super().__init__( + result_type, + _get_op_result_or_value(target), + iteration_interval=iteration_interval, + read_latency=read_latency, + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopUnrollOp(LoopUnrollOp): + """Extension for LoopUnrollOp.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): + super().__init__( + _get_op_result_or_value(target), + factor=factor, + ip=ip, + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py new file mode 100644 index 000000000..56ea61eb8 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/memref.py @@ -0,0 +1,120 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._memref_transform_ops_gen import * +from .._memref_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp): + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + get_global_type_or_alloca: Union[Operation, OpView, Type, Value], + global_type_or_none: Optional[Type] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(get_global_type_or_alloca, Type): + get_global_type = get_global_type_or_alloca + global_type = global_type_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + alloca = get_global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + alloca, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MemRefMultiBufferOp(MemRefMultiBufferOp): + """Specialization for MemRefMultiBufferOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, + factor_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_factor + factor = factor_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + factor = target_or_factor + + super().__init__( + transformed_type, + target, + factor, + skip_analysis=skip_analysis, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/nvgpu.py b/mlir/python/mlir/dialects/transform/nvgpu.py new file mode 100644 index 000000000..74ba4c9ae --- /dev/null +++ b/mlir/python/mlir/dialects/transform/nvgpu.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._nvgpu_transform_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py new file mode 100644 index 000000000..bb5fa7ffd --- /dev/null +++ b/mlir/python/mlir/dialects/transform/pdl.py @@ -0,0 +1,55 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_pdl_extension_ops_gen import * +from .._transform_pdl_extension_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PDLMatchOp(PDLMatchOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + pattern_name, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class WithPDLPatternsOp(WithPDLPatternsOp): + def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, Type) else None + root_type = target if isinstance(target, Type) else root.type + super().__init__(root=root, loc=loc, ip=ip) + self.regions[0].blocks.append(root_type) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/transform/sparse_tensor.py b/mlir/python/mlir/dialects/transform/sparse_tensor.py new file mode 100644 index 000000000..8b33270dc --- /dev/null +++ b/mlir/python/mlir/dialects/transform/sparse_tensor.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._sparse_tensor_transform_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py new file mode 100644 index 000000000..bf40cc532 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -0,0 +1,721 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._structured_transform_ops_gen import * +from .._structured_transform_ops_gen import _Dialect +from .._structured_transform_enum_gen import * + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import ( + DynamicIndexList, + IntOrAttrList, + MixedValues, + OptionalBoolList, + OptionalIntList, + _cext as _ods_cext, + _dispatch_dynamic_index_list, + _dispatch_mixed_values, + _get_int_array_array_attr, + _get_int_array_attr, + _get_value_list, + _get_value_or_attribute_value, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Sequence, Union, overload + + +@_ods_cext.register_operation(_Dialect, replace=True) +class BufferizeToAllocationOp(BufferizeToAllocationOp): + """Specialization for BufferizeToAllocationOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + memory_space: Optional[Union[int, str, Attribute]] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + loc=None, + ip=None, + ): + # No other types are allowed, so hard-code those here. + allocated_buffer_type = transform.AnyValueType.get() + new_ops_type = transform.AnyOpType.get() + + if isinstance(memory_space, int): + memory_space = str(memory_space) + if isinstance(memory_space, str): + memory_space = Attribute.parse(memory_space) + + super().__init__( + allocated_buffer_type, + new_ops_type, + target, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class DecomposeOp(DecomposeOp): + """Specialization for DecomposeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuseIntoContainingOp(FuseIntoContainingOp): + """Specialization for FuseIntoContainingOp class.""" + + @overload + def __init__( + self, + fused_op_type: Type, + new_containing_op_type: Type, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], + new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], + producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(fused_op_type_or_producer_op, Type): + if not isinstance(new_containing_op_type_or_containing_op, Type): + raise TypeError( + "If 'fused_op_type_or_producer_op' is a type, then " + "'new_containing_op_type_or_containing_op' is expected " + "to be one as well." + ) + fused_op_type = fused_op_type_or_producer_op + new_containing_op_type = new_containing_op_type_or_containing_op + producer_op = producer_op_or_none + containing_op = containing_op_or_none + else: + fused_op_type = transform.AnyOpType.get() + new_containing_op_type = transform.AnyOpType.get() + producer_op = fused_op_type_or_producer_op + containing_op = new_containing_op_type_or_containing_op + + super().__init__( + fused_op_type, + new_containing_op_type, + producer_op, + containing_op, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuseOp(FuseOp): + """Specialization for FuseOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, Sequence[Type]], + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + tile_sizes = tile_sizes if tile_sizes else [] + tile_interchange = tile_interchange if tile_interchange else [] + _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) + _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) + num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct FuseOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + super().__init__( + target.type, + loop_types, + target, + tile_sizes=tile_sizes, + tile_interchange=tile_interchange, + apply_cleanup=apply_cleanup, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GeneralizeOp(GeneralizeOp): + """Specialization for GeneralizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InterchangeOp(InterchangeOp): + """Specialization for InterchangeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + transformed_type = transform.AnyOpType.get() + super().__init__( + transformed_type, + target, + iterator_interchange=iterator_interchange, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapCopyToThreadsOp(MapCopyToThreadsOp): + """Specialization for MapCopyToThreadsOp class.""" + + @overload + def __init__( + self, + forall_op_type: Type, + tiled_op_type: Type, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + def __init__( + self, + forall_op_type_or_target: Union[Operation, OpView, Type, Value], + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + if isinstance(forall_op_type_or_target, Type): + forall_op_type = forall_op_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + forall_op_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = forall_op_type_or_target + + super().__init__( + forall_op_type, + tiled_op_type, + target, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class VectorizeOp(VectorizeOp): + """Specialization for VectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + *, + vectorize_nd_extract: Optional[bool] = None, + scalable_sizes: OptionalBoolList = None, + static_vector_sizes: OptionalIntList = None, + loc=None, + ip=None, + ): + if ( + scalable_sizes is None + and static_vector_sizes is None + and vector_sizes is None + ): + dynamic_vector_sizes = [] + elif scalable_sizes is None and static_vector_sizes is None: + ( + dynamic_vector_sizes, + static_vector_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(vector_sizes) + elif scalable_sizes is None or static_vector_sizes is None: + raise TypeError( + "'scalable_sizes' and 'static_vector_sizes' must either both " + "be given explicitly or both be given as part of 'vector_sizes'." + ) + else: + dynamic_vector_sizes = vector_sizes + + super().__init__( + target, + vector_sizes=dynamic_vector_sizes, + static_vector_sizes=static_vector_sizes, + scalable_sizes=scalable_sizes, + vectorize_nd_extract=vectorize_nd_extract, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MatchOp(MatchOp): + """Specialization for MatchOp class.""" + + @overload + @classmethod + def match_op_names( + cls, + target: Union[Operation, Value], + names: Union[str, Sequence[str]], + *, + loc=None, + ip=None, + ): + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Union[str, Sequence[str]], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str], str], + names_or_none: Optional[Union[Sequence[str], str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + if isinstance(names, str): + names = [names] + + return cls( + result_type, + target, + ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MultiTileSizesOp(MultiTileSizesOp): + """Specialization for MultiTileSizesOp class.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + super().__init__( + result_type, + result_type, + result_type, + target, + dimension=dimension, + target_size=target_size, + divisor=divisor, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PadOp(PadOp): + """Specialization for PadOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + padding_dimensions: OptionalIntList = None, + nofold_flags: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + copy_back_op: Optional[Union[str, StringAttr]] = None, + loc=None, + ip=None, + ): + if pad_to_multiple_of is None: + dynamic_pad_to_multiple_of = [] + static_pad_to_multiple_of = None + else: + ( + dynamic_pad_to_multiple_of, + static_pad_to_multiple_of, + _, + ) = _dispatch_dynamic_index_list(pad_to_multiple_of) + + transpose_paddings = _get_int_array_array_attr(transpose_paddings) + + any_op_type = transform.AnyOpType.get() + super().__init__( + any_op_type, + any_op_type, + any_op_type, + target, + pad_to_multiple_of=dynamic_pad_to_multiple_of, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + static_pad_to_multiple_of=static_pad_to_multiple_of, + nofold_flags=nofold_flags, + transpose_paddings=transpose_paddings, + copy_back_op=copy_back_op, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ScalarizeOp(ScalarizeOp): + """Specialization for ScalarizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + result_type = transform.AnyOpType.get() + super().__init__(result_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SplitOp(SplitOp): + """Specialization for SplitOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + chunk_sizes: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): + if isinstance(chunk_sizes, int): + static_chunk_sizes = chunk_sizes + dynamic_chunk_sizes = None + else: + static_chunk_sizes = ShapedType.get_dynamic_size() + dynamic_chunk_sizes = chunk_sizes + + super().__init__( + target.type, + target, + dimension=dimension, + static_chunk_sizes=static_chunk_sizes, + dynamic_chunk_sizes=dynamic_chunk_sizes, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TileUsingForOp(TileUsingForOp): + """Specialization for TileUsingForOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ( + dynamic_sizes, + static_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(sizes) + + num_loops = sum(v if v == 0 else 1 for v in static_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert ( + target_or_none is None + ), "Cannot construct TileUsingForOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + + super().__init__( + target.type, + loop_types, + target, + dynamic_sizes=dynamic_sizes, + static_sizes=static_sizes, + interchange=interchange, + scalable_sizes=scalable_sizes, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TileUsingForallOp(TileUsingForallOp): + """Specialization for TileUsingForallOp class.""" + + @overload + def __init__( + self, + loops_type: Type, + tiled_op_type: Type, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loops_type_or_target: Union[ + Type, Union[Operation, Value, OpView] # loops_type + ], # target + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + num_threads: MixedValues = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + # `Type` arguments in the front are optional: add default values to front. + if isinstance(loops_type_or_target, Type): + # First overload: type arguments provided. + if not isinstance(tiled_op_type_or_none, Type): + raise TypeError( + "If 'loops_type_or_target' is a type, then " + "'tiled_op_type_or_none' is expected to be one as well." + ) + loops_type = loops_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + # Last overload: type arguments missing. + loops_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = loops_type_or_target + + # Unpack mixed num_threads. + ( + dynamic_num_threads, + packed_num_threads, + num_threads_attr, + ) = _dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + ( + dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr, + ) = _dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp): + """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + disable_multi_reduction_to_contract_patterns: bool = False, + disable_transfer_permutation_map_lowering_patterns: bool = False, + vectorize_nd_extract: bool = False, + vectorize_padding: bool = False, + loc=None, + ip=None, + ): + transformed_type = transform.AnyOpType.get() + super().__init__( + transformed_type, + target, + disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, + disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, + vectorize_nd_extract=vectorize_nd_extract, + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py new file mode 100644 index 000000000..4eb30398f --- /dev/null +++ b/mlir/python/mlir/dialects/transform/tensor.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._tensor_transform_ops_gen import * +from .._tensor_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MakeLoopIndependentOp(MakeLoopIndependentOp): + """Specialization for MakeLoopIndependentOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, + num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_num_loops + num_loops = num_loops_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + num_loops = target_or_num_loops + + super().__init__( + transformed_type, + target, + num_loops, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py new file mode 100644 index 000000000..f63f88a38 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -0,0 +1,82 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional, Sequence + +from ...ir import ( + Type, + Attribute, + ArrayAttr, + StringAttr, + F64Type, + IntegerType, + IntegerAttr, + FloatAttr, + BoolAttr, +) +from .._transform_tune_extension_ops_gen import * +from .._transform_tune_extension_ops_gen import _Dialect + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class KnobOp(KnobOp): + def __init__( + self, + result: Type, # !transform.any_param or !transform.param + name: Union[StringAttr, str], + options: Union[ + ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute + ], + *, + selected: Optional[Attribute] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + def map_to_attr(value): + if isinstance(value, bool): + return BoolAttr.get(value) + if isinstance(value, int): + return IntegerAttr.get(IntegerType.get_signless(64), value) + if isinstance(value, float): + return FloatAttr.get(F64Type.get(), value) + if isinstance(value, str): + return StringAttr.get(value) + assert isinstance(value, Attribute) + return value + + if isinstance(options, Sequence) and not isinstance(options, ArrayAttr): + options = ArrayAttr.get([map_to_attr(opt) for opt in options]) + + super().__init__( + result, + name, + options, + selected=selected and map_to_attr(selected), + loc=loc, + ip=ip, + ) + + +def knob( + result: Type, # !transform.any_param or !transform.param + name: Union[StringAttr, str], + options: Union[ + ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute + ], + *, + selected: Optional[Attribute] = None, + loc=None, + ip=None, +): + return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/vector.py b/mlir/python/mlir/dialects/transform/vector.py new file mode 100644 index 000000000..af2435cb2 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/vector.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._vector_transform_enum_gen import * +from .._vector_transform_ops_gen import * diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py new file mode 100644 index 000000000..7384e9a5a --- /dev/null +++ b/mlir/python/mlir/dialects/vector.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._vector_ops_gen import * +from ._vector_enum_gen import * diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py new file mode 100644 index 000000000..4739231c1 --- /dev/null +++ b/mlir/python/mlir/execution_engine.py @@ -0,0 +1,43 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# Simply a wrapper around the extension module of the same name. +from ._mlir_libs import _mlirExecutionEngine as _execution_engine +import ctypes + +__all__ = [ + "ExecutionEngine", +] + + +class ExecutionEngine(_execution_engine.ExecutionEngine): + def lookup(self, name): + """Lookup a function emitted with the `llvm.emit_c_interface` + attribute and returns a ctype callable. + Raise a RuntimeError if the function isn't found. + """ + func = self.raw_lookup("_mlir_ciface_" + name) + if not func: + raise RuntimeError("Unknown function " + name) + prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + return prototype(func) + + def invoke(self, name, *ctypes_args): + """Invoke a function with the list of ctypes arguments. + All arguments must be pointers. + Raise a RuntimeError if the function isn't found. + """ + func = self.lookup(name) + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + func(packed_args) + + def register_runtime(self, name, ctypes_callback): + """Register a runtime function available to the jitted code + under the provided `name`. The `ctypes_callback` must be a + `CFuncType` that outlives the execution engine. + """ + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) + self.raw_register_runtime("_mlir_ciface_" + name, callback) diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py new file mode 100644 index 000000000..3f2defadf --- /dev/null +++ b/mlir/python/mlir/extras/meta.py @@ -0,0 +1,83 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import inspect +from functools import wraps + +from ..dialects._ods_common import get_op_result_or_op_results +from ..ir import Type, InsertionPoint + + +def op_region_builder(op, op_region, terminator=None): + def builder_wrapper(body_builder): + # Add a block with block args having types determined by type hints on the wrapped function. + if len(op_region.blocks) == 0: + sig = inspect.signature(body_builder) + types = [p.annotation for p in sig.parameters.values()] + if not ( + len(types) == len(sig.parameters) + and all(isinstance(t, Type) for t in types) + ): + raise ValueError( + f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}" + ) + + op_region.blocks.append(*types) + + with InsertionPoint(op_region.blocks[0]): + results = body_builder(*list(op_region.blocks[0].arguments)) + + with InsertionPoint(list(op_region.blocks)[-1]): + if terminator is not None: + res = [] + if isinstance(results, (tuple, list)): + res.extend(results) + elif results is not None: + res.append(results) + terminator(res) + + return get_op_result_or_op_results(op) + + return builder_wrapper + + +def region_op(op_constructor, terminator=None): + """Decorator to define an MLIR Op specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor. + + When applied as a decorator to a Python function, an entry block will + be constructed for the Op with types as specified **as type hints on the args of the function**. + The block arguments will be passed positionally to the Python function. + + If a terminator is specified then the return from the decorated function will be passed + to the terminator as the last statement in the entry block. Note, the API for the terminator + is a (possibly empty) list; terminator accepting single values should be wrapped in a + `lambda args: term(args[0])` + + The identifier (name) of the function will become: + 1. A single value result if the Op returns a single value; + 2. An OpResultList (as a list) if the Op returns multiple values; + 3. The Operation if the Op returns no results. + + See examples in tensor.py and transform.extras. + """ + + def op_decorator(*args, **kwargs): + op = op_constructor(*args, **kwargs) + op_region = op.regions[0] + + return op_region_builder(op, op_region, terminator) + + @wraps(op_decorator) + def maybe_no_args(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return op_decorator()(args[0]) + else: + return op_decorator(*args, **kwargs) + + return maybe_no_args diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py new file mode 100644 index 000000000..b875d639e --- /dev/null +++ b/mlir/python/mlir/extras/types.py @@ -0,0 +1,179 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from functools import partial +from typing import Optional, List + +from ..ir import ( + Attribute, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + Float4E2M1FNType, + Float6E2M3FNType, + Float6E3M2FNType, + Float8E3M4Type, + Float8E4M3B11FNUZType, + Float8E4M3FNType, + Float8E4M3Type, + Float8E5M2Type, + Float8E8M0FNUType, + FloatTF32Type, + FunctionType, + IndexType, + IntegerType, + MemRefType, + NoneType, + OpaqueType, + RankedTensorType, + StridedLayoutAttr, + StringAttr, + TupleType, + Type, + UnrankedMemRefType, + UnrankedTensorType, + VectorType, +) + +index = lambda: IndexType.get() + + +def i(width): + return IntegerType.get_signless(width) + + +def si(width): + return IntegerType.get_signed(width) + + +def ui(width): + return IntegerType.get_unsigned(width) + + +bool = lambda: i(1) +i8 = lambda: i(8) +i16 = lambda: i(16) +i32 = lambda: i(32) +i64 = lambda: i(64) + +si8 = lambda: si(8) +si16 = lambda: si(16) +si32 = lambda: si(32) +si64 = lambda: si(64) + +ui8 = lambda: ui(8) +ui16 = lambda: ui(16) +ui32 = lambda: ui(32) +ui64 = lambda: ui(64) + +f16 = lambda: F16Type.get() +f32 = lambda: F32Type.get() +tf32 = lambda: FloatTF32Type.get() +f64 = lambda: F64Type.get() +bf16 = lambda: BF16Type.get() + +f8E5M2 = lambda: Float8E5M2Type.get() +f8E4M3 = lambda: Float8E4M3Type.get() +f8E4M3FN = lambda: Float8E4M3FNType.get() +f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() +f8E3M4 = lambda: Float8E3M4Type.get() +f4E2M1FN = lambda: Float4E2M1FNType.get() +f6E2M3FN = lambda: Float6E2M3FNType.get() +f6E3M2FN = lambda: Float6E3M2FNType.get() +f8E8M0FNU = lambda: Float8E8M0FNUType.get() + +none = lambda: NoneType.get() + + +def complex(type): + return ComplexType.get(type) + + +def opaque(dialect_namespace, type_data): + return OpaqueType.get(dialect_namespace, type_data) + + +def _shaped(*shape, element_type: Type = None, type_constructor=None): + if type_constructor is None: + raise ValueError("shaped is an abstract base class - cannot be constructed.") + if (element_type is None and shape and not isinstance(shape[-1], Type)) or ( + shape and isinstance(shape[-1], Type) and element_type is not None + ): + raise ValueError( + f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type." + ) + if element_type is not None: + type = element_type + sizes = shape + else: + type = shape[-1] + sizes = shape[:-1] + if sizes: + return type_constructor(sizes, type) + else: + return type_constructor(type) + + +def vector( + *shape, + element_type: Type = None, + scalable: Optional[List[bool]] = None, + scalable_dims: Optional[List[int]] = None, +): + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial( + VectorType.get, scalable=scalable, scalable_dims=scalable_dims + ), + ) + + +def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None): + if encoding is not None: + encoding = StringAttr.get(encoding) + if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)): + if encoding is not None: + raise ValueError("UnrankedTensorType does not support encoding.") + return _shaped( + *shape, element_type=element_type, type_constructor=UnrankedTensorType.get + ) + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial(RankedTensorType.get, encoding=encoding), + ) + + +def memref( + *shape, + element_type: Type = None, + memory_space: Optional[int] = None, + layout: Optional[StridedLayoutAttr] = None, +): + if memory_space is not None: + memory_space = Attribute.parse(str(memory_space)) + if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)): + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space), + ) + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial( + MemRefType.get, memory_space=memory_space, layout=layout + ), + ) + + +def tuple(*elements): + return TupleType.get_tuple(elements) + + +def function(*, inputs, results): + return FunctionType.get(inputs, results) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py new file mode 100644 index 000000000..6f37266d5 --- /dev/null +++ b/mlir/python/mlir/ir.py @@ -0,0 +1,316 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._mlir_libs._mlir.ir import * +from ._mlir_libs._mlir.ir import _GlobalDebug +from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs import ( + get_dialect_registry, + append_load_on_create_dialect, + get_load_on_create_dialects, +) + + +# Convenience decorator for registering user-friendly Attribute builders. +def register_attribute_builder(kind, replace=False): + def decorator_builder(func): + AttrBuilder.insert(kind, func, replace=replace) + return func + + return decorator_builder + + +@register_attribute_builder("AffineMapAttr") +def _affineMapAttr(x, context): + return AffineMapAttr.get(x) + + +@register_attribute_builder("IntegerSetAttr") +def _integerSetAttr(x, context): + return IntegerSetAttr.get(x) + + +@register_attribute_builder("BoolAttr") +def _boolAttr(x, context): + return BoolAttr.get(x, context=context) + + +@register_attribute_builder("DictionaryAttr") +def _dictAttr(x, context): + return DictAttr.get(x, context=context) + + +@register_attribute_builder("IndexAttr") +def _indexAttr(x, context): + return IntegerAttr.get(IndexType.get(context=context), x) + + +@register_attribute_builder("I1Attr") +def _i1Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(1, context=context), x) + + +@register_attribute_builder("I8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(8, context=context), x) + + +@register_attribute_builder("I16Attr") +def _i16Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) + + +@register_attribute_builder("I32Attr") +def _i32Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) + + +@register_attribute_builder("I64Attr") +def _i64Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) + + +@register_attribute_builder("SI1Attr") +def _si1Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(1, context=context), x) + + +@register_attribute_builder("SI8Attr") +def _si8Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(8, context=context), x) + + +@register_attribute_builder("SI16Attr") +def _si16Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) + + +@register_attribute_builder("SI32Attr") +def _si32Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) + + +@register_attribute_builder("SI64Attr") +def _si64Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(64, context=context), x) + + +@register_attribute_builder("UI1Attr") +def _ui1Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x) + + +@register_attribute_builder("UI8Attr") +def _ui8Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x) + + +@register_attribute_builder("UI16Attr") +def _ui16Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x) + + +@register_attribute_builder("UI32Attr") +def _ui32Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x) + + +@register_attribute_builder("UI64Attr") +def _ui64Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x) + + +@register_attribute_builder("F32Attr") +def _f32Attr(x, context): + return FloatAttr.get_f32(x, context=context) + + +@register_attribute_builder("F64Attr") +def _f64Attr(x, context): + return FloatAttr.get_f64(x, context=context) + + +@register_attribute_builder("StrAttr") +def _stringAttr(x, context): + return StringAttr.get(x, context=context) + + +@register_attribute_builder("SymbolNameAttr") +def _symbolNameAttr(x, context): + return StringAttr.get(x, context=context) + + +@register_attribute_builder("SymbolRefAttr") +def _symbolRefAttr(x, context): + if isinstance(x, list): + return SymbolRefAttr.get(x, context=context) + else: + return FlatSymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("FlatSymbolRefAttr") +def _flatSymbolRefAttr(x, context): + return FlatSymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("UnitAttr") +def _unitAttr(x, context): + if x: + return UnitAttr.get(context=context) + else: + return None + + +@register_attribute_builder("ArrayAttr") +def _arrayAttr(x, context): + return ArrayAttr.get(x, context=context) + + +@register_attribute_builder("AffineMapArrayAttr") +def _affineMapArrayAttr(x, context): + return ArrayAttr.get([_affineMapAttr(v, context) for v in x]) + + +@register_attribute_builder("BoolArrayAttr") +def _boolArrayAttr(x, context): + return ArrayAttr.get([_boolAttr(v, context) for v in x]) + + +@register_attribute_builder("DictArrayAttr") +def _dictArrayAttr(x, context): + return ArrayAttr.get([_dictAttr(v, context) for v in x]) + + +@register_attribute_builder("FlatSymbolRefArrayAttr") +def _flatSymbolRefArrayAttr(x, context): + return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x]) + + +@register_attribute_builder("I32ArrayAttr") +def _i32ArrayAttr(x, context): + return ArrayAttr.get([_i32Attr(v, context) for v in x]) + + +@register_attribute_builder("I64ArrayAttr") +def _i64ArrayAttr(x, context): + return ArrayAttr.get([_i64Attr(v, context) for v in x]) + + +@register_attribute_builder("I64SmallVectorArrayAttr") +def _i64SmallVectorArrayAttr(x, context): + return _i64ArrayAttr(x, context=context) + + +@register_attribute_builder("IndexListArrayAttr") +def _indexListArrayAttr(x, context): + return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x]) + + +@register_attribute_builder("F32ArrayAttr") +def _f32ArrayAttr(x, context): + return ArrayAttr.get([_f32Attr(v, context) for v in x]) + + +@register_attribute_builder("F64ArrayAttr") +def _f64ArrayAttr(x, context): + return ArrayAttr.get([_f64Attr(v, context) for v in x]) + + +@register_attribute_builder("StrArrayAttr") +def _strArrayAttr(x, context): + return ArrayAttr.get([_stringAttr(v, context) for v in x]) + + +@register_attribute_builder("SymbolRefArrayAttr") +def _symbolRefArrayAttr(x, context): + return ArrayAttr.get([_symbolRefAttr(v, context) for v in x]) + + +@register_attribute_builder("DenseF32ArrayAttr") +def _denseF32ArrayAttr(x, context): + return DenseF32ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseF64ArrayAttr") +def _denseF64ArrayAttr(x, context): + return DenseF64ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI8ArrayAttr") +def _denseI8ArrayAttr(x, context): + return DenseI8ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI16ArrayAttr") +def _denseI16ArrayAttr(x, context): + return DenseI16ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI32ArrayAttr") +def _denseI32ArrayAttr(x, context): + return DenseI32ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI64ArrayAttr") +def _denseI64ArrayAttr(x, context): + return DenseI64ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseBoolArrayAttr") +def _denseBoolArrayAttr(x, context): + return DenseBoolArrayAttr.get(x, context=context) + + +@register_attribute_builder("TypeAttr") +def _typeAttr(x, context): + return TypeAttr.get(x, context=context) + + +@register_attribute_builder("TypeArrayAttr") +def _typeArrayAttr(x, context): + return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) + + +@register_attribute_builder("MemRefTypeAttr") +def _memref_type_attr(x, context): + return _typeAttr(x, context) + + +try: + import numpy as np + + @register_attribute_builder("F64ElementsAttr") + def _f64ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.float64), + type=F64Type.get(context=context), + context=context, + ) + + @register_attribute_builder("I32ElementsAttr") + def _i32ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int32), + type=IntegerType.get_signless(32, context=context), + context=context, + ) + + @register_attribute_builder("I64ElementsAttr") + def _i64ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=IntegerType.get_signless(64, context=context), + context=context, + ) + + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=IndexType.get(context=context), + context=context, + ) + +except ImportError: + pass diff --git a/mlir/python/mlir/passmanager.py b/mlir/python/mlir/passmanager.py new file mode 100644 index 000000000..22e86b879 --- /dev/null +++ b/mlir/python/mlir/passmanager.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._mlir_libs._mlir.passmanager import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/python_test.py b/mlir/python/mlir/rewrite.py similarity index 83% rename from mlir/lib/Bindings/Python/mlir/dialects/python_test.py rename to mlir/python/mlir/rewrite.py index 524db4317..5bc1bba7a 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/rewrite.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._python_test_ops_gen import * +from ._mlir_libs._mlir.rewrite import * diff --git a/mlir/python/mlir/runtime/__init__.py b/mlir/python/mlir/runtime/__init__.py new file mode 100644 index 000000000..8a28fd935 --- /dev/null +++ b/mlir/python/mlir/runtime/__init__.py @@ -0,0 +1 @@ +from .np_to_memref import * diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py new file mode 100644 index 000000000..8cca1e7ad --- /dev/null +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -0,0 +1,184 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. + +import numpy as np +import ctypes + +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes provides some optional low precision data-types for NumPy. + ml_dtypes = None + + +class C128(ctypes.Structure): + """A ctype representation for MLIR's Double Complex.""" + + _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] + + +class C64(ctypes.Structure): + """A ctype representation for MLIR's Float Complex.""" + + _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] + + +class F16(ctypes.Structure): + """A ctype representation for MLIR's Float16.""" + + _fields_ = [("f16", ctypes.c_int16)] + + +class BF16(ctypes.Structure): + """A ctype representation for MLIR's BFloat16.""" + + _fields_ = [("bf16", ctypes.c_int16)] + +class F8E5M2(ctypes.Structure): + """A ctype representation for MLIR's Float8E5M2.""" + + _fields_ = [("f8E5M2", ctypes.c_int8)] + + +# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype +def as_ctype(dtp): + """Converts dtype to ctype.""" + if dtp == np.dtype(np.complex128): + return C128 + if dtp == np.dtype(np.complex64): + return C64 + if dtp == np.dtype(np.float16): + return F16 + if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: + return BF16 + if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2: + return F8E5M2 + return np.ctypeslib.as_ctypes_type(dtp) + + +def to_numpy(array): + """Converts ctypes array back to numpy dtype array.""" + if array.dtype == C128: + return array.view("complex128") + if array.dtype == C64: + return array.view("complex64") + if array.dtype == F16: + return array.view("float16") + assert not ( + array.dtype == BF16 and ml_dtypes is None + ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == BF16: + return array.view("bfloat16") + assert not ( + array.dtype == F8E5M2 and ml_dtypes is None + ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == F8E5M2: + return array.view("float8_e5m2") + return array + + +def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given rank/dtype, where rank>0.""" + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor + + +def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given dtype, where rank=0.""" + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + + return MemRefDescriptor + + +class UnrankedMemRefDescriptor(ctypes.Structure): + """Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + + +def get_ranked_memref_descriptor(nparray): + """Returns a ranked memref descriptor for the given numpy array.""" + ctp = as_ctype(nparray.dtype) + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + return x + + x = make_nd_memref_descriptor(nparray.ndim, ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x + + +def get_unranked_memref_descriptor(nparray): + """Returns a generic/unranked memref descriptor for the given numpy array.""" + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d + + +def move_aligned_ptr_by_offset(aligned_ptr, offset): + """Moves the supplied ctypes pointer ahead by `offset` elements.""" + aligned_addr = ctypes.addressof(aligned_ptr.contents) + elem_size = ctypes.sizeof(aligned_ptr.contents) + shift = offset * elem_size + content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) + return content_ptr + + +def unranked_memref_to_numpy(unranked_memref, np_dtype): + """Converts unranked memrefs to numpy arrays.""" + ctp = as_ctype(np_dtype) + descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset) + np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) + + +def ranked_memref_to_numpy(ranked_memref): + """Converts ranked memrefs to numpy arrays.""" + content_ptr = move_aligned_ptr_by_offset( + ranked_memref[0].aligned, ranked_memref[0].offset + ) + np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt new file mode 100644 index 000000000..1a0075e82 --- /dev/null +++ b/mlir/python/requirements.txt @@ -0,0 +1,6 @@ +nanobind>=2.4, <3.0 +numpy>=1.19.5, <=2.1.2 +pybind11>=2.10.0, <=2.13.6 +PyYAML>=5.4.0, <=6.0.1 +ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" \ No newline at end of file