diff --git a/.github/workflows/integrate_llvm.yml b/.github/workflows/integrate_llvm.yml new file mode 100644 index 000000000..64fcd021e --- /dev/null +++ b/.github/workflows/integrate_llvm.yml @@ -0,0 +1,120 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Copyright (c) 2024. + +name: Auto Integrate LLVM +on: + workflow_dispatch: + schedule: + # At minute 0 past hour 1. (see https://crontab.guru) + - cron: '00 01 * * *' + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + update-dep: + + name: "Integrate LLVM and send PR" + + runs-on: ubuntu-latest + + permissions: + contents: write + pull-requests: write + + steps: + - name: "Check out repository" + uses: actions/checkout@v4.2.2 + with: + fetch-depth: 0 + + - name: Restore LLVM project cache + id: llvm-project-restore + uses: actions/cache/restore@v4 + with: + path: /tmp/llvm-project + key: cache-llvm-project + + - name: "Get llvm-project" + shell: bash + id: get-llvm-project + run: | + + # https://github.com/actions/cache/issues/1566 + if [ "${{ steps.llvm-project-restore.outputs.cache-hit }}" == "" ]; then + git clone https://github.com/llvm/llvm-project.git /tmp/llvm-project + fi + + pushd /tmp/llvm-project + git pull origin main + echo "LLVM_SHA_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT + popd + + - name: Save LLVM project cache + id: llvm-project-save + uses: actions/cache/save@v4 + with: + path: /tmp/llvm-project + key: "cache-llvm-project-${{ format('{0}-{1}', github.ref_name, github.run_number) }}" + + - name: "Make filtered llvm-project" + id: make-filtered-llvm-project + shell: bash + run: | + + DEBIAN_FRONTEND=noninteractive sudo apt install -y git-filter-repo + HERE=$(pwd) + + pushd /tmp/llvm-project + bash $HERE/filter-llvm.sh + popd + + - name: "Rebase on top of llvm-project" + shell: bash + id: rebase-llvm-project + run: | + + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + git pull /tmp/llvm-project main --rebase + + - name: Generate token + uses: actions/create-github-app-token@v1 + id: generate-token + with: + app-id: ${{ secrets.BUMP_LLVM_CREATE_PR_APP_ID }} + private-key: ${{ secrets.BUMP_LLVM_CREATE_PR_APP_PRIVATE_KEY }} + + - name: "Create Pull Request" + id: cpr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ steps.generate-token.outputs.token }} + commit-message: "[LLVM] Integrate to ${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + title: "[LLVM] Integrate to ${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + body: "Integrate LLVM to https://github.com/llvm/llvm-project/commit/${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + author: 'github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>' + base: main + branch: update-llvm + delete-branch: true + + - name: Enable auto-merge + if: steps.cpr.outputs.pull-request-operation == 'created' + uses: peter-evans/enable-pull-request-automerge@v3 + with: + token: ${{ steps.generate-token.outputs.token }} + pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} + merge-method: rebase + + - name: Auto approve + if: steps.cpr.outputs.pull-request-operation == 'created' + run: gh pr review --approve "${{ steps.cpr.outputs.pull-request-number }}" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..1b17bee17 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +llvm-project \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..b076e8519 --- /dev/null +++ b/README.md @@ -0,0 +1,34 @@ +# TL;DR: + +In this repo: + +```shell +$ git clone git@github.com:llvm/llvm-project.git /tmp/llvm-project +$ cp filter-llvm.sh /tmp/llvm-project +$ pushd llvm-project +$ bash filter-llvm.sh +$ popd +$ git remote add upstream /tmp/llvm-project +$ git pull upstream main +$ git remote add origin https://github.com/makslevental/python_bindings_fork +$ git push -u origin main -f +``` + +In the other repo: + +``` +$ git remote add upstream git@github.com:makslevental/python_bindings_fork.git +$ git subtree add -P $WHEREVER_YOU_WANT upstream main +``` + +Note, you will have to have at least one commit for this to work in the "other repo"; do `git commit --allow-empty -m 'Initial commit'` if not. + +Then to pull changes from "this" repo: + +``` +$ git subtree pull -P $WHEREVER_YOU_WANT upstream main +``` + +# Requires + +https://github.com/newren/git-filter-repo diff --git a/filter-llvm.sh b/filter-llvm.sh new file mode 100755 index 000000000..b28ca7230 --- /dev/null +++ b/filter-llvm.sh @@ -0,0 +1,42 @@ +#!/usr/bin/env bash + +#rm -rf .git +#rm -rf mlir +#git init +#git checkout -b main + +emreg=$(cat <[-!#-'*+/-9=?A-Z^-~]+(\.[-!#-'*+/-9=?A-Z^-~]+)*|\"([]!#-[^-~ \t]|(\\[\t -~]))+\")@(?P[-!#-'*+/-9=?A-Z^-~]+(\.[-!#-'*+/-9=?A-Z^-~]+)*|\[[\t -Z^-~]*])"); +rr = re.compile(r"@(?P[a-z\d](?:[a-z\d]|-(?=[a-z\d])){0,38})(?=\b)"); + +message = message.decode("utf-8"); +message = r.sub(r"\g$$$\g", message); +message = rr.sub(r"\g", message); +message = message.replace("$$$", "@"); + +return message.encode("utf-8") + +EOF +) + +git replace --graft 0f0d0ed1c78f1a80139a1f2133fad5284691a121 5b4a01d4a63cb66ab981e52548f940813393bf42 + +git filter-repo \ + --path mlir/include/mlir/Bindings/Python \ + --path mlir/include/mlir/CAPI \ + --path mlir/include/mlir-c \ + --path mlir/lib/Bindings/Python \ + --path mlir/lib/CAPI \ + --path mlir/lib/Bindings/Python \ + --path mlir/python \ + --path mlir/tools/mlir-tblgen \ + --message-callback "$emreg" \ + --force + +#git remote add -f upstream /tmp/llvm-project +#git pull upstream main +#git remote add origin git@github.com:makslevental/python_bindings_fork.git +#git push -u origin main -f diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 65b14254e..c1ade9ed8 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -306,7 +306,8 @@ typedef enum MlirLLVMDINameTableKind MlirLLVMDINameTableKind; MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompileUnitAttrGet( MlirContext ctx, MlirAttribute id, unsigned int sourceLanguage, MlirAttribute file, MlirAttribute producer, bool isOptimized, - MlirLLVMDIEmissionKind emissionKind, MlirLLVMDINameTableKind nameTableKind); + MlirLLVMDIEmissionKind emissionKind, MlirLLVMDINameTableKind nameTableKind, + MlirAttribute splitDebugFilename); /// Creates a LLVM DIFlags attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 71c7d4378..c464e4da6 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -415,6 +415,12 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); /// The returned module is null when the input operation was not a ModuleOp. MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); +/// Checks if two modules are equal. +MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs); + +/// Compute a hash for the given module. +MLIR_CAPI_EXPORTED size_t mlirModuleHashValue(MlirModule mod); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// @@ -619,12 +625,19 @@ static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; } MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, MlirOperation other); +/// Compute a hash for the given operation. +MLIR_CAPI_EXPORTED size_t mlirOperationHashValue(MlirOperation op); + /// Gets the context this operation is associated with MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); /// Gets the location of the operation. MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op); +/// Sets the location of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetLocation(MlirOperation op, + MlirLocation loc); + /// Gets the type id of the operation. /// Returns null if the operation does not have a registered operation /// description. diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 0d2e19ee7..1f63c6d0d 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -92,6 +92,18 @@ mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); MLIR_CAPI_EXPORTED void mlirPassManagerEnableTiming(MlirPassManager passManager); +/// Enumerated type of pass display modes. +/// Mainly used in mlirPassManagerEnableStatistics. +typedef enum { + MLIR_PASS_DISPLAY_MODE_LIST, + MLIR_PASS_DISPLAY_MODE_PIPELINE, +} MlirPassDisplayMode; + +/// Enable pass statistics. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableStatistics(MlirPassManager passManager, + MlirPassDisplayMode displayMode); + /// Nest an OpPassManager under the top-level PassManager, the nested /// passmanager will only run on operations matching the provided name. /// The returned OpPassManager will be destroyed when the parent is destroyed. diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 61d344631..5dd285ee0 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -37,6 +37,7 @@ DEFINE_C_API_STRUCT(MlirRewriterBase, void); DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); +DEFINE_C_API_STRUCT(MlirPatternRewriter, void); //===----------------------------------------------------------------------===// /// RewriterBase API inherited from OpBuilder @@ -100,6 +101,12 @@ mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter); MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter); +/// Returns the operation right after the current insertion point +/// of the rewriter. A null MlirOperation will be returned +// if the current insertion point is at the end of the block. +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter); + //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning //===----------------------------------------------------------------------===// @@ -301,16 +308,30 @@ mlirFreezeRewritePattern(MlirRewritePatternSet op); MLIR_CAPI_EXPORTED void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedilyWithOp( + MlirOperation op, MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig); + MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig); +//===----------------------------------------------------------------------===// +/// PatternRewriter API +//===----------------------------------------------------------------------===// + +/// Cast the PatternRewriter to a RewriterBase +MLIR_CAPI_EXPORTED MlirRewriterBase +mlirPatternRewriterAsBase(MlirPatternRewriter rewriter); + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); +DEFINE_C_API_STRUCT(MlirPDLValue, const void); +DEFINE_C_API_STRUCT(MlirPDLResultList, void); MLIR_CAPI_EXPORTED MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op); @@ -319,6 +340,69 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); MLIR_CAPI_EXPORTED MlirRewritePatternSet mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); + +/// Cast the MlirPDLValue to an MlirValue. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirValue mlirPDLValueAsValue(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirType. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirType mlirPDLValueAsType(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirOperation. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirOperation mlirPDLValueAsOperation(MlirPDLValue value); + +/// Cast the MlirPDLValue to an MlirAttribute. +/// Return a null value if the cast fails, just like llvm::dyn_cast. +MLIR_CAPI_EXPORTED MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value); + +/// Push the MlirValue into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackValue(MlirPDLResultList results, MlirValue value); + +/// Push the MlirType into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void mlirPDLResultListPushBackType(MlirPDLResultList results, + MlirType value); + +/// Push the MlirOperation into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value); + +/// Push the MlirAttribute into the given MlirPDLResultList. +MLIR_CAPI_EXPORTED void +mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value); + +/// This function type is used as callbacks for PDL native rewrite functions. +/// Input values can be accessed by `values` with its size `nValues`; +/// output values can be added into `results` by `mlirPDLResultListPushBack*` +/// APIs. And the return value indicates whether the rewrite succeeds. +typedef MlirLogicalResult (*MlirPDLRewriteFunction)( + MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, + MlirPDLValue *values, void *userData); + +/// Register a rewrite function into the given PDL pattern module. +/// `userData` will be provided as an argument to the rewrite function. +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData); + +/// This function type is used as callbacks for PDL native constraint functions. +/// Input values can be accessed by `values` with its size `nValues`; +/// output values can be added into `results` by `mlirPDLResultListPushBack*` +/// APIs. And the return value indicates whether the constraint holds. +typedef MlirLogicalResult (*MlirPDLConstraintFunction)( + MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues, + MlirPDLValue *values, void *userData); + +/// Register a constraint function into the given PDL pattern module. +/// `userData` will be provided as an argument to the constraint function. +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLConstraintFunction constraintFn, void *userData); + #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH #undef DEFINE_C_API_STRUCT diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 1428d5ccf..847951ab5 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -19,7 +19,9 @@ #ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H +#include #include +#include #include #include "mlir-c/Diagnostics.h" @@ -30,6 +32,57 @@ // clang-format on #include "llvm/ADT/Twine.h" +namespace mlir { +namespace python { +namespace { + +// Safely calls Python initialization code on first use, avoiding deadlocks. +template +class SafeInit { +public: + typedef std::unique_ptr (*F)(); + + explicit SafeInit(F init_fn) : initFn(init_fn) {} + + T &get() { + if (T *result = output.load()) { + return *result; + } + + // Note: init_fn() may be called multiple times if, for example, the GIL is + // released during its execution. The intended use case is for module + // imports which are safe to perform multiple times. We are careful not to + // hold a lock across init_fn() to avoid lock ordering problems. + std::unique_ptr m = initFn(); + { + nanobind::ft_lock_guard lock(mu); + if (T *result = output.load()) { + return *result; + } + T *p = m.release(); + output.store(p); + return *p; + } + } + +private: + nanobind::ft_mutex mu; + std::atomic output{nullptr}; + F initFn; +}; + +nanobind::module_ &irModule() { + static SafeInit init([]() { + return std::make_unique( + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))); + }); + return init.get(); +} + +} // namespace +} // namespace python +} // namespace mlir + // Raw CAPI type casters need to be declared before use, so always include them // first. namespace nanobind { @@ -63,7 +116,8 @@ mlirApiObjectToCapsule(nanobind::handle apiObject) { /// Casts object <-> MlirAffineMap. template <> struct type_caster { - NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")) + NB_TYPE_CASTER(MlirAffineMap, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToAffineMap(capsule->ptr()); @@ -75,7 +129,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonAffineMapToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("AffineMap") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -85,7 +139,8 @@ struct type_caster { /// Casts object <-> MlirAttribute. template <> struct type_caster { - NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")) + NB_TYPE_CASTER(MlirAttribute, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToAttribute(capsule->ptr()); @@ -97,7 +152,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonAttributeToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() @@ -108,7 +163,7 @@ struct type_caster { /// Casts object -> MlirBlock. template <> struct type_caster { - NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")) + NB_TYPE_CASTER(MlirBlock, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Block"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToBlock(capsule->ptr()); @@ -121,16 +176,15 @@ struct type_caster { /// Casts object -> MlirContext. template <> struct type_caster { - NB_TYPE_CASTER(MlirContext, const_name("MlirContext")) + NB_TYPE_CASTER(MlirContext, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Context"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. // TODO: This raises an error of "No current context" currently. // Update the implementation to pretty-print the helpful error that the // core implementations print in this case. - src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Context") - .attr("current"); + src = mlir::python::irModule().attr("Context").attr("current"); } std::optional capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToContext(capsule->ptr()); @@ -141,7 +195,8 @@ struct type_caster { /// Casts object <-> MlirDialectRegistry. template <> struct type_caster { - NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")) + NB_TYPE_CASTER(MlirDialectRegistry, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToDialectRegistry(capsule->ptr()); @@ -153,7 +208,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal( mlirPythonDialectRegistryToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("DialectRegistry") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -163,13 +218,12 @@ struct type_caster { /// Casts object <-> MlirLocation. template <> struct type_caster { - NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")) + NB_TYPE_CASTER(MlirLocation, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Location"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. - src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Location") - .attr("current"); + src = mlir::python::irModule().attr("Location").attr("current"); } if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToLocation(capsule->ptr()); @@ -181,7 +235,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonLocationToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Location") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -191,7 +245,7 @@ struct type_caster { /// Casts object <-> MlirModule. template <> struct type_caster { - NB_TYPE_CASTER(MlirModule, const_name("MlirModule")) + NB_TYPE_CASTER(MlirModule, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Module"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToModule(capsule->ptr()); @@ -203,7 +257,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonModuleToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Module") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -213,8 +267,9 @@ struct type_caster { /// Casts object <-> MlirFrozenRewritePatternSet. template <> struct type_caster { - NB_TYPE_CASTER(MlirFrozenRewritePatternSet, - const_name("MlirFrozenRewritePatternSet")) + NB_TYPE_CASTER( + MlirFrozenRewritePatternSet, + const_name(MAKE_MLIR_PYTHON_QUALNAME("rewrite.FrozenRewritePatternSet"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr()); @@ -236,7 +291,8 @@ struct type_caster { /// Casts object <-> MlirOperation. template <> struct type_caster { - NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")) + NB_TYPE_CASTER(MlirOperation, + const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Operation"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToOperation(capsule->ptr()); @@ -250,7 +306,7 @@ struct type_caster { return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonOperationToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Operation") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -260,7 +316,7 @@ struct type_caster { /// Casts object <-> MlirValue. template <> struct type_caster { - NB_TYPE_CASTER(MlirValue, const_name("MlirValue")) + NB_TYPE_CASTER(MlirValue, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Value"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToValue(capsule->ptr()); @@ -274,7 +330,7 @@ struct type_caster { return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonValueToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Value") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() @@ -285,7 +341,8 @@ struct type_caster { /// Casts object -> MlirPassManager. template <> struct type_caster { - NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")) + NB_TYPE_CASTER(MlirPassManager, const_name(MAKE_MLIR_PYTHON_QUALNAME( + "passmanager.PassManager"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToPassManager(capsule->ptr()); @@ -298,7 +355,7 @@ struct type_caster { /// Casts object <-> MlirTypeID. template <> struct type_caster { - NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")) + NB_TYPE_CASTER(MlirTypeID, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToTypeID(capsule->ptr()); @@ -312,7 +369,7 @@ struct type_caster { return nanobind::none(); nanobind::object capsule = nanobind::steal(mlirPythonTypeIDToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("TypeID") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -322,7 +379,7 @@ struct type_caster { /// Casts object <-> MlirType. template <> struct type_caster { - NB_TYPE_CASTER(MlirType, const_name("MlirType")) + NB_TYPE_CASTER(MlirType, const_name(MAKE_MLIR_PYTHON_QUALNAME("ir.Type"))) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToType(capsule->ptr()); @@ -334,7 +391,7 @@ struct type_caster { cleanup_list *cleanup) noexcept { nanobind::object capsule = nanobind::steal(mlirPythonTypeToCapsule(t)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + return mlir::python::irModule() .attr("Type") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() @@ -345,7 +402,7 @@ struct type_caster { /// Casts MlirStringRef -> object. template <> struct type_caster { - NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef")) + NB_TYPE_CASTER(MlirStringRef, const_name("str")) static handle from_cpp(MlirStringRef s, rv_policy, cleanup_list *cleanup) noexcept { return nanobind::str(s.data, s.length).release(); @@ -453,11 +510,9 @@ class mlir_attribute_subclass : public pure_subclass { mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_attribute_subclass( - scope, attrClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute"), - getTypeIDFunction) {} + : mlir_attribute_subclass(scope, attrClassName, isaFunction, + irModule().attr("Attribute"), + getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Attribute super-class. This must /// be used if the subclass is being defined in the same extension module @@ -512,8 +567,13 @@ class mlir_attribute_subclass : public pure_subclass { .attr("replace")(superCls.attr("__name__"), captureTypeName); }); if (getTypeIDFunction) { - def_staticmethod("get_static_typeid", - [getTypeIDFunction]() { return getTypeIDFunction(); }); + def_staticmethod( + "get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }, + // clang-format off + nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID")) + // clang-format on + ); nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( getTypeIDFunction())(nanobind::cpp_function( @@ -535,11 +595,8 @@ class mlir_type_subclass : public pure_subclass { mlir_type_subclass(nanobind::handle scope, const char *typeClassName, IsAFunctionTy isaFunction, GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_type_subclass( - scope, typeClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Type"), - getTypeIDFunction) {} + : mlir_type_subclass(scope, typeClassName, isaFunction, + irModule().attr("Type"), getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Type super-class. This must /// be used if the subclass is being defined in the same extension module @@ -582,8 +639,9 @@ class mlir_type_subclass : public pure_subclass { // 'isinstance' method. static const char kIsinstanceSig[] = - "def isinstance(other_type: " MAKE_MLIR_PYTHON_QUALNAME( - "ir") ".Type) -> bool"; + // clang-format off + "def isinstance(other_type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ") -> bool"; + // clang-format on def_staticmethod( "isinstance", [isaFunction](MlirType other) { return isaFunction(other); }, @@ -599,8 +657,13 @@ class mlir_type_subclass : public pure_subclass { // `def_property_readonly_static` is not available in `pure_subclass` and // we do not want to introduce the complexity that pybind uses to // implement it. - def_staticmethod("get_static_typeid", - [getTypeIDFunction]() { return getTypeIDFunction(); }); + def_staticmethod( + "get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }, + // clang-format off + nanobind::sig("def get_static_typeid() -> " MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID")) + // clang-format on + ); nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( getTypeIDFunction())(nanobind::cpp_function( @@ -620,10 +683,8 @@ class mlir_value_subclass : public pure_subclass { /// Subclasses by looking up the super-class dynamically. mlir_value_subclass(nanobind::handle scope, const char *valueClassName, IsAFunctionTy isaFunction) - : mlir_value_subclass( - scope, valueClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Value")) {} + : mlir_value_subclass(scope, valueClassName, isaFunction, + irModule().attr("Value")) {} /// Subclasses with a provided mlir.ir.Value super-class. This must /// be used if the subclass is being defined in the same extension module @@ -665,8 +726,9 @@ class mlir_value_subclass : public pure_subclass { // 'isinstance' method. static const char kIsinstanceSig[] = - "def isinstance(other_value: " MAKE_MLIR_PYTHON_QUALNAME( - "ir") ".Value) -> bool"; + // clang-format off + "def isinstance(other_value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ") -> bool"; + // clang-format on def_staticmethod( "isinstance", [isaFunction](MlirValue other) { return isaFunction(other); }, diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp index a21176fff..2568d535e 100644 --- a/mlir/lib/Bindings/Python/DialectGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -38,7 +38,7 @@ NB_MODULE(_mlirDialectsGPU, m) { return cls(mlirGPUAsyncTokenTypeGet(ctx)); }, "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"), - nb::arg("ctx").none() = nb::none()); + nb::arg("ctx") = nb::none()); //===-------------------------------------------------------------------===// // ObjectAttr @@ -62,7 +62,7 @@ NB_MODULE(_mlirDialectsGPU, m) { : MlirAttribute{nullptr})); }, "cls"_a, "target"_a, "format"_a, "object"_a, - "properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(), + "properties"_a = nb::none(), "kernels"_a = nb::none(), "Gets a gpu.object from parameters.") .def_property_readonly( "target", diff --git a/mlir/lib/Bindings/Python/DialectIRDL.cpp b/mlir/lib/Bindings/Python/DialectIRDL.cpp new file mode 100644 index 000000000..08bcab97c --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectIRDL.cpp @@ -0,0 +1,35 @@ +//===--- DialectIRDL.cpp - Pybind module for IRDL dialect API support ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/IRDL.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +static void populateDialectIRDLSubmodule(nb::module_ &m) { + m.def( + "load_dialects", + [](MlirModule module) { + if (mlirLogicalResultIsFailure(mlirLoadIRDLDialects(module))) + throw std::runtime_error( + "failed to load IRDL dialects from the input module"); + }, + nb::arg("module"), "Load IRDL dialects from the given module."); +} + +NB_MODULE(_mlirDialectsIRDL, m) { + m.doc() = "MLIR IRDL dialect."; + + populateDialectIRDLSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index ee106c032..38de4a0e3 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -33,21 +33,37 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { auto llvmStructType = mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); - llvmStructType.def_classmethod( - "get_literal", - [](const nb::object &cls, const std::vector &elements, - bool packed, MlirLocation loc) { - CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); - - MlirType type = mlirLLVMStructTypeLiteralGetChecked( - loc, elements.size(), elements.data(), packed); - if (mlirTypeIsNull(type)) { - throw nb::value_error(scope.takeMessage().c_str()); - } - return cls(type); - }, - "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, - "loc"_a.none() = nb::none()); + llvmStructType + .def_classmethod( + "get_literal", + [](const nb::object &cls, const std::vector &elements, + bool packed, MlirLocation loc) { + CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); + + MlirType type = mlirLLVMStructTypeLiteralGetChecked( + loc, elements.size(), elements.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return cls(type); + }, + "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "loc"_a = nb::none()) + .def_classmethod( + "get_literal_unchecked", + [](const nb::object &cls, const std::vector &elements, + bool packed, MlirContext context) { + CollectDiagnosticsToStringScope scope(context); + + MlirType type = mlirLLVMStructTypeLiteralGet( + context, elements.size(), elements.data(), packed); + if (mlirTypeIsNull(type)) { + throw nb::value_error(scope.takeMessage().c_str()); + } + return cls(type); + }, + "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "context"_a = nb::none()); llvmStructType.def_classmethod( "get_identified", @@ -55,7 +71,7 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { return cls(mlirLLVMStructTypeIdentifiedGet( context, mlirStringRefCreate(name.data(), name.size()))); }, - "cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none()); + "cls"_a, "name"_a, nb::kw_only(), "context"_a = nb::none()); llvmStructType.def_classmethod( "get_opaque", @@ -63,7 +79,7 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { return cls(mlirLLVMStructTypeOpaqueGet( context, mlirStringRefCreate(name.data(), name.size()))); }, - "cls"_a, "name"_a, "context"_a.none() = nb::none()); + "cls"_a, "name"_a, "context"_a = nb::none()); llvmStructType.def( "set_body", @@ -86,7 +102,7 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { elements.size(), elements.data(), packed)); }, "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false, - "context"_a.none() = nb::none()); + "context"_a = nb::none()); llvmStructType.def_property_readonly( "name", [](MlirType type) -> std::optional { @@ -133,8 +149,8 @@ static void populateDialectLLVMSubmodule(const nanobind::module_ &m) { } return cls(type); }, - "cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(), - "context"_a.none() = nb::none()) + "cls"_a, "address_space"_a = nb::none(), nb::kw_only(), + "context"_a = nb::none()) .def_property_readonly("address_space", [](MlirType type) { return mlirLLVMPointerTypeGetAddressSpace(type); }); diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp index bb3f519c9..189174164 100644 --- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -31,7 +31,7 @@ static void populateDialectNVGPUSubmodule(const nb::module_ &m) { "Gets an instance of TensorMapDescriptorType in the same context", nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"), nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"), - nb::arg("ctx").none() = nb::none()); + nb::arg("ctx") = nb::none()); } NB_MODULE(_mlirDialectsNVGPU, m) { diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp index 2acedbc26..1acb41080 100644 --- a/mlir/lib/Bindings/Python/DialectPDL.cpp +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -36,7 +36,7 @@ static void populateDialectPDLSubmodule(const nanobind::module_ &m) { return cls(mlirPDLAttributeTypeGet(ctx)); }, "Get an instance of AttributeType in given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); //===-------------------------------------------------------------------===// // OperationType @@ -50,7 +50,7 @@ static void populateDialectPDLSubmodule(const nanobind::module_ &m) { return cls(mlirPDLOperationTypeGet(ctx)); }, "Get an instance of OperationType in given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); //===-------------------------------------------------------------------===// // RangeType @@ -68,6 +68,8 @@ static void populateDialectPDLSubmodule(const nanobind::module_ &m) { rangeType.def_property_readonly( "element_type", [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); }, + nb::sig( + "def element_type(self) -> " MAKE_MLIR_PYTHON_QUALNAME("ir.Type")), "Get the element type."); //===-------------------------------------------------------------------===// @@ -81,7 +83,7 @@ static void populateDialectPDLSubmodule(const nanobind::module_ &m) { return cls(mlirPDLTypeTypeGet(ctx)); }, "Get an instance of TypeType in given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); //===-------------------------------------------------------------------===// // ValueType @@ -94,7 +96,7 @@ static void populateDialectPDLSubmodule(const nanobind::module_ &m) { return cls(mlirPDLValueTypeGet(ctx)); }, "Get an instance of TypeType in given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); } NB_MODULE(_mlirDialectsPDL, m) { diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp index cab4219fe..0d1d9e89f 100644 --- a/mlir/lib/Bindings/Python/DialectSMT.cpp +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -26,21 +26,26 @@ using namespace mlir::python::nanobind_adaptors; static void populateDialectSMTSubmodule(nanobind::module_ &m) { - auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) - .def_classmethod( - "get", - [](const nb::object &, MlirContext context) { - return mlirSMTTypeGetBool(context); - }, - "cls"_a, "context"_a.none() = nb::none()); + auto smtBoolType = + mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) + .def_staticmethod( + "get", + [](MlirContext context) { return mlirSMTTypeGetBool(context); }, + "context"_a = nb::none()); auto smtBitVectorType = mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector) - .def_classmethod( + .def_staticmethod( "get", - [](const nb::object &, int32_t width, MlirContext context) { + [](int32_t width, MlirContext context) { return mlirSMTTypeGetBitVector(context, width); }, - "cls"_a, "width"_a, "context"_a.none() = nb::none()); + "width"_a, "context"_a = nb::none()); + auto smtIntType = + mlir_type_subclass(m, "IntType", mlirSMTTypeIsAInt) + .def_staticmethod( + "get", + [](MlirContext context) { return mlirSMTTypeGetInt(context); }, + "context"_a = nb::none()); auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, bool indentLetBody) { diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 9d7dc1107..00b65ee97 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -53,9 +53,8 @@ static void populateDialectSparseTensorSubmodule(const nb::module_ &m) { }, nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(), nb::arg("lvl_to_dim").none(), nb::arg("pos_width"), - nb::arg("crd_width"), nb::arg("explicit_val").none() = nb::none(), - nb::arg("implicit_val").none() = nb::none(), - nb::arg("context").none() = nb::none(), + nb::arg("crd_width"), nb::arg("explicit_val") = nb::none(), + nb::arg("implicit_val") = nb::none(), nb::arg("context") = nb::none(), "Gets a sparse_tensor.encoding from parameters.") .def_classmethod( "build_level_type", diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index 0190edf79..150c69953 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -19,7 +19,7 @@ using namespace mlir; using namespace mlir::python; using namespace mlir::python::nanobind_adaptors; -void populateDialectTransformSubmodule(const nb::module_ &m) { +static void populateDialectTransformSubmodule(const nb::module_ &m) { //===-------------------------------------------------------------------===// // AnyOpType //===-------------------------------------------------------------------===// @@ -29,11 +29,11 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformAnyOpTypeGetTypeID); anyOpType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirTransformAnyOpTypeGet(ctx)); }, "Get an instance of AnyOpType in the given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); //===-------------------------------------------------------------------===// // AnyParamType @@ -44,11 +44,11 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformAnyParamTypeGetTypeID); anyParamType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirTransformAnyParamTypeGet(ctx)); }, "Get an instance of AnyParamType in the given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); //===-------------------------------------------------------------------===// // AnyValueType @@ -59,11 +59,11 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformAnyValueTypeGetTypeID); anyValueType.def_classmethod( "get", - [](nb::object cls, MlirContext ctx) { + [](const nb::object &cls, MlirContext ctx) { return cls(mlirTransformAnyValueTypeGet(ctx)); }, "Get an instance of AnyValueType in the given context.", nb::arg("cls"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); //===-------------------------------------------------------------------===// // OperationType @@ -74,7 +74,8 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformOperationTypeGetTypeID); operationType.def_classmethod( "get", - [](nb::object cls, const std::string &operationName, MlirContext ctx) { + [](const nb::object &cls, const std::string &operationName, + MlirContext ctx) { MlirStringRef cOperationName = mlirStringRefCreate(operationName.data(), operationName.size()); return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); @@ -82,7 +83,7 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { "Get an instance of OperationType for the given kind in the given " "context", nb::arg("cls"), nb::arg("operation_name"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); operationType.def_property_readonly( "operation_name", [](MlirType type) { @@ -101,11 +102,11 @@ void populateDialectTransformSubmodule(const nb::module_ &m) { mlirTransformParamTypeGetTypeID); paramType.def_classmethod( "get", - [](nb::object cls, MlirType type, MlirContext ctx) { + [](const nb::object &cls, MlirType type, MlirContext ctx) { return cls(mlirTransformParamTypeGet(ctx, type)); }, "Get an instance of ParamType for the given type in the given context.", - nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none()); + nb::arg("cls"), nb::arg("type"), nb::arg("context") = nb::none()); paramType.def_property_readonly( "type", [](MlirType type) { diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 4885d62c5..8bb493ed7 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -45,7 +45,7 @@ class PyExecutionEngine { referencedObjects.push_back(obj); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirExecutionEngine rawPm = mlirPythonCapsuleToExecutionEngine(capsule.ptr()); if (mlirExecutionEngineIsNull(rawPm)) @@ -113,7 +113,7 @@ NB_MODULE(_mlirExecutionEngine, m) { .def( "raw_register_runtime", [](PyExecutionEngine &executionEngine, const std::string &name, - nb::object callbackObj) { + const nb::object &callbackObj) { executionEngine.addReferencedObject(callbackObj); uintptr_t rawSym = nb::cast(nb::getattr(callbackObj, "value")); diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 71a051cb3..1e81f53e4 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -17,6 +17,7 @@ #include "NanobindUtils.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringExtras.h" @@ -151,6 +152,29 @@ class PyGlobals { TracebackLoc &getTracebackLoc() { return tracebackLoc; } + class TypeIDAllocator { + public: + TypeIDAllocator() : allocator(mlirTypeIDAllocatorCreate()) {} + ~TypeIDAllocator() { + if (allocator.ptr) + mlirTypeIDAllocatorDestroy(allocator); + } + TypeIDAllocator(const TypeIDAllocator &) = delete; + TypeIDAllocator(TypeIDAllocator &&other) : allocator(other.allocator) { + other.allocator.ptr = nullptr; + } + + MlirTypeIDAllocator get() { return allocator; } + MlirTypeID allocate() { + return mlirTypeIDAllocatorAllocateTypeID(allocator); + } + + private: + MlirTypeIDAllocator allocator; + }; + + MlirTypeID allocateTypeID() { return typeIDAllocator.allocate(); } + private: static PyGlobals *instance; @@ -173,6 +197,7 @@ class PyGlobals { llvm::StringSet<> loadedDialectModules; TracebackLoc tracebackLoc; + TypeIDAllocator typeIDAllocator; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 50f2a4f95..7147f2cba 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -17,9 +17,9 @@ #include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/IntegerSet.h" #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir/Support/LLVM.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallVector.h" @@ -64,7 +64,7 @@ static void pyListToVector(const nb::list &list, } template -static bool isPermutation(std::vector permutation) { +static bool isPermutation(const std::vector &permutation) { llvm::SmallVector seen(permutation.size(), false); for (auto val : permutation) { if (val < permutation.size()) { @@ -142,7 +142,7 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr { static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); c.def_prop_ro("value", [](PyAffineConstantExpr &self) { return mlirAffineConstantExprGetValue(self); }); @@ -162,7 +162,7 @@ class PyAffineDimExpr : public PyConcreteAffineExpr { static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); c.def_prop_ro("position", [](PyAffineDimExpr &self) { return mlirAffineDimExprGetPosition(self); }); @@ -182,7 +182,7 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr { static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none()); + nb::arg("context") = nb::none()); c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { return mlirAffineSymbolExprGetPosition(self); }); @@ -366,7 +366,7 @@ nb::object PyAffineExpr::getCapsule() { return nb::steal(mlirPythonAffineExprToCapsule(*this)); } -PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { +PyAffineExpr PyAffineExpr::createFromCapsule(const nb::object &capsule) { MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); if (mlirAffineExprIsNull(rawAffineExpr)) throw nb::python_error(); @@ -424,7 +424,7 @@ nb::object PyAffineMap::getCapsule() { return nb::steal(mlirPythonAffineMapToCapsule(*this)); } -PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { +PyAffineMap PyAffineMap::createFromCapsule(const nb::object &capsule) { MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); if (mlirAffineMapIsNull(rawAffineMap)) throw nb::python_error(); @@ -500,7 +500,7 @@ nb::object PyIntegerSet::getCapsule() { return nb::steal(mlirPythonIntegerSetToCapsule(*this)); } -PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { +PyIntegerSet PyIntegerSet::createFromCapsule(const nb::object &capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) throw nb::python_error(); @@ -574,7 +574,9 @@ void mlir::python::populateIRAffine(nb::module_ &m) { }) .def_prop_ro( "context", - [](PyAffineExpr &self) { return self.getContext().getObject(); }) + [](PyAffineExpr &self) -> nb::typed { + return self.getContext().getObject(); + }) .def("compose", [](PyAffineExpr &self, PyAffineMap &other) { return PyAffineExpr(self.getContext(), @@ -588,7 +590,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { self.getContext(), mlirAffineExprShiftDims(self, numDims, shift, offset)); }, - nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0) + nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset") = 0) .def( "shift_symbols", [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift, @@ -597,8 +599,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { self.getContext(), mlirAffineExprShiftSymbols(self, numSymbols, shift, offset)); }, - nb::arg("num_symbols"), nb::arg("shift"), - nb::arg("offset").none() = 0) + nb::arg("num_symbols"), nb::arg("shift"), nb::arg("offset") = 0) .def_static( "simplify_affine_expr", [](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) { @@ -655,15 +656,15 @@ void mlir::python::populateIRAffine(nb::module_ &m) { "Gets an affine expression containing the rounded-up result " "of dividing an expression by a constant.") .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Gets a constant affine expression with the given value.") .def_static( "get_dim", &PyAffineDimExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Gets an affine expression of a dimension at the given position.") .def_static( "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Gets an affine expression of a symbol at the given position.") .def( "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, @@ -707,34 +708,36 @@ void mlir::python::populateIRAffine(nb::module_ &m) { [](PyAffineMap &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_static("compress_unused_symbols", - [](nb::list affineMaps, DefaultingPyMlirContext context) { - SmallVector maps; - pyListToVector( - affineMaps, maps, "attempting to create an AffineMap"); - std::vector compressed(affineMaps.size()); - auto populate = [](void *result, intptr_t idx, - MlirAffineMap m) { - static_cast(result)[idx] = (m); - }; - mlirAffineMapCompressUnusedSymbols( - maps.data(), maps.size(), compressed.data(), populate); - std::vector res; - res.reserve(compressed.size()); - for (auto m : compressed) - res.emplace_back(context->getRef(), m); - return res; - }) + .def_static( + "compress_unused_symbols", + [](const nb::list &affineMaps, DefaultingPyMlirContext context) { + SmallVector maps; + pyListToVector( + affineMaps, maps, "attempting to create an AffineMap"); + std::vector compressed(affineMaps.size()); + auto populate = [](void *result, intptr_t idx, MlirAffineMap m) { + static_cast(result)[idx] = (m); + }; + mlirAffineMapCompressUnusedSymbols(maps.data(), maps.size(), + compressed.data(), populate); + std::vector res; + res.reserve(compressed.size()); + for (auto m : compressed) + res.emplace_back(context->getRef(), m); + return res; + }) .def_prop_ro( "context", - [](PyAffineMap &self) { return self.getContext().getObject(); }, + [](PyAffineMap &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Affine Map") .def( "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, kDumpDocstring) .def_static( "get", - [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, + [](intptr_t dimCount, intptr_t symbolCount, const nb::list &exprs, DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( @@ -745,7 +748,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return PyAffineMap(context->getRef(), map); }, nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Gets a map with the given expressions as results.") .def_static( "get_constant", @@ -754,7 +757,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { mlirAffineMapConstantGet(context->get(), value); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + nb::arg("value"), nb::arg("context") = nb::none(), "Gets an affine map with a single constant result") .def_static( "get_empty", @@ -762,7 +765,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("context").none() = nb::none(), "Gets an empty affine map.") + nb::arg("context") = nb::none(), "Gets an empty affine map.") .def_static( "get_identity", [](intptr_t nDims, DefaultingPyMlirContext context) { @@ -770,7 +773,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { mlirAffineMapMultiDimIdentityGet(context->get(), nDims); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("n_dims"), nb::arg("context").none() = nb::none(), + nb::arg("n_dims"), nb::arg("context") = nb::none(), "Gets an identity map with the given number of dimensions.") .def_static( "get_minor_identity", @@ -781,7 +784,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return PyAffineMap(context->getRef(), affineMap); }, nb::arg("n_dims"), nb::arg("n_results"), - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Gets a minor identity map with the given number of dimensions and " "results.") .def_static( @@ -795,7 +798,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { context->get(), permutation.size(), permutation.data()); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("permutation"), nb::arg("context").none() = nb::none(), + nb::arg("permutation"), nb::arg("context") = nb::none(), "Gets an affine map that permutes its inputs.") .def( "get_submap", @@ -869,7 +872,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) + .def("__eq__", + [](PyIntegerSet &self, const nb::object &other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; @@ -892,13 +896,15 @@ void mlir::python::populateIRAffine(nb::module_ &m) { }) .def_prop_ro( "context", - [](PyIntegerSet &self) { return self.getContext().getObject(); }) + [](PyIntegerSet &self) -> nb::typed { + return self.getContext().getObject(); + }) .def( "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, kDumpDocstring) .def_static( "get", - [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, + [](intptr_t numDims, intptr_t numSymbols, const nb::list &exprs, std::vector eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) throw nb::value_error( @@ -921,7 +927,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return PyIntegerSet(context->getRef(), set); }, nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), - nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) + nb::arg("eq_flags"), nb::arg("context") = nb::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, @@ -931,11 +937,12 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return PyIntegerSet(context->getRef(), set); }, nb::arg("num_dims"), nb::arg("num_symbols"), - nb::arg("context").none() = nb::none()) + nb::arg("context") = nb::none()) .def( "get_replaced", - [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, - intptr_t numResultDims, intptr_t numResultSymbols) { + [](PyIntegerSet &self, const nb::list &dimExprs, + const nb::list &symbolExprs, intptr_t numResultDims, + intptr_t numResultSymbols) { if (static_cast(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) throw nb::value_error( diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index db84ee1fc..045c0fbf4 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -167,7 +167,7 @@ struct nb_buffer_info { }; class nb_buffer : public nb::object { - NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); + NB_OBJECT_DEFAULT(nb_buffer, object, "Buffer", PyObject_CheckBuffer); nb_buffer_info request() const { int flags = PyBUF_STRIDES | PyBUF_FORMAT; @@ -252,8 +252,13 @@ class PyAffineMapAttribute : public PyConcreteAttribute { return PyAffineMapAttribute(affineMap.getContext(), attr); }, nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - c.def_prop_ro("value", mlirAffineMapAttrGetValue, - "Returns the value of the AffineMap attribute"); + c.def_prop_ro( + "value", + [](PyAffineMapAttribute &self) { + return PyAffineMap(self.getContext(), + mlirAffineMapAttrGetValue(self)); + }, + "Returns the value of the AffineMap attribute"); } }; @@ -351,7 +356,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { } return getAttribute(values, ctx->getRef()); }, - nb::arg("values"), nb::arg("context").none() = nb::none(), + nb::arg("values"), nb::arg("context") = nb::none(), "Gets a uniqued dense array attribute"); } else { c.def_static( @@ -359,7 +364,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { [](const std::vector &values, DefaultingPyMlirContext ctx) { return getAttribute(values, ctx->getRef()); }, - nb::arg("values"), nb::arg("context").none() = nb::none(), + nb::arg("values"), nb::arg("context") = nb::none(), "Gets a uniqued dense array attribute"); } // Bind the array methods. @@ -480,11 +485,13 @@ class PyArrayAttribute : public PyConcreteAttribute { PyArrayAttributeIterator &dunderIter() { return *this; } - MlirAttribute dunderNext() { + nb::typed dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw nb::stop_iteration(); - return mlirArrayAttrGetElement(attr.get(), nextIndex++); + return PyAttribute(this->attr.getContext(), + mlirArrayAttrGetElement(attr.get(), nextIndex++)) + .maybeDownCast(); } static void bind(nb::module_ &m) { @@ -505,7 +512,7 @@ class PyArrayAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](nb::list attributes, DefaultingPyMlirContext context) { + [](const nb::list &attributes, DefaultingPyMlirContext context) { SmallVector mlirAttributes; mlirAttributes.reserve(nb::len(attributes)); for (auto attribute : attributes) { @@ -515,14 +522,16 @@ class PyArrayAttribute : public PyConcreteAttribute { context->get(), mlirAttributes.size(), mlirAttributes.data()); return PyArrayAttribute(context->getRef(), attr); }, - nb::arg("attributes"), nb::arg("context").none() = nb::none(), + nb::arg("attributes"), nb::arg("context") = nb::none(), "Gets a uniqued Array attribute"); - c.def("__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { - if (i >= mlirArrayAttrGetNumElements(arr)) - throw nb::index_error("ArrayAttribute index out of range"); - return arr.getItem(i); - }) + c.def( + "__getitem__", + [](PyArrayAttribute &arr, + intptr_t i) -> nb::typed { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw nb::index_error("ArrayAttribute index out of range"); + return PyAttribute(arr.getContext(), arr.getItem(i)).maybeDownCast(); + }) .def("__len__", [](const PyArrayAttribute &arr) { return mlirArrayAttrGetNumElements(arr); @@ -530,7 +539,7 @@ class PyArrayAttribute : public PyConcreteAttribute { .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); - c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { + c.def("__add__", [](PyArrayAttribute arr, const nb::list &extras) { std::vector attributes; intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); attributes.reserve(numOldElements + nb::len(extras)); @@ -564,7 +573,19 @@ class PyFloatAttribute : public PyConcreteAttribute { throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, - nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), + nb::arg("type"), nb::arg("value"), nb::arg("loc") = nb::none(), + "Gets an uniqued float point attribute associated to a type"); + c.def_static( + "get_unchecked", + [](PyType &type, double value, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute attr = + mlirFloatAttrDoubleGet(context.get()->get(), type, value); + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); + return PyFloatAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), nb::arg("context") = nb::none(), "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", @@ -573,7 +594,7 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF32TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + nb::arg("value"), nb::arg("context") = nb::none(), "Gets an uniqued float point attribute associated to a f32 type"); c.def_static( "get_f64", @@ -582,7 +603,7 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF64TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + nb::arg("value"), nb::arg("context") = nb::none(), "Gets an uniqued float point attribute associated to a f64 type"); c.def_prop_ro("value", mlirFloatAttrGetValueDouble, "Returns the value of the float attribute"); @@ -611,10 +632,12 @@ class PyIntegerAttribute : public PyConcreteAttribute { "Returns the value of the integer attribute"); c.def("__int__", toPyInt, "Converts the value of the integer attribute to a Python int"); - c.def_prop_ro_static("static_typeid", - [](nb::object & /*class*/) -> MlirTypeID { - return mlirIntegerAttrGetTypeID(); - }); + c.def_prop_ro_static( + "static_typeid", + [](nb::object & /*class*/) { + return PyTypeID(mlirIntegerAttrGetTypeID()); + }, + nanobind::sig("def static_typeid(/) -> TypeID")); } private: @@ -642,7 +665,7 @@ class PyBoolAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirBoolAttrGet(context->get(), value); return PyBoolAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + nb::arg("value"), nb::arg("context") = nb::none(), "Gets an uniqued bool attribute"); c.def_prop_ro("value", mlirBoolAttrGetValue, "Returns the value of the bool attribute"); @@ -657,8 +680,8 @@ class PySymbolRefAttribute : public PyConcreteAttribute { static constexpr const char *pyClassName = "SymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; - static MlirAttribute fromList(const std::vector &symbols, - PyMlirContext &context) { + static PySymbolRefAttribute fromList(const std::vector &symbols, + PyMlirContext &context) { if (symbols.empty()) throw std::runtime_error("SymbolRefAttr must be composed of at least " "one symbol."); @@ -668,8 +691,10 @@ class PySymbolRefAttribute : public PyConcreteAttribute { referenceAttrs.push_back( mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); } - return mlirSymbolRefAttrGet(context.get(), rootSymbol, - referenceAttrs.size(), referenceAttrs.data()); + return PySymbolRefAttribute(context.getRef(), + mlirSymbolRefAttrGet(context.get(), rootSymbol, + referenceAttrs.size(), + referenceAttrs.data())); } static void bindDerived(ClassTy &c) { @@ -679,7 +704,7 @@ class PySymbolRefAttribute : public PyConcreteAttribute { DefaultingPyMlirContext context) { return PySymbolRefAttribute::fromList(symbols, context.resolve()); }, - nb::arg("symbols"), nb::arg("context").none() = nb::none(), + nb::arg("symbols"), nb::arg("context") = nb::none(), "Gets a uniqued SymbolRef attribute from a list of symbol names"); c.def_prop_ro( "value", @@ -708,12 +733,12 @@ class PyFlatSymbolRefAttribute static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string value, DefaultingPyMlirContext context) { + [](const std::string &value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); return PyFlatSymbolRefAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + nb::arg("value"), nb::arg("context") = nb::none(), "Gets a uniqued FlatSymbolRef attribute"); c.def_prop_ro( "value", @@ -736,8 +761,8 @@ class PyOpaqueAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, nb_buffer buffer, PyType &type, - DefaultingPyMlirContext context) { + [](const std::string &dialectNamespace, const nb_buffer &buffer, + PyType &type, DefaultingPyMlirContext context) { const nb_buffer_info bufferInfo = buffer.request(); intptr_t bufferSize = bufferInfo.size; MlirAttribute attr = mlirOpaqueAttrGet( @@ -746,7 +771,11 @@ class PyOpaqueAttribute : public PyConcreteAttribute { return PyOpaqueAttribute(context->getRef(), attr); }, nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), - nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); + nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get(dialect_namespace: str, buffer: typing_extensions.Buffer, type: Type, context: Context | None = None) -> OpaqueAttr"), + // clang-format on + "Gets an Opaque attribute."); c.def_prop_ro( "dialect_namespace", [](PyOpaqueAttribute &self) { @@ -764,59 +793,6 @@ class PyOpaqueAttribute : public PyConcreteAttribute { } }; -class PyStringAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; - static constexpr const char *pyClassName = "StringAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirStringAttrGetTypeID; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get", - [](nb::bytes value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get_typed", - [](PyType &type, std::string value) { - MlirAttribute attr = - mlirStringAttrTypedGet(type, toMlirStringRef(value)); - return PyStringAttribute(type.getContext(), attr); - }, - nb::arg("type"), nb::arg("value"), - "Gets a uniqued string attribute associated to a type"); - c.def_prop_ro( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); - c.def_prop_ro( - "value_bytes", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return nb::bytes(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute as `bytes`"); - } -}; - // TODO: Support construction of string elements. class PyDenseElementsAttribute : public PyConcreteAttribute { @@ -826,7 +802,7 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromList(nb::list attributes, std::optional explicitType, + getFromList(const nb::list &attributes, std::optional explicitType, DefaultingPyMlirContext contextWrapper) { const size_t numAttributes = nb::len(attributes); if (numAttributes == 0) @@ -878,8 +854,8 @@ class PyDenseElementsAttribute } static PyDenseElementsAttribute - getFromBuffer(nb_buffer array, bool signless, - std::optional explicitType, + getFromBuffer(const nb_buffer &array, bool signless, + const std::optional &explicitType, std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. @@ -894,8 +870,8 @@ class PyDenseElementsAttribute auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); MlirContext context = contextWrapper->get(); - MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, - explicitShape, context); + MlirAttribute attr = getAttributeFromBuffer( + view, signless, explicitType, std::move(explicitShape), context); if (mlirAttributeIsNull(attr)) { throw std::invalid_argument( "DenseElementsAttr could not be constructed from the given buffer. " @@ -1028,15 +1004,17 @@ class PyDenseElementsAttribute PyDenseElementsAttribute::bf_releasebuffer; #endif c.def("__len__", &PyDenseElementsAttribute::dunderLen) - .def_static("get", PyDenseElementsAttribute::getFromBuffer, - nb::arg("array"), nb::arg("signless") = true, - nb::arg("type").none() = nb::none(), - nb::arg("shape").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kDenseElementsAttrGetDocstring) + .def_static( + "get", PyDenseElementsAttribute::getFromBuffer, nb::arg("array"), + nb::arg("signless") = true, nb::arg("type") = nb::none(), + nb::arg("shape") = nb::none(), nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get(array: typing_extensions.Buffer, signless: bool = True, type: Type | None = None, shape: Sequence[int] | None = None, context: Context | None = None) -> DenseElementsAttr"), + // clang-format on + kDenseElementsAttrGetDocstring) .def_static("get", PyDenseElementsAttribute::getFromList, - nb::arg("attrs"), nb::arg("type").none() = nb::none(), - nb::arg("context").none() = nb::none(), + nb::arg("attrs"), nb::arg("type") = nb::none(), + nb::arg("context") = nb::none(), kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, nb::arg("shaped_type"), nb::arg("element_attr"), @@ -1045,12 +1023,16 @@ class PyDenseElementsAttribute [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self); }) - .def("get_splat_value", [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw nb::value_error( - "get_splat_value called on a non-splat attribute"); - return mlirDenseElementsAttrGetSplatValue(self); - }); + .def("get_splat_value", + [](PyDenseElementsAttribute &self) + -> nb::typed { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return PyAttribute(self.getContext(), + mlirDenseElementsAttrGetSplatValue(self)) + .maybeDownCast(); + }); } static PyType_Slot slots[]; @@ -1092,16 +1074,16 @@ class PyDenseElementsAttribute "when the type is not a shaped type."); } return *bulkLoadElementType; - } else { - MlirAttribute encodingAttr = mlirAttributeGetNull(); - return mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); } + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); } static MlirAttribute getAttributeFromBuffer( Py_buffer &view, bool signless, std::optional explicitType, - std::optional> explicitShape, MlirContext &context) { + const std::optional> &explicitShape, + MlirContext &context) { // Detect format codes that are suitable for bulk loading. This includes // all byte aligned integer and floating point types up to 8 bytes. // Notably, this excludes exotics types which do not have a direct @@ -1125,7 +1107,7 @@ class PyDenseElementsAttribute bulkLoadElementType = mlirF16TypeGet(context); } else if (format == "?") { // i1 - // The i1 type needs to be bit-packed, so we will handle it seperately + // The i1 type needs to be bit-packed, so we will handle it separately return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, context); } else if (isSignedIntegerFormat(format)) { @@ -1205,8 +1187,8 @@ class PyDenseElementsAttribute packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); - MlirType bitpackedType = - getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); + MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), + std::move(explicitShape), view); assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of // packedBooleans, hence the MlirAttribute will remain valid even when @@ -1365,7 +1347,7 @@ class PyDenseIntElementsAttribute /// Returns the element at the given linear position. Asserts if the index /// is out of range. - nb::object dunderGetItem(intptr_t pos) { + nb::int_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { throw nb::index_error("attempt to access out of bounds element"); } @@ -1443,9 +1425,9 @@ class PyDenseResourceElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, - std::optional alignment, bool isMutable, - DefaultingPyMlirContext contextWrapper) { + getFromBuffer(const nb_buffer &buffer, const std::string &name, + const PyType &type, std::optional alignment, + bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { throw std::invalid_argument( "Constructing a DenseResourceElementsAttr requires a ShapedType."); @@ -1505,12 +1487,15 @@ class PyDenseResourceElementsAttribute } static void bindDerived(ClassTy &c) { - c.def_static( - "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, - nb::arg("array"), nb::arg("name"), nb::arg("type"), - nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, - nb::arg("context").none() = nb::none(), - kDenseResourceElementsAttrGetFromBufferDocstring); + c.def_static("get_from_buffer", + PyDenseResourceElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("name"), nb::arg("type"), + nb::arg("alignment") = nb::none(), + nb::arg("is_mutable") = false, nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get_from_buffer(array: typing_extensions.Buffer, name: str, type: Type, alignment: int | None = None, is_mutable: bool = False, context: Context | None = None) -> DenseResourceElementsAttr"), + // clang-format on + kDenseResourceElementsAttrGetFromBufferDocstring); } }; @@ -1534,7 +1519,7 @@ class PyDictAttribute : public PyConcreteAttribute { c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", - [](nb::dict attributes, DefaultingPyMlirContext context) { + [](const nb::dict &attributes, DefaultingPyMlirContext context) { SmallVector mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); for (std::pair it : attributes) { @@ -1550,15 +1535,17 @@ class PyDictAttribute : public PyConcreteAttribute { mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), + nb::arg("value") = nb::dict(), nb::arg("context") = nb::none(), "Gets an uniqued dict attribute"); - c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { - MlirAttribute attr = - mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) - throw nb::key_error("attempt to access a non-existent attribute"); - return attr; - }); + c.def("__getitem__", + [](PyDictAttribute &self, + const std::string &name) -> nb::typed { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) + throw nb::key_error("attempt to access a non-existent attribute"); + return PyAttribute(self.getContext(), attr).maybeDownCast(); + }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { throw nb::index_error("attempt to access out of bounds attribute"); @@ -1618,15 +1605,17 @@ class PyTypeAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType value, DefaultingPyMlirContext context) { + [](const PyType &value, DefaultingPyMlirContext context) { MlirAttribute attr = mlirTypeAttrGet(value.get()); return PyTypeAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + nb::arg("value"), nb::arg("context") = nb::none(), "Gets a uniqued Type attribute"); - c.def_prop_ro("value", [](PyTypeAttribute &self) { - return mlirTypeAttrGetValue(self.get()); - }); + c.def_prop_ro( + "value", [](PyTypeAttribute &self) -> nb::typed { + return PyType(self.getContext(), mlirTypeAttrGetValue(self.get())) + .maybeDownCast(); + }); } }; @@ -1646,7 +1635,7 @@ class PyUnitAttribute : public PyConcreteAttribute { return PyUnitAttribute(context->getRef(), mlirUnitAttrGet(context->get())); }, - nb::arg("context").none() = nb::none(), "Create a Unit attribute."); + nb::arg("context") = nb::none(), "Create a Unit attribute."); } }; @@ -1663,14 +1652,13 @@ class PyStridedLayoutAttribute static void bindDerived(ClassTy &c) { c.def_static( "get", - [](int64_t offset, const std::vector strides, + [](int64_t offset, const std::vector &strides, DefaultingPyMlirContext ctx) { MlirAttribute attr = mlirStridedLayoutAttrGet( ctx->get(), offset, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - nb::arg("offset"), nb::arg("strides"), - nb::arg("context").none() = nb::none(), + nb::arg("offset"), nb::arg("strides"), nb::arg("context") = nb::none(), "Gets a strided layout attribute."); c.def_static( "get_fully_dynamic", @@ -1682,7 +1670,7 @@ class PyStridedLayoutAttribute ctx->get(), dynamic, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - nb::arg("rank"), nb::arg("context").none() = nb::none(), + nb::arg("rank"), nb::arg("context") = nb::none(), "Gets a strided layout attribute with dynamic offset and strides of " "a " "given rank."); @@ -1744,9 +1732,9 @@ nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { return nb::cast(PyBoolAttribute(pyAttribute)); if (PyIntegerAttribute::isaFunction(pyAttribute)) return nb::cast(PyIntegerAttribute(pyAttribute)); - std::string msg = - std::string("Can't cast unknown element type DenseArrayAttr (") + - nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + std::string msg = std::string("Can't cast unknown attribute type Attr (") + + nb::cast(nb::repr(nb::cast(pyAttribute))) + + ")"; throw nb::type_error(msg.c_str()); } @@ -1763,6 +1751,50 @@ nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { } // namespace +void PyStringAttribute::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context") = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get", + [](const nb::bytes &value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context") = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get_typed", + [](PyType &type, const std::string &value) { + MlirAttribute attr = + mlirStringAttrTypedGet(type, toMlirStringRef(value)); + return PyStringAttribute(type.getContext(), attr); + }, + nb::arg("type"), nb::arg("value"), + "Gets a uniqued string attribute associated to a type"); + c.def_prop_ro( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return nb::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); + c.def_prop_ro( + "value_bytes", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return nb::bytes(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute as `bytes`"); +} + void mlir::python::populateIRAttributes(nb::module_ &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4b3a06cbc..7b1710656 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; +static const char kModuleCAPICreate[] = + R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). +Note this returns a new object BUT _clear_mlir_module(module) must be called to +prevent double-frees (of the underlying mlir::Module). +)"; + static const char kOperationCreateDocstring[] = R"(Creates a new operation. @@ -196,7 +202,7 @@ operations. /// Helper for creating an @classmethod. template -nb::object classmethod(Func f, Args... args) { +static nb::object classmethod(Func f, Args... args) { nb::object cf = nb::cpp_function(f, args...); return nb::borrow((PyClassMethod_New(cf.ptr()))); } @@ -507,7 +513,7 @@ class PyOperationIterator { PyOperationIterator &dunderIter() { return *this; } - nb::object dunderNext() { + nb::typed dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { throw nb::stop_iteration(); @@ -556,7 +562,7 @@ class PyOperationList { return count; } - nb::object dunderGetItem(intptr_t index) { + nb::typed dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { index += dunderLen(); @@ -592,7 +598,7 @@ class PyOpOperand { public: PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} - nb::object getOwner() { + PyOpView getOwner() { MlirOperation owner = mlirOpOperandGetOwner(opOperand); PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(owner)); @@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } -size_t PyMlirContext::getLiveOperationCount() { - nb::ft_lock_guard lock(liveOperationsMutex); - return liveOperations.size(); -} - -std::vector PyMlirContext::getLiveOperationObjects() { - std::vector liveObjects; - nb::ft_lock_guard lock(liveOperationsMutex); - for (auto &entry : liveOperations) - liveObjects.push_back(entry.second.second); - return liveObjects; -} - -size_t PyMlirContext::clearLiveOperations() { - - LiveOperationMap operations; - { - nb::ft_lock_guard lock(liveOperationsMutex); - std::swap(operations, liveOperations); - } - for (auto &op : operations) - op.second.second->setInvalid(); - size_t numInvalidated = operations.size(); - return numInvalidated; -} - -void PyMlirContext::clearOperation(MlirOperation op) { - PyOperation *py_op; - { - nb::ft_lock_guard lock(liveOperationsMutex); - auto it = liveOperations.find(op.ptr); - if (it == liveOperations.end()) { - return; - } - py_op = it->second.second; - liveOperations.erase(it); - } - py_op->setInvalid(); -} - -void PyMlirContext::clearOperationsInside(PyOperationBase &op) { - typedef struct { - PyOperation &rootOp; - bool rootSeen; - } callBackData; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - callBackData *data = static_cast(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation().getContext()->clearOperation(op); - else - data->rootSeen = true; - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast(&data), MlirWalkPreOrder); -} -void PyMlirContext::clearOperationsInside(MlirOperation op) { - PyOperationRef opRef = PyOperation::forOperation(getRef(), op); - clearOperationsInside(opRef->getOperation()); -} - -void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - PyMlirContextRef &contextRef = *static_cast(userData); - contextRef->clearOperation(op); - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - &op.getOperation().getContext(), MlirWalkPreOrder); -} - -size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } - nb::object PyMlirContext::contextEnter(nb::object context) { return PyThreadContextEntry::pushContext(context); } @@ -797,7 +725,7 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { new PyDiagnosticHandler(get(), std::move(callback)); nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::take_ownership); - pyHandlerObject.inc_ref(); + (void)pyHandlerObject.inc_ref(); // In these C callbacks, the userData is a PyDiagnosticHandler* that is // guaranteed to be known to pybind. @@ -1207,15 +1135,11 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - - // Otherwise, invalidate the operation and remove it from live map when it is - // attached. - if (isAttached()) { - getContext()->clearOperation(*this); - } else { - // And destroy it when it is detached, i.e. owned by Python, in which case - // all nested operations must be invalidated at removed from the live map as - // well. + // Otherwise, invalidate the operation when it is attached. + if (isAttached()) + setInvalid(); + else { + // And destroy it when it is detached, i.e. owned by Python. erase(); } } @@ -1252,35 +1176,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - PyOperationRef result = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(result.getObject(), result.get()); - return result; - } - // Use existing. - PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - assert(liveOperations.count(operation.ptr) == 0 && - "cannot create detached operation that already exists"); - (void)liveOperations; PyOperationRef created = createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(created.getObject(), created.get()); created->attached = false; return created; } @@ -1491,7 +1395,7 @@ nb::object PyOperation::getCapsule() { return nb::steal(mlirPythonOperationToCapsule(get())); } -nb::object PyOperation::createFromCapsule(nb::object capsule) { +nb::object PyOperation::createFromCapsule(const nb::object &capsule) { MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); if (mlirOperationIsNull(rawOperation)) throw nb::python_error(); @@ -1652,7 +1556,7 @@ nb::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - getContext()->clearOperationAndInside(*this); + setInvalid(); mlirOperationDestroy(operation); } @@ -1701,7 +1605,9 @@ class PyConcreteValue : public PyValue { }, nb::arg("other_value")); cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) { return self.maybeDownCast(); }); + [](DerivedTy &self) -> nb::typed { + return self.maybeDownCast(); + }); DerivedTy::bindDerived(cls); } @@ -1719,13 +1625,14 @@ class PyOpResult : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyOpResult &self) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation().getObject(); - }); + c.def_prop_ro( + "owner", [](PyOpResult &self) -> nb::typed { + assert(mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation().getObject(); + }); c.def_prop_ro("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }); @@ -1734,12 +1641,14 @@ class PyOpResult : public PyConcreteValue { /// Returns the list of types of the values held by container. template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; +static std::vector> +getValueTypes(Container &container, PyMlirContextRef &context) { + std::vector> result; result.reserve(container.size()); for (int i = 0, e = container.size(); i < e; ++i) { - result.push_back(mlirValueGetType(container.getElement(i).get())); + result.push_back(PyType(context->getRef(), + mlirValueGetType(container.getElement(i).get())) + .maybeDownCast()); } return result; } @@ -1765,9 +1674,10 @@ class PyOpResultList : public Sliceable { c.def_prop_ro("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); - c.def_prop_ro("owner", [](PyOpResultList &self) { - return self.operation->createOpView(); - }); + c.def_prop_ro("owner", + [](PyOpResultList &self) -> nb::typed { + return self.operation->createOpView(); + }); } PyOperationRef &getOperation() { return operation; } @@ -2105,7 +2015,7 @@ nb::object PyOpView::buildGeneric( // Delegate to create. return PyOperation::create(name, /*results=*/std::move(resultTypes), - /*operands=*/std::move(operands), + /*operands=*/operands, /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), /*regions=*/*regions, location, maybeIp, @@ -2130,12 +2040,15 @@ PyOpView::PyOpView(const nb::object &operationObject) // PyInsertionPoint. //------------------------------------------------------------------------------ -PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} +PyInsertionPoint::PyInsertionPoint(const PyBlock &block) : block(block) {} PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) : refOperation(beforeOperationBase.getOperation().getRef()), block((*refOperation)->getBlock()) {} +PyInsertionPoint::PyInsertionPoint(PyOperationRef beforeOperationRef) + : refOperation(beforeOperationRef), block((*refOperation)->getBlock()) {} + void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) @@ -2184,8 +2097,21 @@ PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { return PyInsertionPoint{block, std::move(terminatorOpRef)}; } +PyInsertionPoint PyInsertionPoint::after(PyOperationBase &op) { + PyOperation &operation = op.getOperation(); + PyBlock block = operation.getBlock(); + MlirOperation nextOperation = mlirOperationGetNextInBlock(operation); + if (mlirOperationIsNull(nextOperation)) + return PyInsertionPoint(block); + PyOperationRef nextOpRef = PyOperation::forOperation( + block.getParentOperation()->getContext(), nextOperation); + return PyInsertionPoint{block, std::move(nextOpRef)}; +} + +size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } + nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { - return PyThreadContextEntry::pushInsertionPoint(insertPoint); + return PyThreadContextEntry::pushInsertionPoint(std::move(insertPoint)); } void PyInsertionPoint::contextExit(const nb::object &excType, @@ -2206,7 +2132,7 @@ nb::object PyAttribute::getCapsule() { return nb::steal(mlirPythonAttributeToCapsule(*this)); } -PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { +PyAttribute PyAttribute::createFromCapsule(const nb::object &capsule) { MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); if (mlirAttributeIsNull(rawAttr)) throw nb::python_error(); @@ -2214,6 +2140,20 @@ PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); } +nb::object PyAttribute::maybeDownCast() { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(this->get()); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = PyGlobals::get().lookupTypeCaster( + mlirTypeID, mlirAttributeGetDialect(this->get())); + // nb::rv_policy::move means use std::move to move the return value + // contents into a new instance that will be owned by Python. + nb::object thisObj = nb::cast(this, nb::rv_policy::move); + if (!typeCaster) + return thisObj; + return typeCaster.value()(thisObj); +} + //------------------------------------------------------------------------------ // PyNamedAttribute. //------------------------------------------------------------------------------ @@ -2246,6 +2186,20 @@ PyType PyType::createFromCapsule(nb::object capsule) { rawType); } +nb::object PyType::maybeDownCast() { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(this->get()); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = PyGlobals::get().lookupTypeCaster( + mlirTypeID, mlirTypeGetDialect(this->get())); + // nb::rv_policy::move means use std::move to move the return value + // contents into a new instance that will be owned by Python. + nb::object thisObj = nb::cast(this, nb::rv_policy::move); + if (!typeCaster) + return thisObj; + return typeCaster.value()(thisObj); +} + //------------------------------------------------------------------------------ // PyTypeID. //------------------------------------------------------------------------------ @@ -2344,17 +2298,19 @@ void PySymbolTable::dunderDel(const std::string &name) { erase(nb::cast(operation)); } -MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { +PyStringAttribute PySymbolTable::insert(PyOperationBase &symbol) { operation->checkValid(); symbol.getOperation().checkValid(); MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) throw nb::value_error("Expected operation to have a symbol name."); - return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); + return PyStringAttribute( + symbol.getOperation().getContext(), + mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); } -MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { +PyStringAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { // Op must already be a symbol. PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -2363,7 +2319,8 @@ MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) throw nb::value_error("Expected operation to have a symbol name."); - return existingNameAttr; + return PyStringAttribute(symbol.getOperation().getContext(), + existingNameAttr); } void PySymbolTable::setSymbolName(PyOperationBase &symbol, @@ -2381,7 +2338,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol, mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); } -MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { +PyStringAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { PyOperation &operation = symbol.getOperation(); operation.checkValid(); MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); @@ -2389,7 +2346,7 @@ MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) throw nb::value_error("Expected operation to have a symbol visibility."); - return existingVisAttr; + return PyStringAttribute(symbol.getOperation().getContext(), existingVisAttr); } void PySymbolTable::setVisibility(PyOperationBase &symbol, @@ -2727,13 +2684,14 @@ class PyOpAttributeMap { PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - MlirAttribute dunderGetItemNamed(const std::string &name) { + nb::typed + dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw nb::key_error("attempt to access a non-existent attribute"); } - return attr; + return PyAttribute(operation->getContext(), attr).maybeDownCast(); } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { @@ -2998,20 +2956,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("exc_type").none(), nb::arg("exc_value").none(), nb::arg("traceback").none()); - //---------------------------------------------------------------------------- - // Mapping of MlirContext. - // Note that this is exported as _BaseContext. The containing, Python level - // __init__.py will subclass it with site-specific functionality and set a - // "Context" attribute on this module. - //---------------------------------------------------------------------------- - // Expose DefaultThreadPool to python nb::class_(m, "ThreadPool") .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); - nb::class_(m, "_BaseContext") + nb::class_(m, "Context") .def("__init__", [](PyMlirContext &self) { MlirContext context = mlirContextCreateWithThreading(false); @@ -3019,31 +2970,27 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", - [](PyMlirContext &self) { + [](PyMlirContext &self) -> nb::typed { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_operation_objects", - &PyMlirContext::getLiveOperationObjects) - .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) - .def("_clear_live_operations_inside", - nb::overload_cast( - &PyMlirContext::clearOperationsInside)) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), nb::arg("traceback").none()) .def_prop_ro_static( "current", - [](nb::object & /*class*/) { + [](nb::object & /*class*/) + -> std::optional> { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - return nb::none(); + return {}; return nb::cast(context); }, + nb::sig("def current(/) -> Context | None"), "Gets the Context bound to the current thread or raises ValueError") .def_prop_ro( "dialects", @@ -3114,7 +3061,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirContextAppendDialectRegistry(self.get(), registry); }, nb::arg("registry")) - .def_prop_rw("emit_error_diagnostics", nullptr, + .def_prop_rw("emit_error_diagnostics", + &PyMlirContext::getEmitErrorDiagnostics, &PyMlirContext::setEmitErrorDiagnostics, "Emit error diagnostics to diagnostic handlers. By default " "error diagnostics are captured and reported through " @@ -3132,13 +3080,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirStringRef ns = mlirDialectGetNamespace(self.get()); return nb::str(ns.data, ns.length); }) - .def("__repr__", [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - std::string repr(""); - return repr; - }); + .def( + "__repr__", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + std::string repr(""); + return repr; + }, + nb::sig("def __repr__(self) -> str")); //---------------------------------------------------------------------------- // Mapping of PyDialects @@ -3167,20 +3118,24 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def(nb::init(), nb::arg("descriptor")) .def_prop_ro("descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](nb::object self) { - auto clazz = self.attr("__class__"); - return nb::str(""); - }); + .def( + "__repr__", + [](const nb::object &self) { + auto clazz = self.attr("__class__"); + return nb::str(""); + }, + nb::sig("def __repr__(self) -> str")); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- nb::class_(m, "DialectRegistry") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyDialectRegistry::createFromCapsule) .def(nb::init<>()); //---------------------------------------------------------------------------- @@ -3188,7 +3143,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Location") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), nb::arg("exc_value").none(), nb::arg("traceback").none()) @@ -3205,6 +3160,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { return std::nullopt; return loc; }, + // clang-format off + nb::sig("def current(/) -> Location | None"), + // clang-format on "Gets the Location bound to the current thread or raises ValueError") .def_static( "unknown", @@ -3212,7 +3170,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationUnknownGet(context->get())); }, - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Gets a Location representing an unknown location") .def_static( "callsite", @@ -3227,12 +3185,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); }, - nb::arg("callee"), nb::arg("frames"), - nb::arg("context").none() = nb::none(), + nb::arg("callee"), nb::arg("frames"), nb::arg("context") = nb::none(), kContextGetCallSiteLocationDocstring) .def("is_a_callsite", mlirLocationIsACallSite) - .def_prop_ro("callee", mlirLocationCallSiteGetCallee) - .def_prop_ro("caller", mlirLocationCallSiteGetCaller) + .def_prop_ro("callee", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCallee(self)); + }) + .def_prop_ro("caller", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationCallSiteGetCaller(self)); + }) .def_static( "file", [](std::string filename, int line, int col, @@ -3243,8 +3208,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { context->get(), toMlirStringRef(filename), line, col)); }, nb::arg("filename"), nb::arg("line"), nb::arg("col"), - nb::arg("context").none() = nb::none(), - kContextGetFileLocationDocstring) + nb::arg("context") = nb::none(), kContextGetFileLocationDocstring) .def_static( "file", [](std::string filename, int startLine, int startCol, int endLine, @@ -3256,7 +3220,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), nb::arg("end_line"), nb::arg("end_col"), - nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) + nb::arg("context") = nb::none(), kContextGetFileRangeDocstring) .def("is_a_file", mlirLocationIsAFileLineColRange) .def_prop_ro("filename", [](MlirLocation loc) { @@ -3281,19 +3245,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { metadata ? metadata->get() : MlirAttribute{0}); return PyLocation(context->getRef(), location); }, - nb::arg("locations"), nb::arg("metadata").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kContextGetFusedLocationDocstring) + nb::arg("locations"), nb::arg("metadata") = nb::none(), + nb::arg("context") = nb::none(), kContextGetFusedLocationDocstring) .def("is_a_fused", mlirLocationIsAFused) - .def_prop_ro("locations", - [](MlirLocation loc) { - unsigned numLocations = - mlirLocationFusedGetNumLocations(loc); - std::vector locations(numLocations); - if (numLocations) - mlirLocationFusedGetLocations(loc, locations.data()); - return locations; - }) + .def_prop_ro( + "locations", + [](PyLocation &self) { + unsigned numLocations = mlirLocationFusedGetNumLocations(self); + std::vector locations(numLocations); + if (numLocations) + mlirLocationFusedGetLocations(self, locations.data()); + std::vector pyLocations{}; + pyLocations.reserve(numLocations); + for (unsigned i = 0; i < numLocations; ++i) + pyLocations.emplace_back(self.getContext(), locations[i]); + return pyLocations; + }) .def_static( "name", [](std::string name, std::optional childLoc, @@ -3305,30 +3272,38 @@ void mlir::python::populateIRCore(nb::module_ &m) { childLoc ? childLoc->get() : mlirLocationUnknownGet(context->get()))); }, - nb::arg("name"), nb::arg("childLoc").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kContextGetNameLocationDocString) + nb::arg("name"), nb::arg("childLoc") = nb::none(), + nb::arg("context") = nb::none(), kContextGetNameLocationDocString) .def("is_a_name", mlirLocationIsAName) .def_prop_ro("name_str", [](MlirLocation loc) { return mlirIdentifierStr(mlirLocationNameGetName(loc)); }) - .def_prop_ro("child_loc", mlirLocationNameGetChildLoc) + .def_prop_ro("child_loc", + [](PyLocation &self) { + return PyLocation(self.getContext(), + mlirLocationNameGetChildLoc(self)); + }) .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { return PyLocation(context->getRef(), mlirLocationFromAttribute(attribute)); }, - nb::arg("attribute"), nb::arg("context").none() = nb::none(), + nb::arg("attribute"), nb::arg("context") = nb::none(), "Gets a Location from a LocationAttr") .def_prop_ro( "context", - [](PyLocation &self) { return self.getContext().getObject(); }, + [](PyLocation &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Location") .def_prop_ro( "attr", - [](PyLocation &self) { return mlirLocationGetAttribute(self); }, + [](PyLocation &self) { + return PyAttribute(self.getContext(), + mlirLocationGetAttribute(self)); + }, "Get the underlying LocationAttr") .def( "emit_error", @@ -3348,10 +3323,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, + kModuleCAPICreate) + .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", - [](const std::string &moduleAsm, DefaultingPyMlirContext context) { + [](const std::string &moduleAsm, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); @@ -3359,11 +3337,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + nb::arg("asm"), nb::arg("context") = nb::none(), kModuleParseDocstring) .def_static( "parse", - [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { + [](nb::bytes moduleAsm, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); @@ -3371,11 +3350,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + nb::arg("asm"), nb::arg("context") = nb::none(), kModuleParseDocstring) .def_static( "parseFile", - [](const std::string &path, DefaultingPyMlirContext context) { + [](const std::string &path, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParseFromFile( context->get(), toMlirStringRef(path)); @@ -3383,23 +3363,26 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("path"), nb::arg("context").none() = nb::none(), + nb::arg("path"), nb::arg("context") = nb::none(), kModuleParseDocstring) .def_static( "create", - [](const std::optional &loc) { + [](const std::optional &loc) + -> nb::typed { PyLocation pyLoc = maybeGetTracebackLocation(loc); MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("loc").none() = nb::none(), "Creates an empty module") + nb::arg("loc") = nb::none(), "Creates an empty module") .def_prop_ro( "context", - [](PyModule &self) { return self.getContext().getObject(); }, + [](PyModule &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that created the Module") .def_prop_ro( "operation", - [](PyModule &self) { + [](PyModule &self) -> nb::typed { return PyOperation::forOperation(self.getContext(), mlirModuleGetOperation(self.get()), self.getRef().releaseObject()) @@ -3424,11 +3407,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { kDumpDocstring) .def( "__str__", - [](nb::object self) { + [](const nb::object &self) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - kOperationStrDunderDocstring); + nb::sig("def __str__(self) -> str"), kOperationStrDunderDocstring) + .def( + "__eq__", + [](PyModule &self, PyModule &other) { + return mlirModuleEqual(self.get(), other.get()); + }, + "other"_a) + .def("__hash__", + [](PyModule &self) { return mlirModuleHashValue(self.get()); }); //---------------------------------------------------------------------------- // Mapping of Operation. @@ -3440,13 +3431,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { - return &self.getOperation() == &other.getOperation(); + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); }) .def("__eq__", [](PyOperationBase &self, nb::object other) { return false; }) .def("__hash__", [](PyOperationBase &self) { - return static_cast(llvm::hash_value(&self.getOperation())); + return mlirOperationHashValue(self.getOperation().get()); }) .def_prop_ro("attributes", [](PyOperationBase &self) { @@ -3454,7 +3446,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def_prop_ro( "context", - [](PyOperationBase &self) { + [](PyOperationBase &self) -> nb::typed { PyOperation &concreteOperation = self.getOperation(); concreteOperation.checkValid(); return concreteOperation.getContext().getObject(); @@ -3483,28 +3475,35 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns the list of Operation results.") .def_prop_ro( "result", - [](PyOperationBase &self) { + [](PyOperationBase &self) -> nb::typed { auto &operation = self.getOperation(); return PyOpResult(operation.getRef(), getUniqueResult(operation)) .maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_prop_ro( + .def_prop_rw( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); return PyLocation(operation.getContext(), mlirOperationGetLocation(operation.get())); }, - "Returns the source location the operation was defined or derived " - "from.") + [](PyOperationBase &self, const PyLocation &location) { + PyOperation &operation = self.getOperation(); + mlirOperationSetLocation(operation.get(), location.get()); + }, + nb::for_getter("Returns the source location the operation was " + "defined or derived from."), + nb::for_setter("Sets the source location the operation was defined " + "or derived from.")) .def_prop_ro("parent", - [](PyOperationBase &self) -> nb::object { + [](PyOperationBase &self) + -> std::optional> { auto parent = self.getOperation().getParentOperation(); if (parent) return parent->getObject(); - return nb::none(); + return {}; }) .def( "__str__", @@ -3520,35 +3519,36 @@ void mlir::python::populateIRCore(nb::module_ &m) { /*assumeVerified=*/false, /*skipRegions=*/false); }, + nb::sig("def __str__(self) -> str"), "Returns the assembly form of the operation.") .def("print", nb::overload_cast( &PyOperationBase::print), - nb::arg("state"), nb::arg("file").none() = nb::none(), + nb::arg("state"), nb::arg("file") = nb::none(), nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", nb::overload_cast, std::optional, bool, bool, bool, bool, bool, bool, nb::object, bool, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - nb::arg("large_elements_limit").none() = nb::none(), - nb::arg("large_resource_limit").none() = nb::none(), + nb::arg("large_elements_limit") = nb::none(), + nb::arg("large_resource_limit") = nb::none(), nb::arg("enable_debug_info") = false, nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, nb::arg("use_local_scope") = false, nb::arg("use_name_loc_as_prefix") = false, - nb::arg("assume_verified") = false, - nb::arg("file").none() = nb::none(), nb::arg("binary") = false, - nb::arg("skip_regions") = false, kOperationPrintDocstring) + nb::arg("assume_verified") = false, nb::arg("file") = nb::none(), + nb::arg("binary") = false, nb::arg("skip_regions") = false, + kOperationPrintDocstring) .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), - nb::arg("desired_version").none() = nb::none(), + nb::arg("desired_version") = nb::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. nb::arg("binary") = false, - nb::arg("large_elements_limit").none() = nb::none(), - nb::arg("large_resource_limit").none() = nb::none(), + nb::arg("large_elements_limit") = nb::none(), + nb::arg("large_resource_limit") = nb::none(), nb::arg("enable_debug_info") = false, nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, @@ -3574,13 +3574,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { "of the parent block.") .def( "clone", - [](PyOperationBase &self, nb::object ip) { + [](PyOperationBase &self, + const nb::object &ip) -> nb::typed { return self.getOperation().clone(ip); }, - nb::arg("ip").none() = nb::none()) + nb::arg("ip") = nb::none()) .def( "detach_from_parent", - [](PyOperationBase &self) { + [](PyOperationBase &self) -> nb::typed { PyOperation &operation = self.getOperation(); operation.checkValid(); if (!operation.isAttached()) @@ -3600,7 +3601,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Reports if the operation is attached to its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) .def("walk", &PyOperationBase::walk, nb::arg("callback"), - nb::arg("walk_order") = MlirWalkPostOrder); + nb::arg("walk_order") = MlirWalkPostOrder, + // clang-format off + nb::sig("def walk(self, callback: Callable[[Operation], WalkResult], walk_order: WalkOrder) -> None") + // clang-format on + ); nb::class_(m, "Operation") .def_static( @@ -3611,7 +3616,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional attributes, std::optional> successors, int regions, const std::optional &location, - const nb::object &maybeIp, bool inferType) { + const nb::object &maybeIp, + bool inferType) -> nb::typed { // Unpack/validate operands. llvm::SmallVector mlirOperands; if (operands) { @@ -3628,38 +3634,48 @@ void mlir::python::populateIRCore(nb::module_ &m) { successors, regions, pyLoc, maybeIp, inferType); }, - nb::arg("name"), nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0, - nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + nb::arg("name"), nb::arg("results") = nb::none(), + nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), + nb::arg("successors") = nb::none(), nb::arg("regions") = 0, + nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(), nb::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, - DefaultingPyMlirContext context) { + DefaultingPyMlirContext context) + -> nb::typed { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_prop_ro("operation", [](nb::object self) { return self; }) - .def_prop_ro("opview", &PyOperation::createOpView) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyOperation::createFromCapsule) + .def_prop_ro("operation", + [](nb::object self) -> nb::typed { + return self; + }) + .def_prop_ro("opview", + [](PyOperation &self) -> nb::typed { + return self.createOpView(); + }) .def_prop_ro("block", &PyOperation::getBlock) .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, - "Returns the list of Operation successors."); + "Returns the list of Operation successors.") + .def("_set_invalid", &PyOperation::setInvalid, + "Invalidate the operation."); auto opViewClass = nb::class_(m, "OpView") - .def(nb::init(), nb::arg("operation")) + .def(nb::init>(), + nb::arg("operation")) .def( "__init__", [](PyOpView *self, std::string_view name, @@ -3679,18 +3695,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { attributes, successors, regions, pyLoc, maybeIp)); }, nb::arg("name"), nb::arg("opRegionSpec"), - nb::arg("operandSegmentSpecObj").none() = nb::none(), - nb::arg("resultSegmentSpecObj").none() = nb::none(), - nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("loc").none() = nb::none(), - nb::arg("ip").none() = nb::none()) - - .def_prop_ro("operation", &PyOpView::getOperationObject) - .def_prop_ro("opview", [](nb::object self) { return self; }) + nb::arg("operandSegmentSpecObj") = nb::none(), + nb::arg("resultSegmentSpecObj") = nb::none(), + nb::arg("results") = nb::none(), nb::arg("operands") = nb::none(), + nb::arg("attributes") = nb::none(), + nb::arg("successors") = nb::none(), + nb::arg("regions") = nb::none(), nb::arg("loc") = nb::none(), + nb::arg("ip") = nb::none()) + .def_prop_ro( + "operation", + [](PyOpView &self) -> nb::typed { + return self.getOperationObject(); + }) + .def_prop_ro("opview", + [](nb::object self) -> nb::typed { + return self; + }) .def( "__str__", [](PyOpView &self) { return nb::str(self.getOperationObject()); }) @@ -3699,7 +3719,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, - "Returns the list of Operation successors."); + "Returns the list of Operation successors.") + .def( + "_set_invalid", + [](PyOpView &self) { self.getOperation().setInvalid(); }, + "Invalidate the operation."); opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); @@ -3723,16 +3747,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { operandList, attributes, successors, regions, pyLoc, maybeIp); }, - nb::arg("cls"), nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + nb::arg("cls"), nb::arg("results") = nb::none(), + nb::arg("operands") = nb::none(), nb::arg("attributes") = nb::none(), + nb::arg("successors") = nb::none(), nb::arg("regions") = nb::none(), + nb::arg("loc") = nb::none(), nb::arg("ip") = nb::none(), "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( [](const nb::object &cls, const std::string &sourceStr, - const std::string &sourceName, DefaultingPyMlirContext context) { + const std::string &sourceName, + DefaultingPyMlirContext context) -> nb::typed { PyOperationRef parsed = PyOperation::parse(context->getRef(), sourceStr, sourceName); @@ -3752,7 +3775,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyOpView::constructDerived(cls, parsed.getObject()); }, nb::arg("cls"), nb::arg("source"), nb::kw_only(), - nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), + nb::arg("source_name") = "", nb::arg("context") = nb::none(), "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- @@ -3767,7 +3790,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns a forward-optimized sequence of blocks.") .def_prop_ro( "owner", - [](PyRegion &self) { + [](PyRegion &self) -> nb::typed { return self.getParentOperation()->createOpView(); }, "Returns the operation owning this region.") @@ -3792,7 +3815,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) .def_prop_ro( "owner", - [](PyBlock &self) { + [](PyBlock &self) -> nb::typed { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") @@ -3812,8 +3835,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def( "add_argument", [](PyBlock &self, const PyType &type, const PyLocation &loc) { - return mlirBlockAddArgument(self.get(), type, loc); + return PyBlockArgument(self.getParentOperation(), + mlirBlockAddArgument(self.get(), type, loc)); }, + "type"_a, "loc"_a, "Append an argument of the specified type to the block and returns " "the newly added argument.") .def( @@ -3955,6 +3980,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw nb::value_error("No current InsertionPoint"); return ip; }, + nb::sig("def current(/) -> InsertionPoint"), "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") .def(nb::init(), nb::arg("beforeOperation"), @@ -3963,6 +3989,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("block"), "Inserts at the beginning of the block.") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, nb::arg("block"), "Inserts before the block terminator.") + .def_static("after", &PyInsertionPoint::after, nb::arg("operation"), + "Inserts after the operation.") .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), "Inserts an operation.") .def_prop_ro( @@ -3970,11 +3998,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns the block that this InsertionPoint points to.") .def_prop_ro( "ref_operation", - [](PyInsertionPoint &self) -> nb::object { + [](PyInsertionPoint &self) + -> std::optional> { auto refOperation = self.getRefOperation(); if (refOperation) return refOperation->getObject(); - return nb::none(); + return {}; }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " @@ -3989,26 +4018,34 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed attribute to the generic Attribute") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyAttribute::createFromCapsule) .def_static( "parse", - [](const std::string &attrSpec, DefaultingPyMlirContext context) { + [](const std::string &attrSpec, DefaultingPyMlirContext context) + -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirAttribute attr = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); if (mlirAttributeIsNull(attr)) throw MLIRError("Unable to parse attribute", errors.take()); - return attr; + return PyAttribute(context.get()->getRef(), attr).maybeDownCast(); }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + nb::arg("asm"), nb::arg("context") = nb::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " "failure.") .def_prop_ro( "context", - [](PyAttribute &self) { return self.getContext().getObject(); }, + [](PyAttribute &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Attribute") .def_prop_ro("type", - [](PyAttribute &self) { return mlirAttributeGetType(self); }) + [](PyAttribute &self) -> nb::typed { + return PyType(self.getContext(), + mlirAttributeGetType(self)) + .maybeDownCast(); + }) .def( "get_named", [](PyAttribute &self, std::string name) { @@ -4049,23 +4086,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }) .def_prop_ro("typeid", - [](PyAttribute &self) -> MlirTypeID { + [](PyAttribute &self) { MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - return mlirTypeID; + return PyTypeID(mlirTypeID); }) - .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - std::optional typeCaster = - PyGlobals::get().lookupTypeCaster(mlirTypeID, - mlirAttributeGetDialect(self)); - if (!typeCaster) - return nb::cast(self); - return typeCaster.value()(self); - }); + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyAttribute &self) -> nb::typed { + return self.maybeDownCast(); + }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute @@ -4094,7 +4124,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def_prop_ro( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, - nb::keep_alive<0, 1>(), + nb::keep_alive<0, 1>(), nb::sig("def attr(self) -> Attribute"), "The underlying generic attribute of the NamedAttribute binding"); //---------------------------------------------------------------------------- @@ -4106,21 +4136,25 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed type to the generic Type") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", - [](std::string typeSpec, DefaultingPyMlirContext context) { + [](std::string typeSpec, + DefaultingPyMlirContext context) -> nb::typed { PyMlirContext::ErrorCapture errors(context->getRef()); MlirType type = mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); if (mlirTypeIsNull(type)) throw MLIRError("Unable to parse type", errors.take()); - return type; + return PyType(context.get()->getRef(), type).maybeDownCast(); }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + nb::arg("asm"), nb::arg("context") = nb::none(), kContextParseTypeDocstring) .def_prop_ro( - "context", [](PyType &self) { return self.getContext().getObject(); }, + "context", + [](PyType &self) -> nb::typed { + return self.getContext().getObject(); + }, "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) .def( @@ -4155,21 +4189,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { return printAccum.join(); }) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyType &self) { - MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - std::optional typeCaster = - PyGlobals::get().lookupTypeCaster(mlirTypeID, - mlirTypeGetDialect(self)); - if (!typeCaster) - return nb::cast(self); - return typeCaster.value()(self); + [](PyType &self) -> nb::typed { + return self.maybeDownCast(); }) - .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { + .def_prop_ro("typeid", [](PyType &self) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) - return mlirTypeID; + return PyTypeID(mlirTypeID); auto origRepr = nb::cast(nb::repr(nb::cast(self))); throw nb::value_error( (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); @@ -4180,7 +4206,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "TypeID") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether // the Python objects are the same (i.e., PyTypeID is a value type). @@ -4201,10 +4227,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::class_(m, "Value") .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) + .def_static(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_prop_ro( "context", - [](PyValue &self) { return self.getParentOperation()->getContext(); }, + [](PyValue &self) -> nb::typed { + return self.getParentOperation()->getContext().getObject(); + }, "Context in which the value lives.") .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, @@ -4214,11 +4242,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyValue &self) -> nb::object { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in " - "the IR"); + assert(mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match " + "that in " + "the IR"); return self.getParentOperation().getObject(); } @@ -4229,7 +4257,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { assert(false && "Value must be a block argument or an op result"); return nb::none(); - }) + }, + // clang-format off + nb::sig("def owner(self) -> Operation | Block | None")) + // clang-format on .def_prop_ro("uses", [](PyValue &self) { return PyOpOperandIterator( @@ -4287,7 +4318,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("state"), kGetNameAsOperand) .def_prop_ro("type", - [](PyValue &self) { return mlirValueGetType(self.get()); }) + [](PyValue &self) -> nb::typed { + return PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())) + .maybeDownCast(); + }) .def( "set_type", [](PyValue &self, const PyType &type) { @@ -4302,15 +4337,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { kValueReplaceAllUsesWithDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, PyOperation &exception) { + [](PyValue &self, PyValue &with, PyOperation &exception) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - nb::arg("with"), nb::arg("exceptions"), + nb::arg("with_"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, nb::list exceptions) { + [](PyValue &self, PyValue &with, const nb::list &exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector exceptionOps; for (nb::handle exception : exceptions) { @@ -4321,10 +4356,34 @@ void mlir::python::populateIRCore(nb::module_ &m) { self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - nb::arg("with"), nb::arg("exceptions"), + nb::arg("with_"), nb::arg("exceptions"), + kValueReplaceAllUsesExceptDocstring) + .def( + "replace_all_uses_except", + [](PyValue &self, PyValue &with, PyOperation &exception) { + MlirOperation exceptedUser = exception.get(); + mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); + }, + nb::arg("with_"), nb::arg("exceptions"), + kValueReplaceAllUsesExceptDocstring) + .def( + "replace_all_uses_except", + [](PyValue &self, PyValue &with, + std::vector &exceptions) { + // Convert Python list to a SmallVector of MlirOperations + llvm::SmallVector exceptionOps; + for (PyOperation &exception : exceptions) + exceptionOps.push_back(exception); + mlirValueReplaceAllUsesExcept( + self, with, static_cast(exceptionOps.size()), + exceptionOps.data()); + }, + nb::arg("with_"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyValue &self) { return self.maybeDownCast(); }) + [](PyValue &self) -> nb::typed { + return self.maybeDownCast(); + }) .def_prop_ro( "location", [](MlirValue self) { @@ -4349,7 +4408,11 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "SymbolTable") .def(nb::init()) - .def("__getitem__", &PySymbolTable::dunderGetItem) + .def("__getitem__", + [](PySymbolTable &self, + const std::string &name) -> nb::typed { + return self.dunderGetItem(name); + }) .def("insert", &PySymbolTable::insert, nb::arg("operation")) .def("erase", &PySymbolTable::erase, nb::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 9e1fedaab..31d4798ff 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -195,7 +195,7 @@ class PyConcreteOpInterface { static void bind(nb::module_ &m) { nb::class_ cls(m, ConcreteIface::pyClassName); cls.def(nb::init(), nb::arg("object"), - nb::arg("context").none() = nb::none(), constructorDoc) + nb::arg("context") = nb::none(), constructorDoc) .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, operationDoc) .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); @@ -212,22 +212,18 @@ class PyConcreteOpInterface { /// Returns the operation instance from which this object was constructed. /// Throws a type error if this object was constructed from a subclass of /// OpView. - nb::object getOperationObject() { - if (operation == nullptr) { + nb::typed getOperationObject() { + if (operation == nullptr) throw nb::type_error("Cannot get an operation from a static interface"); - } - return operation->getRef().releaseObject(); } /// Returns the opview of the operation instance from which this object was /// constructed. Throws a type error if this object was constructed form a /// subclass of OpView. - nb::object getOpView() { - if (operation == nullptr) { + nb::typed getOpView() { + if (operation == nullptr) throw nb::type_error("Cannot get an opview from a static interface"); - } - return operation->createOpView(); } @@ -303,12 +299,11 @@ class PyInferTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("properties").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("context").none() = nb::none(), - nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); + nb::arg("operands") = nb::none(), + nb::arg("attributes") = nb::none(), + nb::arg("properties") = nb::none(), nb::arg("regions") = nb::none(), + nb::arg("context") = nb::none(), nb::arg("loc") = nb::none(), + inferReturnTypesDoc); } }; @@ -332,6 +327,7 @@ class PyShapedTypeComponents { .def_prop_ro( "element_type", [](PyShapedTypeComponents &self) { return self.elementType; }, + nb::sig("def element_type(self) -> Type"), "Returns the element type of the shaped type components.") .def_static( "get", @@ -362,10 +358,9 @@ class PyShapedTypeComponents { "Returns whether the given shaped type component is ranked.") .def_prop_ro( "rank", - [](PyShapedTypeComponents &self) -> nb::object { - if (!self.ranked) { - return nb::none(); - } + [](PyShapedTypeComponents &self) -> std::optional { + if (!self.ranked) + return {}; return nb::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " @@ -373,10 +368,9 @@ class PyShapedTypeComponents { "returned.") .def_prop_ro( "shape", - [](PyShapedTypeComponents &self) -> nb::object { - if (!self.ranked) { - return nb::none(); - } + [](PyShapedTypeComponents &self) -> std::optional { + if (!self.ranked) + return {}; return nb::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " @@ -463,12 +457,10 @@ class PyInferShapedTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypeComponents", &PyInferShapedTypeOpInterface::inferReturnTypeComponents, - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("properties").none() = nb::none(), - nb::arg("context").none() = nb::none(), - nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); + nb::arg("operands") = nb::none(), + nb::arg("attributes") = nb::none(), nb::arg("regions") = nb::none(), + nb::arg("properties") = nb::none(), nb::arg("context") = nb::none(), + nb::arg("loc") = nb::none(), inferReturnTypeComponentsDoc); } }; diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fa16ae3ce..e706be3b4 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -19,6 +19,7 @@ #include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" @@ -93,6 +94,8 @@ class PyObjectRef { } operator bool() const { return referrent && object; } + using NBTypedT = nanobind::typed; + private: T *referrent; nanobind::object object; @@ -218,36 +221,6 @@ class PyMlirContext { /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); - /// Get a list of Python objects which are still in the live context map. - std::vector getLiveOperationObjects(); - - /// Gets the count of live operations associated with this context. - /// Used for testing. - size_t getLiveOperationCount(); - - /// Clears the live operations map, returning the number of entries which were - /// invalidated. To be used as a safety mechanism so that API end-users can't - /// corrupt by holding references they shouldn't have accessed in the first - /// place. - size_t clearLiveOperations(); - - /// Removes an operation from the live operations map and sets it invalid. - /// This is useful for when some non-bindings code destroys the operation and - /// the bindings need to made aware. For example, in the case when pass - /// manager is run. - /// - /// Note that this does *NOT* clear the nested operations. - void clearOperation(MlirOperation op); - - /// Clears all operations nested inside the given op using - /// `clearOperation(MlirOperation)`. - void clearOperationsInside(PyOperationBase &op); - void clearOperationsInside(MlirOperation op); - - /// Clears the operaiton _and_ all operations inside using - /// `clearOperation(MlirOperation)`. - void clearOperationAndInside(PyOperationBase &op); - /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); @@ -265,6 +238,7 @@ class PyMlirContext { /// Controls whether error diagnostics should be propagated to diagnostic /// handlers, instead of being captured by `ErrorCapture`. void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } + bool getEmitErrorDiagnostics() { return emitErrorDiagnostics; } struct ErrorCapture; private: @@ -286,17 +260,6 @@ class PyMlirContext { llvm::DenseMap>; LiveModuleMap liveModules; - // Interns all live operations associated with this context. Operations - // tracked in this map are valid. When an operation is invalidated, it is - // removed from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveOperationMap = - llvm::DenseMap>; - nanobind::ft_mutex liveOperationsMutex; - - // Guarded by liveOperationsMutex in free-threading mode. - LiveOperationMap liveOperations; - bool emitErrorDiagnostics = false; MlirContext context; @@ -310,7 +273,7 @@ class DefaultingPyMlirContext : public Defaulting { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = "mlir.ir.Context"; + static constexpr const char kTypeDescription[] = "Context"; static PyMlirContext &resolve(); }; @@ -536,7 +499,7 @@ class DefaultingPyLocation : public Defaulting { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = "mlir.ir.Location"; + static constexpr const char kTypeDescription[] = "Location"; static PyLocation &resolve(); operator MlirLocation() const { return *get(); } @@ -548,8 +511,8 @@ class PyModule; using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: - /// Returns a PyModule reference for the given MlirModule. This may return - /// a pre-existing or new object. + /// Returns a PyModule reference for the given MlirModule. This always returns + /// a new object. static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; PyModule(PyMlirContext &&) = delete; @@ -570,11 +533,12 @@ class PyModule : public BaseContextObject { nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. - /// Note that PyModule instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirModule - /// is taken by calling this function. + /// Note this returns a new object BUT clearMlirModule() must be called to + /// prevent double-frees (of the underlying mlir::Module). static nanobind::object createFromCapsule(nanobind::object capsule); + void clearMlirModule() { module = {nullptr}; } + private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; @@ -636,6 +600,7 @@ class PyOperationBase { /// drops to zero or it is attached to a parent, at which point its lifetime /// is bounded by its top-level parent reference. class PyOperation; +class PyOpView; using PyOperationRef = PyObjectRef; class PyOperation : public PyOperationBase, public BaseContextObject { public: @@ -704,7 +669,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates a PyOperation from the MlirOperation wrapped by a capsule. /// Ownership of the underlying MlirOperation is taken by calling this /// function. - static nanobind::object createFromCapsule(nanobind::object capsule); + static nanobind::object createFromCapsule(const nanobind::object &capsule); /// Creates an operation. See corresponding python docstring. static nanobind::object @@ -873,14 +838,19 @@ class PyInsertionPoint { public: /// Creates an insertion point positioned after the last operation in the /// block, but still inside the block. - PyInsertionPoint(PyBlock &block); + PyInsertionPoint(const PyBlock &block); /// Creates an insertion point positioned before a reference operation. PyInsertionPoint(PyOperationBase &beforeOperationBase); + /// Creates an insertion point positioned before a reference operation. + PyInsertionPoint(PyOperationRef beforeOperationRef); /// Shortcut to create an insertion point at the beginning of the block. static PyInsertionPoint atBlockBegin(PyBlock &block); /// Shortcut to create an insertion point before the block terminator. static PyInsertionPoint atBlockTerminator(PyBlock &block); + /// Shortcut to create an insertion point to the node after the specified + /// operation. + static PyInsertionPoint after(PyOperationBase &op); /// Inserts an operation. void insert(PyOperationBase &operationBase); @@ -922,6 +892,8 @@ class PyType : public BaseContextObject { /// is taken by calling this function. static PyType createFromCapsule(nanobind::object capsule); + nanobind::object maybeDownCast(); + private: MlirType type; }; @@ -995,16 +967,18 @@ class PyConcreteType : public BaseTy { }, nanobind::arg("other")); cls.def_prop_ro_static( - "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { + "static_typeid", + [](nanobind::object & /*class*/) { if (DerivedTy::getTypeIdFunction) - return DerivedTy::getTypeIdFunction(); + return PyTypeID(DerivedTy::getTypeIdFunction()); throw nanobind::attribute_error( (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) .str() .c_str()); - }); + }, + nanobind::sig("def static_typeid(/) -> TypeID")); cls.def_prop_ro("typeid", [](PyType &self) { - return nanobind::cast(nanobind::cast(self).attr("typeid")); + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -1046,7 +1020,9 @@ class PyAttribute : public BaseContextObject { /// Note that PyAttribute instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAttribute /// is taken by calling this function. - static PyAttribute createFromCapsule(nanobind::object capsule); + static PyAttribute createFromCapsule(const nanobind::object &capsule); + + nanobind::object maybeDownCast(); private: MlirAttribute attr; @@ -1126,18 +1102,24 @@ class PyConcreteAttribute : public BaseTy { }, nanobind::arg("other")); cls.def_prop_ro( - "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); + "type", + [](PyAttribute &attr) -> nanobind::typed { + return PyType(attr.getContext(), mlirAttributeGetType(attr)) + .maybeDownCast(); + }); cls.def_prop_ro_static( - "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { + "static_typeid", + [](nanobind::object & /*class*/) -> PyTypeID { if (DerivedTy::getTypeIdFunction) - return DerivedTy::getTypeIdFunction(); + return PyTypeID(DerivedTy::getTypeIdFunction()); throw nanobind::attribute_error( (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) .str() .c_str()); - }); + }, + nanobind::sig("def static_typeid(/) -> TypeID")); cls.def_prop_ro("typeid", [](PyAttribute &self) { - return nanobind::cast(nanobind::cast(self).attr("typeid")); + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -1165,6 +1147,17 @@ class PyConcreteAttribute : public BaseTy { static void bindDerived(ClassTy &m) {} }; +class PyStringAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; + static constexpr const char *pyClassName = "StringAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; + + static void bindDerived(ClassTy &c); +}; + /// Wrapper around the generic MlirValue. /// Values are managed completely by the operation that resulted in their /// definition. For op result value, this is the operation that defines the @@ -1216,7 +1209,7 @@ class PyAffineExpr : public BaseContextObject { /// Note that PyAffineExpr instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr /// is taken by calling this function. - static PyAffineExpr createFromCapsule(nanobind::object capsule); + static PyAffineExpr createFromCapsule(const nanobind::object &capsule); PyAffineExpr add(const PyAffineExpr &other) const; PyAffineExpr mul(const PyAffineExpr &other) const; @@ -1243,7 +1236,7 @@ class PyAffineMap : public BaseContextObject { /// Note that PyAffineMap instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineMap /// is taken by calling this function. - static PyAffineMap createFromCapsule(nanobind::object capsule); + static PyAffineMap createFromCapsule(const nanobind::object &capsule); private: MlirAffineMap affineMap; @@ -1263,7 +1256,7 @@ class PyIntegerSet : public BaseContextObject { /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(nanobind::object capsule); + static PyIntegerSet createFromCapsule(const nanobind::object &capsule); private: MlirIntegerSet integerSet; @@ -1291,14 +1284,14 @@ class PySymbolTable { /// Inserts the given operation into the symbol table. The operation must have /// the symbol trait. - MlirAttribute insert(PyOperationBase &symbol); + PyStringAttribute insert(PyOperationBase &symbol); /// Gets and sets the name of a symbol op. - static MlirAttribute getSymbolName(PyOperationBase &symbol); + static PyStringAttribute getSymbolName(PyOperationBase &symbol); static void setSymbolName(PyOperationBase &symbol, const std::string &name); /// Gets and sets the visibility of a symbol op. - static MlirAttribute getVisibility(PyOperationBase &symbol); + static PyStringAttribute getVisibility(PyOperationBase &symbol); static void setVisibility(PyOperationBase &symbol, const std::string &visibility); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index b11e3f75b..34c5b8dd8 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -49,7 +49,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - nb::arg("width"), nb::arg("context").none() = nb::none(), + nb::arg("width"), nb::arg("context") = nb::none(), "Create a signless integer type"); c.def_static( "get_signed", @@ -57,7 +57,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeSignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - nb::arg("width"), nb::arg("context").none() = nb::none(), + nb::arg("width"), nb::arg("context") = nb::none(), "Create a signed integer type"); c.def_static( "get_unsigned", @@ -65,7 +65,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - nb::arg("width"), nb::arg("context").none() = nb::none(), + nb::arg("width"), nb::arg("context") = nb::none(), "Create an unsigned integer type"); c.def_prop_ro( "width", @@ -108,7 +108,7 @@ class PyIndexType : public PyConcreteType { MlirType t = mlirIndexTypeGet(context->get()); return PyIndexType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a index type."); + nb::arg("context") = nb::none(), "Create a index type."); } }; @@ -142,7 +142,7 @@ class PyFloat4E2M1FNType MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); return PyFloat4E2M1FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); + nb::arg("context") = nb::none(), "Create a float4_e2m1fn type."); } }; @@ -163,7 +163,7 @@ class PyFloat6E2M3FNType MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); return PyFloat6E2M3FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); + nb::arg("context") = nb::none(), "Create a float6_e2m3fn type."); } }; @@ -184,7 +184,7 @@ class PyFloat6E3M2FNType MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); return PyFloat6E3M2FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); + nb::arg("context") = nb::none(), "Create a float6_e3m2fn type."); } }; @@ -205,7 +205,7 @@ class PyFloat8E4M3FNType MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); return PyFloat8E4M3FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); + nb::arg("context") = nb::none(), "Create a float8_e4m3fn type."); } }; @@ -225,7 +225,7 @@ class PyFloat8E5M2Type : public PyConcreteType { MlirType t = mlirFloat8E5M2TypeGet(context->get()); return PyFloat8E5M2Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); + nb::arg("context") = nb::none(), "Create a float8_e5m2 type."); } }; @@ -245,7 +245,7 @@ class PyFloat8E4M3Type : public PyConcreteType { MlirType t = mlirFloat8E4M3TypeGet(context->get()); return PyFloat8E4M3Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); + nb::arg("context") = nb::none(), "Create a float8_e4m3 type."); } }; @@ -266,8 +266,7 @@ class PyFloat8E4M3FNUZType MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); return PyFloat8E4M3FNUZType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e4m3fnuz type."); + nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type."); } }; @@ -288,8 +287,7 @@ class PyFloat8E4M3B11FNUZType MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); return PyFloat8E4M3B11FNUZType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e4m3b11fnuz type."); + nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type."); } }; @@ -310,8 +308,7 @@ class PyFloat8E5M2FNUZType MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); return PyFloat8E5M2FNUZType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e5m2fnuz type."); + nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type."); } }; @@ -331,7 +328,7 @@ class PyFloat8E3M4Type : public PyConcreteType { MlirType t = mlirFloat8E3M4TypeGet(context->get()); return PyFloat8E3M4Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); + nb::arg("context") = nb::none(), "Create a float8_e3m4 type."); } }; @@ -352,8 +349,7 @@ class PyFloat8E8M0FNUType MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); return PyFloat8E8M0FNUType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e8m0fnu type."); + nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type."); } }; @@ -373,7 +369,7 @@ class PyBF16Type : public PyConcreteType { MlirType t = mlirBF16TypeGet(context->get()); return PyBF16Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a bf16 type."); + nb::arg("context") = nb::none(), "Create a bf16 type."); } }; @@ -393,7 +389,7 @@ class PyF16Type : public PyConcreteType { MlirType t = mlirF16TypeGet(context->get()); return PyF16Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a f16 type."); + nb::arg("context") = nb::none(), "Create a f16 type."); } }; @@ -413,7 +409,7 @@ class PyTF32Type : public PyConcreteType { MlirType t = mlirTF32TypeGet(context->get()); return PyTF32Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a tf32 type."); + nb::arg("context") = nb::none(), "Create a tf32 type."); } }; @@ -433,7 +429,7 @@ class PyF32Type : public PyConcreteType { MlirType t = mlirF32TypeGet(context->get()); return PyF32Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a f32 type."); + nb::arg("context") = nb::none(), "Create a f32 type."); } }; @@ -453,7 +449,7 @@ class PyF64Type : public PyConcreteType { MlirType t = mlirF64TypeGet(context->get()); return PyF64Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a f64 type."); + nb::arg("context") = nb::none(), "Create a f64 type."); } }; @@ -473,7 +469,7 @@ class PyNoneType : public PyConcreteType { MlirType t = mlirNoneTypeGet(context->get()); return PyNoneType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a none type."); + nb::arg("context") = nb::none(), "Create a none type."); } }; @@ -505,7 +501,10 @@ class PyComplexType : public PyConcreteType { "Create a complex type"); c.def_prop_ro( "element_type", - [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, + [](PyComplexType &self) -> nb::typed { + return PyType(self.getContext(), mlirComplexTypeGetElementType(self)) + .maybeDownCast(); + }, "Returns element type."); } }; @@ -516,7 +515,10 @@ class PyComplexType : public PyConcreteType { void mlir::PyShapedType::bindDerived(ClassTy &c) { c.def_prop_ro( "element_type", - [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, + [](PyShapedType &self) -> nb::typed { + return PyType(self.getContext(), mlirShapedTypeGetElementType(self)) + .maybeDownCast(); + }, "Returns the element type of the shaped type."); c.def_prop_ro( "has_rank", @@ -637,11 +639,16 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, nb::arg("shape"), + c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"), nb::arg("element_type"), nb::kw_only(), - nb::arg("scalable").none() = nb::none(), - nb::arg("scalable_dims").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a vector type") + nb::arg("scalable") = nb::none(), + nb::arg("scalable_dims") = nb::none(), + nb::arg("loc") = nb::none(), "Create a vector type") + .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable") = nb::none(), + nb::arg("scalable_dims") = nb::none(), + nb::arg("context") = nb::none(), "Create a vector type") .def_prop_ro( "scalable", [](MlirType self) { return mlirVectorTypeIsScalable(self); }) @@ -656,10 +663,11 @@ class PyVectorType : public PyConcreteType { } private: - static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, - std::optional> scalableDims, - DefaultingPyLocation loc) { + static PyVectorType + getChecked(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyLocation loc) { if (scalable && scalableDims) { throw nb::value_error("'scalable' and 'scalable_dims' kwargs " "are mutually exclusive."); @@ -694,6 +702,42 @@ class PyVectorType : public PyConcreteType { throw MLIRError("Invalid type", errors.take()); return PyVectorType(elementType.getContext(), type); } + + static PyVectorType get(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyMlirContext context) { + if (scalable && scalableDims) { + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); + } + + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw nb::value_error("Expected len(scalable) == len(shape)."); + + SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( + *scalable, [](const nb::handle &h) { return nb::cast(h); })); + type = mlirVectorTypeGetScalable(shape.size(), shape.data(), + scalableDimFlags.data(), elementType); + } else if (scalableDims) { + SmallVector scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) + throw nb::value_error("Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; + } + type = mlirVectorTypeGetScalable(shape.size(), shape.data(), + scalableDimFlags.data(), elementType); + } else { + type = mlirVectorTypeGet(shape.size(), shape.data(), elementType); + } + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); + } }; /// Ranked Tensor Type subclass - RankedTensorType. @@ -720,16 +764,33 @@ class PyRankedTensorType return PyRankedTensorType(elementType.getContext(), t); }, nb::arg("shape"), nb::arg("element_type"), - nb::arg("encoding").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); - c.def_prop_ro("encoding", - [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = - mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return encoding; - }); + nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(), + "Create a ranked tensor type"); + c.def_static( + "get_unchecked", + [](std::vector shape, PyType &elementType, + std::optional &encodingAttr, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType t = mlirRankedTensorTypeGet( + shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(), + "Create a ranked tensor type"); + c.def_prop_ro( + "encoding", + [](PyRankedTensorType &self) + -> std::optional> { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return PyAttribute(self.getContext(), encoding).maybeDownCast(); + }); } }; @@ -753,7 +814,18 @@ class PyUnrankedTensorType throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, - nb::arg("element_type"), nb::arg("loc").none() = nb::none(), + nb::arg("element_type"), nb::arg("loc") = nb::none(), + "Create a unranked tensor type"); + c.def_static( + "get_unchecked", + [](PyType &elementType, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType t = mlirUnrankedTensorTypeGet(elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("context") = nb::none(), "Create a unranked tensor type"); } }; @@ -785,13 +857,35 @@ class PyMemRefType : public PyConcreteType { return PyMemRefType(elementType.getContext(), t); }, nb::arg("shape"), nb::arg("element_type"), - nb::arg("layout").none() = nb::none(), - nb::arg("memory_space").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a memref type") + nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(), + nb::arg("loc") = nb::none(), "Create a memref type") + .def_static( + "get_unchecked", + [](std::vector shape, PyType &elementType, + PyAttribute *layout, PyAttribute *memorySpace, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute layoutAttr = + layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGet(elementType, shape.size(), shape.data(), + layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout") = nb::none(), + nb::arg("memory_space") = nb::none(), + nb::arg("context") = nb::none(), "Create a memref type") .def_prop_ro( "layout", - [](PyMemRefType &self) -> MlirAttribute { - return mlirMemRefTypeGetLayout(self); + [](PyMemRefType &self) -> nb::typed { + return PyAttribute(self.getContext(), + mlirMemRefTypeGetLayout(self)) + .maybeDownCast(); }, "The layout of the MemRef type.") .def( @@ -815,11 +909,12 @@ class PyMemRefType : public PyConcreteType { "The layout of the MemRef type as an affine map.") .def_prop_ro( "memory_space", - [](PyMemRefType &self) -> std::optional { + [](PyMemRefType &self) + -> std::optional> { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); if (mlirAttributeIsNull(a)) return std::nullopt; - return a; + return PyAttribute(self.getContext(), a).maybeDownCast(); }, "Returns the memory space of the given MemRef type."); } @@ -852,14 +947,31 @@ class PyUnrankedMemRefType return PyUnrankedMemRefType(elementType.getContext(), t); }, nb::arg("element_type"), nb::arg("memory_space").none(), - nb::arg("loc").none() = nb::none(), "Create a unranked memref type") + nb::arg("loc") = nb::none(), "Create a unranked memref type") + .def_static( + "get_unchecked", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("context") = nb::none(), "Create a unranked memref type") .def_prop_ro( "memory_space", - [](PyUnrankedMemRefType &self) -> std::optional { + [](PyUnrankedMemRefType &self) + -> std::optional> { MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); if (mlirAttributeIsNull(a)) return std::nullopt; - return a; + return PyAttribute(self.getContext(), a).maybeDownCast(); }, "Returns the memory space of the given Unranked MemRef type."); } @@ -875,6 +987,20 @@ class PyTupleType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](const std::vector &elements, + DefaultingPyMlirContext context) { + std::vector mlirElements; + mlirElements.reserve(elements.size()); + for (const auto &element : elements) + mlirElements.push_back(element.get()); + MlirType t = mlirTupleTypeGet(context->get(), elements.size(), + mlirElements.data()); + return PyTupleType(context->getRef(), t); + }, + nb::arg("elements"), nb::arg("context") = nb::none(), + "Create a tuple type"); c.def_static( "get_tuple", [](std::vector elements, DefaultingPyMlirContext context) { @@ -882,12 +1008,16 @@ class PyTupleType : public PyConcreteType { elements.data()); return PyTupleType(context->getRef(), t); }, - nb::arg("elements"), nb::arg("context").none() = nb::none(), + nb::arg("elements"), nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"), + // clang-format on "Create a tuple type"); c.def( "get_type", - [](PyTupleType &self, intptr_t pos) { - return mlirTupleTypeGetType(self, pos); + [](PyTupleType &self, intptr_t pos) -> nb::typed { + return PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) + .maybeDownCast(); }, nb::arg("pos"), "Returns the pos-th type in the tuple type."); c.def_prop_ro( @@ -909,6 +1039,26 @@ class PyFunctionType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + std::vector mlirInputs; + mlirInputs.reserve(inputs.size()); + for (const auto &input : inputs) + mlirInputs.push_back(input.get()); + std::vector mlirResults; + mlirResults.reserve(results.size()); + for (const auto &result : results) + mlirResults.push_back(result.get()); + + MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(), + mlirInputs.data(), results.size(), + mlirResults.data()); + return PyFunctionType(context->getRef(), t); + }, + nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), + "Gets a FunctionType from a list of input and result types"); c.def_static( "get", [](std::vector inputs, std::vector results, @@ -918,8 +1068,10 @@ class PyFunctionType : public PyConcreteType { results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, - nb::arg("inputs"), nb::arg("results"), - nb::arg("context").none() = nb::none(), + nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), + // clang-format off + nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"), + // clang-format on "Gets a FunctionType from a list of input and result types"); c.def_prop_ro( "inputs", @@ -963,7 +1115,7 @@ class PyOpaqueType : public PyConcreteType { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, std::string typeData, + [](const std::string &dialectNamespace, const std::string &typeData, DefaultingPyMlirContext context) { MlirType type = mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace), @@ -971,7 +1123,7 @@ class PyOpaqueType : public PyConcreteType { return PyOpaqueType(context->getRef(), type); }, nb::arg("dialect_namespace"), nb::arg("buffer"), - nb::arg("context").none() = nb::none(), + nb::arg("context") = nb::none(), "Create an unregistered (opaque) dialect type."); c.def_prop_ro( "dialect_namespace", diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 278847e7a..a14f09f77 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -52,9 +52,14 @@ NB_MODULE(_mlir, m) { [](PyGlobals &self, bool enabled) { self.getTracebackLoc().setLocTracebacksEnabled(enabled); }) + .def("loc_tracebacks_frame_limit", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebackFramesLimit(); + }) .def("set_loc_tracebacks_frame_limit", - [](PyGlobals &self, int n) { - self.getTracebackLoc().setLocTracebackFramesLimit(n); + [](PyGlobals &self, std::optional n) { + self.getTracebackLoc().setLocTracebackFramesLimit( + n.value_or(PyGlobals::TracebackLoc::kMaxFrames)); }) .def("register_traceback_file_inclusion", [](PyGlobals &self, const std::string &filename) { @@ -136,7 +141,7 @@ NB_MODULE(_mlir, m) { populateRewriteSubmodule(rewriteModule); // Define and populate PassManager submodule. - auto passModule = + auto passManagerModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); - populatePassManagerSubmodule(passModule); + populatePassManagerSubmodule(passManagerModule); } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 20017e25b..572afa902 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,10 +8,13 @@ #include "Pass.h" +#include "Globals.h" #include "IRModule.h" #include "mlir-c/Pass.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on namespace nb = nanobind; using namespace nb::literals; @@ -39,7 +42,7 @@ class PyPassManager { return nb::steal(mlirPythonPassManagerToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) throw nb::python_error(); @@ -54,6 +57,20 @@ class PyPassManager { /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { + //---------------------------------------------------------------------------- + // Mapping of enumerated types + //---------------------------------------------------------------------------- + nb::enum_(m, "PassDisplayMode") + .value("LIST", MLIR_PASS_DISPLAY_MODE_LIST) + .value("PIPELINE", MLIR_PASS_DISPLAY_MODE_PIPELINE); + + //---------------------------------------------------------------------------- + // Mapping of MlirExternalPass + //---------------------------------------------------------------------------- + nb::class_(m, "ExternalPass") + .def("signal_pass_failure", + [](MlirExternalPass pass) { mlirExternalPassSignalFailure(pass); }); + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- @@ -67,7 +84,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirStringRefCreate(anchorOp.data(), anchorOp.size())); new (&self) PyPassManager(passManager); }, - "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), + "anchor_op"_a = nb::str("any"), "context"_a = nb::none(), + // clang-format off + nb::sig("def __init__(self, anchor_op: str = 'any', context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> None"), + // clang-format on "Create a new PassManager for the current (or provided) Context.") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) @@ -109,10 +129,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, - "large_elements_limit"_a.none() = nb::none(), - "large_resource_limit"_a.none() = nb::none(), - "enable_debug_info"_a = false, "print_generic_op_form"_a = false, - "tree_printing_dir_path"_a.none() = nb::none(), + "large_elements_limit"_a = nb::none(), + "large_resource_limit"_a = nb::none(), "enable_debug_info"_a = false, + "print_generic_op_form"_a = false, + "tree_printing_dir_path"_a = nb::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", @@ -126,6 +146,14 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableTiming(passManager.get()); }, "Enable pass timing.") + .def( + "enable_statistics", + [](PyPassManager &passManager, MlirPassDisplayMode displayMode) { + mlirPassManagerEnableStatistics(passManager.get(), displayMode); + }, + "displayMode"_a = + MlirPassDisplayMode::MLIR_PASS_DISPLAY_MODE_PIPELINE, + "Enable pass statistics.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -139,7 +167,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { throw nb::value_error(errorMsg.join().c_str()); return new PyPassManager(passManager); }, - "pipeline"_a, "context"_a.none() = nb::none(), + "pipeline"_a, "context"_a = nb::none(), + // clang-format off + nb::sig("def parse(pipeline: str, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> PassManager"), + // clang-format on "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -158,12 +189,45 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "Add textual pipeline elements to the pass manager. Throws a " "ValueError if the pipeline can't be parsed.") .def( - "run", - [](PyPassManager &passManager, PyOperationBase &op, - bool invalidateOps) { - if (invalidateOps) { - op.getOperation().getContext()->clearOperationsInside(op); + "add", + [](PyPassManager &passManager, const nb::callable &run, + std::optional &name, const std::string &argument, + const std::string &description, const std::string &opName) { + if (!name.has_value()) { + name = nb::cast( + nb::borrow(run.attr("__name__"))); } + MlirTypeID passID = PyGlobals::get().allocateTypeID(); + MlirExternalPassCallbacks callbacks; + callbacks.construct = [](void *obj) { + (void)nb::handle(static_cast(obj)).inc_ref(); + }; + callbacks.destruct = [](void *obj) { + (void)nb::handle(static_cast(obj)).dec_ref(); + }; + callbacks.initialize = nullptr; + callbacks.clone = [](void *) -> void * { + throw std::runtime_error("Cloning Python passes not supported"); + }; + callbacks.run = [](MlirOperation op, MlirExternalPass pass, + void *userData) { + nb::handle(static_cast(userData))(op, pass); + }; + auto externalPass = mlirCreateExternalPass( + passID, mlirStringRefCreate(name->data(), name->length()), + mlirStringRefCreate(argument.data(), argument.length()), + mlirStringRefCreate(description.data(), description.length()), + mlirStringRefCreate(opName.data(), opName.size()), + /*nDependentDialects*/ 0, /*dependentDialects*/ nullptr, + callbacks, /*userData*/ run.ptr()); + mlirPassManagerAddOwnedPass(passManager.get(), externalPass); + }, + "run"_a, "name"_a.none() = nb::none(), "argument"_a.none() = "", + "description"_a.none() = "", "op_name"_a.none() = "", + "Add a python-defined pass to the pass manager.") + .def( + "run", + [](PyPassManager &passManager, PyOperationBase &op) { // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( @@ -172,7 +236,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - "operation"_a, "invalidate_ops"_a = true, + "operation"_a, + // clang-format off + nb::sig("def run(self, operation: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ") -> None"), + // clang-format on "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp index 3ba42bec5..3edcb099c 100644 --- a/mlir/lib/Bindings/Python/RegisterEverything.cpp +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/RegisterEverything.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" NB_MODULE(_mlirRegisterEverything, m) { m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration"; diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 0373f9c7a..9e3d9703c 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -9,10 +9,15 @@ #include "Rewrite.h" #include "IRModule.h" +#include "mlir-c/IR.h" #include "mlir-c/Rewrite.h" +#include "mlir-c/Support.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on #include "mlir/Config/mlir-config.h" +#include "nanobind/nanobind.h" namespace nb = nanobind; using namespace mlir; @@ -21,7 +26,65 @@ using namespace mlir::python; namespace { +class PyPatternRewriter { +public: + PyPatternRewriter(MlirPatternRewriter rewriter) + : base(mlirPatternRewriterAsBase(rewriter)), + ctx(PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {} + + PyInsertionPoint getInsertionPoint() const { + MlirBlock block = mlirRewriterBaseGetInsertionBlock(base); + MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion(base); + + if (mlirOperationIsNull(op)) { + MlirOperation owner = mlirBlockGetParentOperation(block); + auto parent = PyOperation::forOperation(ctx, owner); + return PyInsertionPoint(PyBlock(parent, block)); + } + + return PyInsertionPoint(PyOperation::forOperation(ctx, op)); + } + +private: + MlirRewriterBase base; + PyMlirContextRef ctx; +}; + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH +static nb::object objectFromPDLValue(MlirPDLValue value) { + if (MlirValue v = mlirPDLValueAsValue(value); !mlirValueIsNull(v)) + return nb::cast(v); + if (MlirOperation v = mlirPDLValueAsOperation(value); !mlirOperationIsNull(v)) + return nb::cast(v); + if (MlirAttribute v = mlirPDLValueAsAttribute(value); !mlirAttributeIsNull(v)) + return nb::cast(v); + if (MlirType v = mlirPDLValueAsType(value); !mlirTypeIsNull(v)) + return nb::cast(v); + + throw std::runtime_error("unsupported PDL value type"); +} + +static std::vector objectsFromPDLValues(size_t nValues, + MlirPDLValue *values) { + std::vector args; + args.reserve(nValues); + for (size_t i = 0; i < nValues; ++i) + args.push_back(objectFromPDLValue(values[i])); + return args; +} + +// Convert the Python object to a boolean. +// If it evaluates to False, treat it as success; +// otherwise, treat it as failure. +// Note that None is considered success. +static MlirLogicalResult logicalResultFromObject(const nb::object &obj) { + if (obj.is_none()) + return mlirLogicalResultSuccess(); + + return nb::cast(obj) ? mlirLogicalResultFailure() + : mlirLogicalResultSuccess(); +} + /// Owning Wrapper around a PDLPatternModule. class PyPDLPatternModule { public: @@ -36,6 +99,36 @@ class PyPDLPatternModule { } MlirPDLPatternModule get() { return module; } + void registerRewriteFunction(const std::string &name, + const nb::callable &fn) { + mlirPDLPatternModuleRegisterRewriteFunction( + get(), mlirStringRefCreate(name.data(), name.size()), + [](MlirPatternRewriter rewriter, MlirPDLResultList results, + size_t nValues, MlirPDLValue *values, + void *userData) -> MlirLogicalResult { + nb::handle f = nb::handle(static_cast(userData)); + return logicalResultFromObject( + f(PyPatternRewriter(rewriter), results, + objectsFromPDLValues(nValues, values))); + }, + fn.ptr()); + } + + void registerConstraintFunction(const std::string &name, + const nb::callable &fn) { + mlirPDLPatternModuleRegisterConstraintFunction( + get(), mlirStringRefCreate(name.data(), name.size()), + [](MlirPatternRewriter rewriter, MlirPDLResultList results, + size_t nValues, MlirPDLValue *values, + void *userData) -> MlirLogicalResult { + nb::handle f = nb::handle(static_cast(userData)); + return logicalResultFromObject( + f(PyPatternRewriter(rewriter), results, + objectsFromPDLValues(nValues, values))); + }, + fn.ptr()); + } + private: MlirPDLPatternModule module; }; @@ -60,7 +153,7 @@ class PyFrozenRewritePatternSet { mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static nb::object createFromCapsule(const nb::object &capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) @@ -76,10 +169,50 @@ class PyFrozenRewritePatternSet { /// Create the `mlir.rewrite` here. void mlir::python::populateRewriteSubmodule(nb::module_ &m) { + nb::class_(m, "PatternRewriter") + .def_prop_ro("ip", &PyPatternRewriter::getInsertionPoint, + "The current insertion point of the PatternRewriter."); //---------------------------------------------------------------------------- - // Mapping of the top-level PassManager + // Mapping of the PDLResultList and PDLModule //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "PDLResultList") + .def( + "append", + [](MlirPDLResultList results, const PyValue &value) { + mlirPDLResultListPushBackValue(results, value); + }, + // clang-format off + nb::sig("def append(self, value: " MAKE_MLIR_PYTHON_QUALNAME("ir.Value") ")") + // clang-format on + ) + .def( + "append", + [](MlirPDLResultList results, const PyOperation &op) { + mlirPDLResultListPushBackOperation(results, op); + }, + // clang-format off + nb::sig("def append(self, op: " MAKE_MLIR_PYTHON_QUALNAME("ir.Operation") ")") + // clang-format on + ) + .def( + "append", + [](MlirPDLResultList results, const PyType &type) { + mlirPDLResultListPushBackType(results, type); + }, + // clang-format off + nb::sig("def append(self, type: " MAKE_MLIR_PYTHON_QUALNAME("ir.Type") ")") + // clang-format on + ) + .def( + "append", + [](MlirPDLResultList results, const PyAttribute &attr) { + mlirPDLResultListPushBackAttribute(results, attr); + }, + // clang-format off + nb::sig("def append(self, attr: " MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute") ")") + // clang-format on + ); nb::class_(m, "PDLModule") .def( "__init__", @@ -87,11 +220,41 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { new (&self) PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); }, + // clang-format off + nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), + // clang-format on "module"_a, "Create a PDL module from the given module.") - .def("freeze", [](PyPDLPatternModule &self) { - return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( - mlirRewritePatternSetFromPDLPatternModule(self.get()))); - }); + .def( + "__init__", + [](PyPDLPatternModule &self, PyModule &module) { + new (&self) PyPDLPatternModule( + mlirPDLPatternModuleFromModule(module.get())); + }, + // clang-format off + nb::sig("def __init__(self, module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ") -> None"), + // clang-format on + "module"_a, "Create a PDL module from the given module.") + .def( + "freeze", + [](PyPDLPatternModule &self) { + return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + mlirRewritePatternSetFromPDLPatternModule(self.get()))); + }, + nb::keep_alive<0, 1>()) + .def( + "register_rewrite_function", + [](PyPDLPatternModule &self, const std::string &name, + const nb::callable &fn) { + self.registerRewriteFunction(name, fn); + }, + nb::keep_alive<1, 3>()) + .def( + "register_constraint_function", + [](PyPDLPatternModule &self, const std::string &name, + const nb::callable &fn) { + self.registerConstraintFunction(name, fn); + }, + nb::keep_alive<1, 3>()); #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH nb::class_(m, "FrozenRewritePatternSet") .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, @@ -99,14 +262,63 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( - "apply_patterns_and_fold_greedily", - [](MlirModule module, MlirFrozenRewritePatternSet set) { - auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); - if (mlirLogicalResultIsFailure(status)) - // FIXME: Not sure this is the right error to throw here. - throw nb::value_error("pattern application failed to converge"); - }, - "module"_a, "set"_a, - "Applys the given patterns to the given module greedily while folding " - "results."); + "apply_patterns_and_fold_greedily", + [](PyModule &module, PyFrozenRewritePatternSet &set) { + auto status = + mlirApplyPatternsAndFoldGreedily(module.get(), set.get(), {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error("pattern application failed to converge"); + }, + "module"_a, "set"_a, + // clang-format off + nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"), + // clang-format on + "Applys the given patterns to the given module greedily while folding " + "results.") + .def( + "apply_patterns_and_fold_greedily", + [](PyModule &module, MlirFrozenRewritePatternSet set) { + auto status = + mlirApplyPatternsAndFoldGreedily(module.get(), set, {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error( + "pattern application failed to converge"); + }, + "module"_a, "set"_a, + // clang-format off + nb::sig("def apply_patterns_and_fold_greedily(module: " MAKE_MLIR_PYTHON_QUALNAME("ir.Module") ", set: FrozenRewritePatternSet) -> None"), + // clang-format on + "Applys the given patterns to the given module greedily while " + "folding " + "results.") + .def( + "apply_patterns_and_fold_greedily", + [](PyOperationBase &op, PyFrozenRewritePatternSet &set) { + auto status = mlirApplyPatternsAndFoldGreedilyWithOp( + op.getOperation(), set.get(), {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error( + "pattern application failed to converge"); + }, + "op"_a, "set"_a, + // clang-format off + nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), + // clang-format on + "Applys the given patterns to the given op greedily while folding " + "results.") + .def( + "apply_patterns_and_fold_greedily", + [](PyOperationBase &op, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedilyWithOp( + op.getOperation(), set, {}); + if (mlirLogicalResultIsFailure(status)) + throw std::runtime_error( + "pattern application failed to converge"); + }, + "op"_a, "set"_a, + // clang-format off + nb::sig("def apply_patterns_and_fold_greedily(op: " MAKE_MLIR_PYTHON_QUALNAME("ir._OperationBase") ", set: FrozenRewritePatternSet) -> None"), + // clang-format on + "Applys the given patterns to the given op greedily while folding " + "results."); } diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f9b0fed62..920bca886 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) { // root. This is awkward, but we don't have access to PyMlirContext // object here otherwise. nb::object obj = nb::cast(payloadRoot); - obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); MlirLogicalResult result = mlirTransformApplyNamedSequence( payloadRoot, transformRoot, transformModule, options.options); diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 69c804b76..eaad8a87a 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -197,12 +197,12 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( cast(unwrap(name)), cast(unwrap(file)), line, cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, - llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return cast(a); }), cast(unwrap(dataLocation)), cast(unwrap(rank)), cast(unwrap(allocated)), - cast(unwrap(associated)))); + cast(unwrap(associated)), + llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), + [](Attribute a) { return cast(a); }))); } MlirAttribute mlirLLVMDIDerivedTypeAttrGet( @@ -253,17 +253,16 @@ MlirAttribute mlirLLVMDIFileAttrGet(MlirContext ctx, MlirAttribute name, cast(unwrap(directory)))); } -MlirAttribute -mlirLLVMDICompileUnitAttrGet(MlirContext ctx, MlirAttribute id, - unsigned int sourceLanguage, MlirAttribute file, - MlirAttribute producer, bool isOptimized, - MlirLLVMDIEmissionKind emissionKind, - MlirLLVMDINameTableKind nameTableKind) { +MlirAttribute mlirLLVMDICompileUnitAttrGet( + MlirContext ctx, MlirAttribute id, unsigned int sourceLanguage, + MlirAttribute file, MlirAttribute producer, bool isOptimized, + MlirLLVMDIEmissionKind emissionKind, MlirLLVMDINameTableKind nameTableKind, + MlirAttribute splitDebugFilename) { return wrap(DICompileUnitAttr::get( unwrap(ctx), cast(unwrap(id)), sourceLanguage, cast(unwrap(file)), cast(unwrap(producer)), - isOptimized, DIEmissionKind(emissionKind), - DINameTableKind(nameTableKind))); + isOptimized, DIEmissionKind(emissionKind), DINameTableKind(nameTableKind), + cast(unwrap(splitDebugFilename)))); } MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, uint64_t value) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 21db18dfd..5c2a65d2c 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -59,7 +59,7 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { if (failed(maybeDims)) return result; - linalg::ContractionDimensions contractionDims = *maybeDims; + const linalg::ContractionDimensions &contractionDims = *maybeDims; MLIRContext *ctx = linalgOp.getContext(); auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { @@ -95,7 +95,7 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) { if (failed(maybeDims)) return result; - linalg::ConvolutionDimensions dims = *maybeDims; + const linalg::ConvolutionDimensions &dims = *maybeDims; MLIRContext *ctx = linalgOp.getContext(); auto toI32Attr = diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 9d8554aab..f5f4ed302 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -465,10 +465,6 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); } -MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) { - return wrap(llvm::cast(unwrap(type)).getElementType()); -} - //===----------------------------------------------------------------------===// // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8491553da..188186598 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -465,6 +465,14 @@ MlirModule mlirModuleFromOperation(MlirOperation op) { return wrap(dyn_cast(unwrap(op))); } +bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) { + return unwrap(lhs) == unwrap(rhs); +} + +size_t mlirModuleHashValue(MlirModule mod) { + return OperationEquivalence::computeHash(unwrap(mod).getOperation()); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// @@ -636,6 +644,10 @@ bool mlirOperationEqual(MlirOperation op, MlirOperation other) { return unwrap(op) == unwrap(other); } +size_t mlirOperationHashValue(MlirOperation op) { + return OperationEquivalence::computeHash(unwrap(op)); +} + MlirContext mlirOperationGetContext(MlirOperation op) { return wrap(unwrap(op)->getContext()); } @@ -644,6 +656,10 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) { return wrap(unwrap(op)->getLoc()); } +void mlirOperationSetLocation(MlirOperation op, MlirLocation loc) { + unwrap(op)->setLoc(unwrap(loc)); +} + MlirTypeID mlirOperationGetTypeID(MlirOperation op) { if (auto info = unwrap(op)->getRegisteredInfo()) return wrap(info->getTypeID()); diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 3c499c3e4..72bec11f7 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/Pass/PassManager.h" +#include "llvm/Support/ErrorHandling.h" #include using namespace mlir; @@ -79,6 +80,20 @@ void mlirPassManagerEnableTiming(MlirPassManager passManager) { unwrap(passManager)->enableTiming(); } +void mlirPassManagerEnableStatistics(MlirPassManager passManager, + MlirPassDisplayMode displayMode) { + PassDisplayMode mode; + switch (displayMode) { + case MLIR_PASS_DISPLAY_MODE_LIST: + mode = PassDisplayMode::List; + break; + case MLIR_PASS_DISPLAY_MODE_PIPELINE: + mode = PassDisplayMode::Pipeline; + break; + } + unwrap(passManager)->enableStatistics(mode); +} + MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName) { return wrap(&unwrap(passManager)->nest(unwrap(operationName))); @@ -145,10 +160,14 @@ class ExternalPass : public Pass { : Pass(passID, opName), id(passID), name(name), argument(argument), description(description), dependentDialects(dependentDialects), callbacks(callbacks), userData(userData) { - callbacks.construct(userData); + if (callbacks.construct) + callbacks.construct(userData); } - ~ExternalPass() override { callbacks.destruct(userData); } + ~ExternalPass() override { + if (callbacks.destruct) + callbacks.destruct(userData); + } StringRef getName() const override { return name; } StringRef getArgument() const override { return argument; } diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index a4df97f7b..c15a73b99 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -13,6 +13,8 @@ #include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/PDLPatternMatch.h.inc" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -68,6 +70,17 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { return wrap(unwrap(rewriter)->getBlock()); } +MlirOperation +mlirRewriterBaseGetOperationAfterInsertion(MlirRewriterBase rewriter) { + mlir::RewriterBase *base = unwrap(rewriter); + mlir::Block *block = base->getInsertionBlock(); + mlir::Block::iterator it = base->getInsertionPoint(); + if (it == block->end()) + return {nullptr}; + + return wrap(std::addressof(*it)); +} + //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning //===----------------------------------------------------------------------===// @@ -257,22 +270,23 @@ void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { /// RewritePatternSet and FrozenRewritePatternSet API //===----------------------------------------------------------------------===// -inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { +static inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast(module.ptr)); } -inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { +static inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { return {module}; } -inline mlir::FrozenRewritePatternSet * +static inline mlir::FrozenRewritePatternSet * unwrap(MlirFrozenRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return static_cast(module.ptr); } -inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { +static inline MlirFrozenRewritePatternSet +wrap(mlir::FrozenRewritePatternSet *module) { return {module}; } @@ -294,17 +308,41 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op, return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } +MlirLogicalResult +mlirApplyPatternsAndFoldGreedilyWithOp(MlirOperation op, + MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig) { + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); +} + +//===----------------------------------------------------------------------===// +/// PatternRewriter API +//===----------------------------------------------------------------------===// + +inline mlir::PatternRewriter *unwrap(MlirPatternRewriter rewriter) { + assert(rewriter.ptr && "unexpected null rewriter"); + return static_cast(rewriter.ptr); +} + +inline MlirPatternRewriter wrap(mlir::PatternRewriter *rewriter) { + return {rewriter}; +} + +MlirRewriterBase mlirPatternRewriterAsBase(MlirPatternRewriter rewriter) { + return wrap(static_cast(unwrap(rewriter))); +} + //===----------------------------------------------------------------------===// /// PDLPatternModule API //===----------------------------------------------------------------------===// #if MLIR_ENABLE_PDL_IN_PATTERNMATCH -inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { +static inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); return static_cast(module.ptr); } -inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { +static inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { return {module}; } @@ -324,4 +362,93 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { op.ptr = nullptr; return wrap(m); } + +inline const mlir::PDLValue *unwrap(MlirPDLValue value) { + assert(value.ptr && "unexpected null PDL value"); + return static_cast(value.ptr); +} + +inline MlirPDLValue wrap(const mlir::PDLValue *value) { return {value}; } + +inline mlir::PDLResultList *unwrap(MlirPDLResultList results) { + assert(results.ptr && "unexpected null PDL results"); + return static_cast(results.ptr); +} + +inline MlirPDLResultList wrap(mlir::PDLResultList *results) { + return {results}; +} + +MlirValue mlirPDLValueAsValue(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +MlirType mlirPDLValueAsType(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +MlirOperation mlirPDLValueAsOperation(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +MlirAttribute mlirPDLValueAsAttribute(MlirPDLValue value) { + return wrap(unwrap(value)->dyn_cast()); +} + +void mlirPDLResultListPushBackValue(MlirPDLResultList results, + MlirValue value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackType(MlirPDLResultList results, MlirType value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackOperation(MlirPDLResultList results, + MlirOperation value) { + unwrap(results)->push_back(unwrap(value)); +} + +void mlirPDLResultListPushBackAttribute(MlirPDLResultList results, + MlirAttribute value) { + unwrap(results)->push_back(unwrap(value)); +} + +inline std::vector wrap(ArrayRef values) { + std::vector mlirValues; + mlirValues.reserve(values.size()); + for (auto &value : values) { + mlirValues.push_back(wrap(&value)); + } + return mlirValues; +} + +void mlirPDLPatternModuleRegisterRewriteFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLRewriteFunction rewriteFn, void *userData) { + unwrap(pdlModule)->registerRewriteFunction( + unwrap(name), + [userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results, + ArrayRef values) -> LogicalResult { + std::vector mlirValues = wrap(values); + return unwrap(rewriteFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} + +void mlirPDLPatternModuleRegisterConstraintFunction( + MlirPDLPatternModule pdlModule, MlirStringRef name, + MlirPDLConstraintFunction constraintFn, void *userData) { + unwrap(pdlModule)->registerConstraintFunction( + unwrap(name), + [userData, constraintFn](PatternRewriter &rewriter, + PDLResultList &results, + ArrayRef values) -> LogicalResult { + std::vector mlirValues = wrap(values); + return unwrap(constraintFn(wrap(&rewriter), wrap(&results), + mlirValues.size(), mlirValues.data(), + userData)); + }); +} #endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7a0c95ebb..9f5246de6 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -1,5 +1,9 @@ include(AddMLIRPython) +# Specifies that all MLIR packages are co-located under the `MLIR_PYTHON_PACKAGE_PREFIX.` +# top level package (the API has been embedded in a relocatable way). +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=${MLIR_PYTHON_PACKAGE_PREFIX}.") + ################################################################################ # Structural groupings. ################################################################################ @@ -19,15 +23,11 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python ADD_TO_PARENT MLIRPythonSources.Core SOURCES _mlir_libs/__init__.py + _mlir_libs/_mlir/py.typed ir.py passmanager.py rewrite.py dialects/_ods_common.py - - # The main _mlir module has submodules: include stubs from each. - _mlir_libs/_mlir/__init__.pyi - _mlir_libs/_mlir/ir.pyi - _mlir_libs/_mlir/passmanager.pyi ) declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras @@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" DIALECT_NAME transform EXTENSION_NAME transform_pdl_extension) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformSMTExtensionOps.td + SOURCES + dialects/transform/smt.py + DIALECT_NAME transform + EXTENSION_NAME transform_smt_extension) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -453,6 +462,14 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME tosa ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/UBOps.td + SOURCES dialects/ub.py + DIALECT_NAME ub +) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -462,6 +479,15 @@ declare_mlir_dialect_python_bindings( GEN_ENUM_BINDINGS_TD_FILE "dialects/VectorAttributes.td") +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/IRDLOps.td + SOURCES dialects/irdl.py + DIALECT_NAME irdl + GEN_ENUM_BINDINGS +) + ################################################################################ # Python extensions. # The sources for these are all in lib/Bindings/Python, but since they have to @@ -637,6 +663,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind MLIRCAPITransformDialect ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.IRDL.Pybind + MODULE_NAME _mlirDialectsIRDL + ADD_TO_PARENT MLIRPythonSources.Dialects.irdl + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectIRDL.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIIRDL +) + declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async @@ -655,7 +695,7 @@ if(MLIR_ENABLE_EXECUTION_ENGINE) MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine ROOT_DIR "${PYTHON_SOURCE_DIR}" - PYTHON_BINDINGS_LIBRARY nanobind + PYTHON_BINDINGS_LIBRARY nanobind SOURCES ExecutionEngineModule.cpp PRIVATE_LINK_LIBS @@ -806,10 +846,11 @@ endif() # once ready. ################################################################################ +set(MLIRPythonModules_ROOT_PREFIX "${MLIR_BINARY_DIR}/${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}") add_mlir_python_common_capi_library(MLIRPythonCAPI INSTALL_COMPONENT MLIRPythonModules INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" - OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs" + OUTPUT_DIRECTORY "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs" RELATIVE_INSTALL_ROOT "../../../.." DECLARED_HEADERS MLIRPythonCAPI.HeaderSources @@ -832,19 +873,113 @@ if(NOT LLVM_ENABLE_IDE) ) endif() +# Stubgen doesn't work when cross-compiling (stubgen will run in the host interpreter and then fail +# to find the extension module for the host arch). +if(NOT CMAKE_CROSSCOMPILING) + # _mlir stubgen + # Note: All this needs to come before add_mlir_python_modules(MLIRPythonModules so that the install targets for the + # generated type stubs get created. + + set(_core_type_stub_sources + _mlir/__init__.pyi + _mlir/ir.pyi + _mlir/passmanager.pyi + _mlir/rewrite.pyi + ) + + # Note 1: INTERFACE_SOURCES is a genex ($ $) + # which will be evaluated by file(GENERATE ...) inside mlir_generate_type_stubs. This will evaluate to the correct + # thing in the build dir (i.e., actual source dir paths) and in the install dir + # (where it's a conventional path; see install/lib/cmake/mlir/MLIRTargets.cmake). + # + # Note 2: MLIRPythonExtension.Core is the target that is defined using target_sources(INTERFACE) + # **NOT** MLIRPythonModules.extension._mlir.dso. So be sure to use the correct target! + get_target_property(_core_extension_srcs MLIRPythonExtension.Core INTERFACE_SOURCES) + + # Why is MODULE_NAME _mlir here but mlir._mlir_libs._mlirPythonTestNanobind below??? + # The _mlir extension can be imported independently of any other python code and/or extension modules. + # I.e., you could do `cd $MLIRPythonModules_ROOT_PREFIX/_mlir_libs && python -c "import _mlir"` (try it!). + # _mlir is also (currently) the only extension for which this is possible because dialect extensions modules, + # which generally make use of `mlir_value_subclass/mlir_type_subclass/mlir_attribute_subclass`, perform an + # `import mlir` right when they're loaded (see the mlir_*_subclass ctors in NanobindAdaptors.h). + # Note, this also why IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs" here while below + # "${MLIRPythonModules_ROOT_PREFIX}/.." (because MLIR_BINDINGS_PYTHON_INSTALL_PREFIX, by default, ends at mlir). + # + # Further note: this function creates file targets like + # "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs/_mlir/__init__.pyi". These must match the file targets + # that declare_mlir_python_sources expects, which are like "${ROOT_DIR}/${WHATEVER_SOURCE}". + # This is why _mlir_libs is prepended below. + mlir_generate_type_stubs( + MODULE_NAME _mlir + DEPENDS_TARGETS MLIRPythonModules.extension._mlir.dso + OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs" + OUTPUTS "${_core_type_stub_sources}" + DEPENDS_TARGET_SRC_DEPS "${_core_extension_srcs}" + IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/_mlir_libs" + ) + set(_mlir_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}") + + list(TRANSFORM _core_type_stub_sources PREPEND "_mlir_libs/") + # Note, we do not do ADD_TO_PARENT here so that the type stubs are not associated (as mlir_DEPENDS) with + # MLIRPythonSources.Core (or something) when a distro is installed/created. Otherwise they would not be regenerated + # by users of the distro (the stubs are still installed in the distro - they are just not added to mlir_DEPENDS). + declare_mlir_python_sources( + MLIRPythonExtension.Core.type_stub_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs" + SOURCES "${_core_type_stub_sources}" + ) + + # _mlirPythonTestNanobind stubgen + + if(MLIR_INCLUDE_TESTS) + get_target_property(_test_extension_srcs MLIRPythonTestSources.PythonTestExtensionNanobind INTERFACE_SOURCES) + mlir_generate_type_stubs( + # This is the FQN path because dialect modules import _mlir when loaded. See above. + MODULE_NAME mlir._mlir_libs._mlirPythonTestNanobind + DEPENDS_TARGETS + # You need both _mlir and _mlirPythonTestNanobind because dialect modules import _mlir when loaded + # (so _mlir needs to be built before calling stubgen). + MLIRPythonModules.extension._mlir.dso + MLIRPythonModules.extension._mlirPythonTestNanobind.dso + # You need this one so that ir.py "built" because mlir._mlir_libs.__init__.py import mlir.ir in _site_initialize. + MLIRPythonModules.sources.MLIRPythonSources.Core.Python + OUTPUT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs/_mlir_libs" + OUTPUTS _mlirPythonTestNanobind.pyi + DEPENDS_TARGET_SRC_DEPS "${_test_extension_srcs}" + IMPORT_PATHS "${MLIRPythonModules_ROOT_PREFIX}/.." + ) + set(_mlirPythonTestNanobind_typestub_gen_target "${NB_STUBGEN_CUSTOM_TARGET}") + declare_mlir_python_sources( + MLIRPythonTestSources.PythonTestExtensionNanobind.type_stub_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}/type_stubs" + ADD_TO_PARENT MLIRPythonTestSources.Dialects + SOURCES _mlir_libs/_mlirPythonTestNanobind.pyi + ) + endif() +endif() + ################################################################################ # The fully assembled package of modules. # This must come last. ################################################################################ +set(_declared_sources MLIRPythonSources MLIRPythonExtension.RegisterEverything) +if(NOT CMAKE_CROSSCOMPILING) + list(APPEND _declared_sources MLIRPythonExtension.Core.type_stub_gen) +endif() + add_mlir_python_modules(MLIRPythonModules - ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir" + ROOT_PREFIX ${MLIRPythonModules_ROOT_PREFIX} INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}" DECLARED_SOURCES - MLIRPythonSources - MLIRPythonExtension.RegisterEverything + ${_declared_sources} ${_ADDL_TEST_SOURCES} COMMON_CAPI_LINK_LIBS MLIRPythonCAPI ) - +if(NOT CMAKE_CROSSCOMPILING) + add_dependencies(MLIRPythonModules "${_mlir_typestub_gen_target}") + if(MLIR_INCLUDE_TESTS) + add_dependencies(MLIRPythonModules "${_mlirPythonTestNanobind_typestub_gen_target}") + endif() +endif() diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 083a9075f..63244212b 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -147,7 +147,9 @@ def process_initializer_module(module_name): if not process_initializer_module(module_name): break - class Context(ir._BaseContext): + ir._Context = ir.Context + + class Context(ir._Context): def __init__( self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs ): diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi deleted file mode 100644 index 03449b70b..000000000 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ /dev/null @@ -1,12 +0,0 @@ - -globals: "_Globals" - -class _Globals: - dialect_search_modules: list[str] - def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... - def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... - def append_dialect_search_prefix(self, module_name: str) -> None: ... - def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... - -def register_dialect(dialect_class: type) -> type: ... -def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi deleted file mode 100644 index dcae3dd74..000000000 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ /dev/null @@ -1,2846 +0,0 @@ -# Originally imported via: -# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir -# but with the following diff (in order to remove pipes from types, -# which we won't support until bumping minimum python to 3.10) -# -# --------------------- diff begins ------------------------------------ -# -# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py -# index 1f755aa..4924927 100644 -# --- a/pybind11_stubgen/printer.py -# +++ b/pybind11_stubgen/printer.py -# @@ -283,14 +283,6 @@ class Printer: -# return split[0] + "..." -# -# def print_type(self, type_: ResolvedType) -> str: -# - if ( -# - str(type_.name) == "typing.Optional" -# - and type_.parameters is not None -# - and len(type_.parameters) == 1 -# - ): -# - return f"{self.print_annotation(type_.parameters[0])} | None" -# - if str(type_.name) == "typing.Union" and type_.parameters is not None: -# - return " | ".join(self.print_annotation(p) for p in type_.parameters) -# if type_.parameters: -# param_str = ( -# "[" -# -# --------------------- diff ends ------------------------------------ -# -# Local modifications: -# * Rewrite references to 'mlir.ir' to local types. -# * Drop `typing.` everywhere (top-level import instead). -# * List -> List, dict -> Dict, Tuple -> Tuple. -# * copy-paste Buffer type from typing_extensions. -# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top. -# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`. -# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute. -# * Local edits to signatures and types that pybind11-stubgen did not auto detect (or detected incorrectly). -# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand. -# * Fill in `Any`s where possible. -# * black formatting. - -from __future__ import annotations - -import abc -import collections -from collections.abc import Callable, Sequence -from pathlib import Path -from typing import Any, BinaryIO, ClassVar, Literal, TypeVar, overload - -__all__ = [ - "AffineAddExpr", - "AffineBinaryExpr", - "AffineCeilDivExpr", - "AffineConstantExpr", - "AffineDimExpr", - "AffineExpr", - "AffineExprList", - "AffineFloorDivExpr", - "AffineMap", - "AffineMapAttr", - "AffineModExpr", - "AffineMulExpr", - "AffineSymbolExpr", - "ArrayAttr", - "ArrayAttributeIterator", - "AsmState", - "AttrBuilder", - "Attribute", - "BF16Type", - "Block", - "BlockArgument", - "BlockArgumentList", - "BlockIterator", - "BlockList", - "BoolAttr", - "ComplexType", - "Context", - "DenseBoolArrayAttr", - "DenseBoolArrayIterator", - "DenseElementsAttr", - "DenseF32ArrayAttr", - "DenseF32ArrayIterator", - "DenseF64ArrayAttr", - "DenseF64ArrayIterator", - "DenseFPElementsAttr", - "DenseI16ArrayAttr", - "DenseI16ArrayIterator", - "DenseI32ArrayAttr", - "DenseI32ArrayIterator", - "DenseI64ArrayAttr", - "DenseI64ArrayIterator", - "DenseI8ArrayAttr", - "DenseI8ArrayIterator", - "DenseIntElementsAttr", - "DenseResourceElementsAttr", - "Diagnostic", - "DiagnosticHandler", - "DiagnosticInfo", - "DiagnosticSeverity", - "Dialect", - "DialectDescriptor", - "DialectRegistry", - "Dialects", - "DictAttr", - "F16Type", - "F32Type", - "F64Type", - "FlatSymbolRefAttr", - "Float4E2M1FNType", - "Float6E2M3FNType", - "Float6E3M2FNType", - "Float8E3M4Type", - "Float8E4M3B11FNUZType", - "Float8E4M3FNType", - "Float8E4M3FNUZType", - "Float8E4M3Type", - "Float8E5M2FNUZType", - "Float8E5M2Type", - "Float8E8M0FNUType", - "FloatAttr", - "FloatTF32Type", - "FloatType", - "FunctionType", - "IndexType", - "InferShapedTypeOpInterface", - "InferTypeOpInterface", - "InsertionPoint", - "IntegerAttr", - "IntegerSet", - "IntegerSetAttr", - "IntegerSetConstraint", - "IntegerSetConstraintList", - "IntegerType", - "Location", - "MemRefType", - "Module", - "NamedAttribute", - "NoneType", - "OpAttributeMap", - "OpOperand", - "OpOperandIterator", - "OpOperandList", - "OpResult", - "OpResultList", - "OpSuccessors", - "OpView", - "OpaqueAttr", - "OpaqueType", - "Operation", - "OperationIterator", - "OperationList", - "RankedTensorType", - "Region", - "RegionIterator", - "RegionSequence", - "ShapedType", - "ShapedTypeComponents", - "StridedLayoutAttr", - "StringAttr", - "SymbolRefAttr", - "SymbolTable", - "TupleType", - "Type", - "TypeAttr", - "TypeID", - "UnitAttr", - "UnrankedMemRefType", - "UnrankedTensorType", - "Value", - "VectorType", - "_GlobalDebug", - "_OperationBase", -] - -if hasattr(collections.abc, "Buffer"): - Buffer = collections.abc.Buffer -else: - class Buffer(abc.ABC): - pass - -class _OperationBase: - @overload - def __eq__(self, arg0: _OperationBase) -> bool: ... - @overload - def __eq__(self, arg0: _OperationBase) -> bool: ... - def __hash__(self) -> int: ... - def __str__(self) -> str: - """ - Returns the assembly form of the operation. - """ - def clone(self, ip: InsertionPoint = None) -> OpView: ... - def detach_from_parent(self) -> OpView: - """ - Detaches the operation from its parent block. - """ - - @property - def attached(self) -> bool: - """ - Reports if the operation is attached to its parent block. - """ - - def erase(self) -> None: ... - - @overload - def get_asm( - binary: Literal[True], - large_elements_limit: int | None = None, - large_resource_limit: int | None = None, - enable_debug_info: bool = False, - pretty_debug_info: bool = False, - print_generic_op_form: bool = False, - use_local_scope: bool = False, - assume_verified: bool = False, - skip_regions: bool = False, - ) -> bytes: ... - @overload - def get_asm( - self, - binary: bool = False, - large_elements_limit: int | None = None, - large_resource_limit: int | None = None, - enable_debug_info: bool = False, - pretty_debug_info: bool = False, - print_generic_op_form: bool = False, - use_local_scope: bool = False, - assume_verified: bool = False, - skip_regions: bool = False, - ) -> str: - """ - Returns the assembly form of the operation. - - See the print() method for common keyword arguments for configuring - the output. - """ - - def move_after(self, other: _OperationBase) -> None: - """ - Puts self immediately after the other operation in its parent block. - """ - def move_before(self, other: _OperationBase) -> None: - """ - Puts self immediately before the other operation in its parent block. - """ - @overload - def print( - self, - state: AsmState, - file: Any | None = None, - binary: bool = False, - ) -> None: - """ - Prints the assembly form of the operation to a file like object. - - Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - state: AsmState capturing the operation numbering and flags. - """ - @overload - def print( - self, - large_elements_limit: int | None = None, - large_resource_limit: int | None = None, - enable_debug_info: bool = False, - pretty_debug_info: bool = False, - print_generic_op_form: bool = False, - use_local_scope: bool = False, - assume_verified: bool = False, - file: Any | None = None, - binary: bool = False, - skip_regions: bool = False, - ) -> None: - """ - Prints the assembly form of the operation to a file like object. - - Args: - file: The file like object to write to. Defaults to sys.stdout. - binary: Whether to write bytes (True) or str (False). Defaults to False. - large_elements_limit: Whether to elide elements attributes above this - number of elements. Defaults to None (no limit). - large_resource_limit: Whether to elide resource strings above this - number of characters. Defaults to None (no limit). If large_elements_limit - is set and this is None, the behavior will be to use large_elements_limit - as large_resource_limit. - enable_debug_info: Whether to print debug/location information. Defaults - to False. - pretty_debug_info: Whether to format debug information for easier reading - by a human (warning: the result is unparseable). - print_generic_op_form: Whether to print the generic assembly forms of all - ops. Defaults to False. - use_local_Scope: Whether to print in a way that is more optimized for - multi-threaded access but may not be consistent with how the overall - module prints. - assume_verified: By default, if not printing generic form, the verifier - will be run and if it fails, generic form will be printed with a comment - about failed verification. While a reasonable default for interactive use, - for systematic use, it is often better for the caller to verify explicitly - and report failures in a more robust fashion. Set this to True if doing this - in order to avoid running a redundant verification. If the IR is actually - invalid, behavior is undefined. - skip_regions: Whether to skip printing regions. Defaults to False. - """ - def verify(self) -> bool: - """ - Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. - """ - def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None: - """ - Write the bytecode form of the operation to a file like object. - - Args: - file: The file like object or path to write to. - desired_version: The version of bytecode to emit. - Returns: - The bytecode writer status. - """ - @property - def _CAPIPtr(self) -> object: ... - @property - def attributes(self) -> OpAttributeMap: ... - @property - def context(self) -> Context: - """ - Context that owns the Operation - """ - @property - def location(self) -> Location: - """ - Returns the source location the operation was defined or derived from. - """ - @property - def name(self) -> str: ... - @property - def operands(self) -> OpOperandList: ... - @property - def parent(self) -> _OperationBase | None: ... - @property - def regions(self) -> RegionSequence: ... - @property - def result(self) -> OpResult: - """ - Shortcut to get an op result if it has only one (throws an error otherwise). - """ - @property - def results(self) -> OpResultList: - """ - Returns the List of Operation results. - """ - -_TOperation = TypeVar("_TOperation", bound=_OperationBase) - -class AffineExpr: - @staticmethod - @overload - def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: - """ - Gets an affine expression containing a sum of two expressions. - """ - @staticmethod - @overload - def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr: - """ - Gets an affine expression containing a sum of a constant and another expression. - """ - @staticmethod - @overload - def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr: - """ - Gets an affine expression containing a sum of an expression and a constant. - """ - @staticmethod - @overload - def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: - """ - Gets an affine expression containing the rounded-up result of dividing one expression by another. - """ - @staticmethod - @overload - def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr: - """ - Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression. - """ - @staticmethod - @overload - def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr: - """ - Gets an affine expression containing the rounded-up result of dividing an expression by a constant. - """ - @staticmethod - def get_constant( - value: int, context: Context | None = None - ) -> AffineConstantExpr: - """ - Gets a constant affine expression with the given value. - """ - @staticmethod - def get_dim(position: int, context: Context | None = None) -> AffineDimExpr: - """ - Gets an affine expression of a dimension at the given position. - """ - @staticmethod - @overload - def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: - """ - Gets an affine expression containing the rounded-down result of dividing one expression by another. - """ - @staticmethod - @overload - def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr: - """ - Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression. - """ - @staticmethod - @overload - def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr: - """ - Gets an affine expression containing the rounded-down result of dividing an expression by a constant. - """ - @staticmethod - @overload - def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: - """ - Gets an affine expression containing the modulo of dividing one expression by another. - """ - @staticmethod - @overload - def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr: - """ - Gets a semi-affine expression containing the modulo of dividing a constant by an expression. - """ - @staticmethod - @overload - def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr: - """ - Gets an affine expression containing the module of dividingan expression by a constant. - """ - @staticmethod - @overload - def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: - """ - Gets an affine expression containing a product of two expressions. - """ - @staticmethod - @overload - def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr: - """ - Gets an affine expression containing a product of a constant and another expression. - """ - @staticmethod - @overload - def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr: - """ - Gets an affine expression containing a product of an expression and a constant. - """ - @staticmethod - def get_symbol( - position: int, context: Context | None = None - ) -> AffineSymbolExpr: - """ - Gets an affine expression of a symbol at the given position. - """ - def _CAPICreate(self) -> AffineExpr: ... - @overload - def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ... - @overload - def __add__(self, arg0: int) -> AffineAddExpr: ... - @overload - def __eq__(self, arg0: AffineExpr) -> bool: ... - @overload - def __eq__(self, arg0: Any) -> bool: ... - def __hash__(self) -> int: ... - @overload - def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ... - @overload - def __mod__(self, arg0: int) -> AffineModExpr: ... - @overload - def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ... - @overload - def __mul__(self, arg0: int) -> AffineMulExpr: ... - def __radd__(self, arg0: int) -> AffineAddExpr: ... - def __rmod__(self, arg0: int) -> AffineModExpr: ... - def __rmul__(self, arg0: int) -> AffineMulExpr: ... - def __rsub__(self, arg0: int) -> AffineAddExpr: ... - @overload - def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ... - @overload - def __sub__(self, arg0: int) -> AffineAddExpr: ... - def compose(self, arg0: AffineMap) -> AffineExpr: ... - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> Context: ... - -class Attribute: - @staticmethod - def parse(asm: str | bytes, context: Context | None = None) -> Attribute: - """ - Parses an attribute from an assembly form. Raises an MLIRError on failure. - """ - def _CAPICreate(self) -> Attribute: ... - @overload - def __eq__(self, arg0: Attribute) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - def __init__(self, cast_from_type: Attribute) -> None: - """ - Casts the passed attribute to the generic Attribute - """ - def __str__(self) -> str: - """ - Returns the assembly form of the Attribute. - """ - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - def get_named(self, arg0: str) -> NamedAttribute: - """ - Binds a name to the attribute - """ - def maybe_downcast(self) -> Any: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> Context: - """ - Context that owns the Attribute - """ - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class Type: - @staticmethod - def parse(asm: str | bytes, context: Context | None = None) -> Type: - """ - Parses the assembly form of a type. - - Returns a Type object or raises an MLIRError if the type cannot be parsed. - - See also: https://mlir.llvm.org/docs/LangRef/#type-system - """ - def _CAPICreate(self) -> Type: ... - @overload - def __eq__(self, arg0: Type) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - def __init__(self, cast_from_type: Type) -> None: - """ - Casts the passed type to the generic Type - """ - def __str__(self) -> str: - """ - Returns the assembly form of the type. - """ - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - def maybe_downcast(self) -> Any: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> Context: - """ - Context that owns the Type - """ - @property - def typeid(self) -> TypeID: ... - -class Value: - def _CAPICreate(self) -> Value: ... - @overload - def __eq__(self, arg0: Value) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - def __init__(self, value: Value) -> None: ... - def __str__(self) -> str: - """ - Returns the string form of the value. - - If the value is a block argument, this is the assembly form of its type and the - position in the argument List. If the value is an operation result, this is - equivalent to printing the operation that produced it. - """ - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - @overload - def get_name(self, use_local_scope: bool = False, use_name_loc_as_prefix: bool = True) -> str: ... - @overload - def get_name(self, state: AsmState) -> str: - """ - Returns the string form of value as an operand (i.e., the ValueID). - """ - def maybe_downcast(self) -> Any: ... - def replace_all_uses_with(self, arg0: Value) -> None: - """ - Replace all uses of value with the new value, updating anything in - the IR that uses 'self' to use the other value instead. - """ - def set_type(self, type: Type) -> None: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> Context: - """ - Context in which the value lives. - """ - @property - def owner(self) -> _OperationBase: ... - @property - def type(self) -> Type: ... - @property - def uses(self) -> OpOperandIterator: ... - -class AffineAddExpr(AffineBinaryExpr): - @staticmethod - def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - -class AffineBinaryExpr(AffineExpr): - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - @property - def lhs(self) -> AffineExpr: ... - @property - def rhs(self) -> AffineExpr: ... - -class AffineCeilDivExpr(AffineBinaryExpr): - @staticmethod - def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - -class AffineConstantExpr(AffineExpr): - @staticmethod - def get(value: int, context: Context | None = None) -> AffineConstantExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - @property - def value(self) -> int: ... - -class AffineDimExpr(AffineExpr): - @staticmethod - def get(position: int, context: Context | None = None) -> AffineDimExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - @property - def position(self) -> int: ... - -class AffineExprList: - def __add__(self, arg0: AffineExprList) -> list[AffineExpr]: ... - -class AffineFloorDivExpr(AffineBinaryExpr): - @staticmethod - def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - -class AffineMap: - @staticmethod - def compress_unused_symbols( - arg0: list, arg1: Context | None - ) -> list[AffineMap]: ... - @staticmethod - def get( - dim_count: int, - symbol_count: int, - exprs: list, - context: Context | None = None, - ) -> AffineMap: - """ - Gets a map with the given expressions as results. - """ - @staticmethod - def get_constant(value: int, context: Context | None = None) -> AffineMap: - """ - Gets an affine map with a single constant result - """ - @staticmethod - def get_empty(context: Context | None = None) -> AffineMap: - """ - Gets an empty affine map. - """ - @staticmethod - def get_identity(n_dims: int, context: Context | None = None) -> AffineMap: - """ - Gets an identity map with the given number of dimensions. - """ - @staticmethod - def get_minor_identity( - n_dims: int, n_results: int, context: Context | None = None - ) -> AffineMap: - """ - Gets a minor identity map with the given number of dimensions and results. - """ - @staticmethod - def get_permutation( - permutation: list[int], context: Context | None = None - ) -> AffineMap: - """ - Gets an affine map that permutes its inputs. - """ - def _CAPICreate(self) -> AffineMap: ... - @overload - def __eq__(self, arg0: AffineMap) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - def get_major_submap(self, n_results: int) -> AffineMap: ... - def get_minor_submap(self, n_results: int) -> AffineMap: ... - def get_submap(self, result_positions: list[int]) -> AffineMap: ... - def replace( - self, - expr: AffineExpr, - replacement: AffineExpr, - n_result_dims: int, - n_result_syms: int, - ) -> AffineMap: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> Context: - """ - Context that owns the Affine Map - """ - @property - def is_permutation(self) -> bool: ... - @property - def is_projected_permutation(self) -> bool: ... - @property - def n_dims(self) -> int: ... - @property - def n_inputs(self) -> int: ... - @property - def n_symbols(self) -> int: ... - @property - def results(self) -> AffineMapExprList: ... - -class AffineMapAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(affine_map: AffineMap) -> AffineMapAttr: - """ - Gets an attribute wrapping an AffineMap. - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class AffineModExpr(AffineBinaryExpr): - @staticmethod - def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - -class AffineMulExpr(AffineBinaryExpr): - @staticmethod - def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - -class AffineSymbolExpr(AffineExpr): - @staticmethod - def get(position: int, context: Context | None = None) -> AffineSymbolExpr: ... - @staticmethod - def isinstance(other: AffineExpr) -> bool: ... - def __init__(self, expr: AffineExpr) -> None: ... - @property - def position(self) -> int: ... - -class ArrayAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(attributes: list, context: Context | None = None) -> ArrayAttr: - """ - Gets a uniqued Array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> ArrayAttr: ... - def __getitem__(self, arg0: int) -> Attribute: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> ArrayAttributeIterator: ... - def __len__(self) -> int: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class ArrayAttributeIterator: - def __iter__(self) -> ArrayAttributeIterator: ... - def __next__(self) -> Attribute: ... - -class AsmState: - @overload - def __init__(self, value: Value, use_local_scope: bool = False) -> None: ... - @overload - def __init__(self, op: _OperationBase, use_local_scope: bool = False) -> None: ... - -class AttrBuilder: - @staticmethod - def contains(arg0: str) -> bool: ... - @staticmethod - def get(arg0: str) -> Callable: ... - @staticmethod - def insert( - attribute_kind: str, attr_builder: Callable, replace: bool = False - ) -> None: - """ - Register an attribute builder for building MLIR attributes from python values. - """ - -class BF16Type(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> BF16Type: - """ - Create a bf16 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Block: - @staticmethod - def create_at_start( - parent: Region, - arg_types: list[Type], - arg_locs: Sequence | None = None, - ) -> Block: - """ - Creates and returns a new Block at the beginning of the given region (with given argument types and locations). - """ - @overload - def __eq__(self, arg0: Block) -> bool: ... - @overload - def __eq__(self, arg0: Any) -> bool: ... - def __hash__(self) -> int: ... - def __iter__(self) -> OperationIterator: - """ - Iterates over operations in the block. - """ - def __str__(self) -> str: - """ - Returns the assembly form of the block. - """ - def append(self, operation: _OperationBase) -> None: - """ - Appends an operation to this block. If the operation is currently in another block, it will be moved. - """ - def append_to(self, arg0: Region) -> None: - """ - Append this block to a region, transferring ownership if necessary - """ - def create_after(self, *args, arg_locs: Sequence | None = None) -> Block: - """ - Creates and returns a new Block after this block (with given argument types and locations). - """ - def create_before(self, *args, arg_locs: Sequence | None = None) -> Block: - """ - Creates and returns a new Block before this block (with given argument types and locations). - """ - @property - def _CAPIPtr(self) -> object: ... - @property - def arguments(self) -> BlockArgumentList: - """ - Returns a List of block arguments. - """ - @property - def operations(self) -> OperationList: - """ - Returns a forward-optimized sequence of operations. - """ - @property - def owner(self) -> OpView: - """ - Returns the owning operation of this block. - """ - @property - def region(self) -> Region: - """ - Returns the owning region of this block. - """ - -class BlockArgument(Value): - @staticmethod - def isinstance(other_value: Value) -> bool: ... - def __init__(self, value: Value) -> None: ... - def maybe_downcast(self) -> Any: ... - def set_type(self, type: Type) -> None: ... - @property - def arg_number(self) -> int: ... - @property - def owner(self) -> Block: ... - -class BlockArgumentList: - @overload - def __getitem__(self, arg0: int) -> BlockArgument: ... - @overload - def __getitem__(self, arg0: slice) -> BlockArgumentList: ... - def __len__(self) -> int: ... - def __add__(self, arg0: BlockArgumentList) -> list[BlockArgument]: ... - @property - def types(self) -> list[Type]: ... - -class BlockIterator: - def __iter__(self) -> BlockIterator: ... - def __next__(self) -> Block: ... - -class BlockList: - def __getitem__(self, arg0: int) -> Block: ... - def __iter__(self) -> BlockIterator: ... - def __len__(self) -> int: ... - def append(self, *args, arg_locs: Sequence | None = None) -> Block: - """ - Appends a new block, with argument types as positional args. - - Returns: - The created block. - """ - -class BoolAttr(Attribute): - @staticmethod - def get(value: bool, context: Context | None = None) -> BoolAttr: - """ - Gets an uniqued bool attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __bool__(self: Attribute) -> bool: - """ - Converts the value of the bool attribute to a Python bool - """ - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> bool: - """ - Returns the value of the bool attribute - """ - -class ComplexType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(arg0: Type) -> ComplexType: - """ - Create a complex type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def element_type(self) -> Type: - """ - Returns element type. - """ - @property - def typeid(self) -> TypeID: ... - -class Context: - current: ClassVar[Context] = ... # read-only - allow_unregistered_dialects: bool - @staticmethod - def _get_live_count() -> int: ... - def _CAPICreate(self) -> object: ... - def __enter__(self) -> Context: ... - def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... - def __init__(self) -> None: ... - def _clear_live_operations(self) -> int: ... - def _get_context_again(self) -> Context: ... - def _get_live_module_count(self) -> int: ... - def _get_live_operation_count(self) -> int: ... - def _get_live_operation_objects(self) -> list[Operation]: ... - def append_dialect_registry(self, registry: DialectRegistry) -> None: ... - def attach_diagnostic_handler( - self, callback: Callable[[Diagnostic], bool] - ) -> DiagnosticHandler: - """ - Attaches a diagnostic handler that will receive callbacks - """ - def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: - """ - Gets or loads a dialect by name, returning its descriptor object - """ - def is_registered_operation(self, operation_name: str) -> bool: ... - def load_all_available_dialects(self) -> None: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def d(self) -> Dialects: - """ - Alias for 'dialect' - """ - @property - def dialects(self) -> Dialects: - """ - Gets a container for accessing dialects by name - """ - -class DenseBoolArrayAttr(Attribute): - @staticmethod - def get( - values: Sequence[bool], context: Context | None = None - ) -> DenseBoolArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseBoolArrayAttr: ... - def __getitem__(self, arg0: int) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseBoolArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseBoolArrayIterator: - def __iter__(self) -> DenseBoolArrayIterator: ... - def __next__(self) -> bool: ... - -class DenseElementsAttr(Attribute): - @staticmethod - def get( - array: Buffer, - signless: bool = True, - type: Type | None = None, - shape: list[int] | None = None, - context: Context | None = None, - ) -> DenseElementsAttr: - """ - Gets a DenseElementsAttr from a Python buffer or array. - - When `type` is not provided, then some limited type inferencing is done based - on the buffer format. Support presently exists for 8/16/32/64 signed and - unsigned integers and float16/float32/float64. DenseElementsAttrs of these - types can also be converted back to a corresponding buffer. - - For conversions outside of these types, a `type=` must be explicitly provided - and the buffer contents must be bit-castable to the MLIR internal - representation: - - * Integer types (except for i1): the buffer must be byte aligned to the - next byte boundary. - * Floating point types: Must be bit-castable to the given floating point - size. - * i1 (bool): Bit packed into 8bit words where the bit pattern matches a - row major ordering. An arbitrary Numpy `bool_` array can be bit packed to - this specification with: `np.packbits(ary, axis=None, bitorder='little')`. - - If a single element buffer is passed (or for i1, a single byte with value 0 - or 255), then a splat will be created. - - Args: - array: The array or buffer to convert. - signless: If inferring an appropriate MLIR type, use signless types for - integers (defaults True). - type: Skips inference of the MLIR element type and uses this instead. The - storage size must be consistent with the actual contents of the buffer. - shape: Overrides the shape of the buffer when constructing the MLIR - shaped type. This is needed when the physical and logical shape differ (as - for i1). - context: Explicit context, if not from context manager. - - Returns: - DenseElementsAttr on success. - - Raises: - ValueError: If the type of the buffer or array cannot be matched to an MLIR - type or if the buffer does not meet expectations. - """ - @staticmethod - def get_splat(shaped_type: Type, element_attr: Attribute) -> DenseElementsAttr: - """ - Gets a DenseElementsAttr where all values are the same - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __len__(self) -> int: ... - def get_splat_value(self) -> Attribute: ... - @property - def is_splat(self) -> bool: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseF32ArrayAttr(Attribute): - @staticmethod - def get( - values: Sequence[float], context: Context | None = None - ) -> DenseF32ArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseF32ArrayAttr: ... - def __getitem__(self, arg0: int) -> float: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseF32ArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseF32ArrayIterator: - def __iter__(self) -> DenseF32ArrayIterator: ... - def __next__(self) -> float: ... - -class DenseF64ArrayAttr(Attribute): - @staticmethod - def get( - values: Sequence[float], context: Context | None = None - ) -> DenseF64ArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseF64ArrayAttr: ... - def __getitem__(self, arg0: int) -> float: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseF64ArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseF64ArrayIterator: - def __iter__(self) -> DenseF64ArrayIterator: ... - def __next__(self) -> float: ... - -class DenseFPElementsAttr(DenseElementsAttr): - @staticmethod - def get( - array: Buffer, - signless: bool = True, - type: Type | None = None, - shape: list[int] | None = None, - context: Context | None = None, - ) -> DenseFPElementsAttr: ... - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __getitem__(self, arg0: int) -> float: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseI16ArrayAttr(Attribute): - @staticmethod - def get(values: Sequence[int], context: Context | None = None) -> DenseI16ArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseI16ArrayAttr: ... - def __getitem__(self, arg0: int) -> int: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseI16ArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseI16ArrayIterator: - def __iter__(self) -> DenseI16ArrayIterator: ... - def __next__(self) -> int: ... - -class DenseI32ArrayAttr(Attribute): - @staticmethod - def get(values: Sequence[int], context: Context | None = None) -> DenseI32ArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseI32ArrayAttr: ... - def __getitem__(self, arg0: int) -> int: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseI32ArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseI32ArrayIterator: - def __iter__(self) -> DenseI32ArrayIterator: ... - def __next__(self) -> int: ... - -class DenseI64ArrayAttr(Attribute): - @staticmethod - def get(values: Sequence[int], context: Context | None = None) -> DenseI64ArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseI64ArrayAttr: ... - def __getitem__(self, arg0: int) -> int: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseI16ArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseI64ArrayIterator: - def __iter__(self) -> DenseI64ArrayIterator: ... - def __next__(self) -> int: ... - -class DenseI8ArrayAttr(Attribute): - @staticmethod - def get(values: Sequence[int], context: Context | None = None) -> DenseI8ArrayAttr: - """ - Gets a uniqued dense array attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: list) -> DenseI8ArrayAttr: ... - def __getitem__(self, arg0: int) -> int: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __iter__( - self, - ) -> DenseI8ArrayIterator: ... - def __len__(self) -> int: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseI8ArrayIterator: - def __iter__(self) -> DenseI8ArrayIterator: ... - def __next__(self) -> int: ... - -class DenseIntElementsAttr(DenseElementsAttr): - @staticmethod - def get( - array: Buffer, - signless: bool = True, - type: Type | None = None, - shape: list[int] | None = None, - context: Context | None = None, - ) -> DenseIntElementsAttr: ... - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __getitem__(self, arg0: int) -> int: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class DenseResourceElementsAttr(Attribute): - @staticmethod - def get_from_buffer( - array: Buffer, - name: str, - type: Type, - alignment: int | None = None, - is_mutable: bool = False, - context: Context | None = None, - ) -> DenseResourceElementsAttr: - """ - Gets a DenseResourceElementsAttr from a Python buffer or array. - - This function does minimal validation or massaging of the data, and it is - up to the caller to ensure that the buffer meets the characteristics - implied by the shape. - - The backing buffer and any user objects will be retained for the lifetime - of the resource blob. This is typically bounded to the context but the - resource can have a shorter lifespan depending on how it is used in - subsequent processing. - - Args: - buffer: The array or buffer to convert. - name: Name to provide to the resource (may be changed upon collision). - type: The explicit ShapedType to construct the attribute with. - context: Explicit context, if not from context manager. - - Returns: - DenseResourceElementsAttr on success. - - Raises: - ValueError: If the type of the buffer or array cannot be matched to an MLIR - type or if the buffer does not meet expectations. - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class Diagnostic: - @property - def location(self) -> Location: ... - @property - def message(self) -> str: ... - @property - def notes(self) -> tuple[Diagnostic]: ... - @property - def severity(self) -> DiagnosticSeverity: ... - -class DiagnosticHandler: - def __enter__(self) -> DiagnosticHandler: ... - def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... - def detach(self) -> None: ... - @property - def attached(self) -> bool: ... - @property - def had_error(self) -> bool: ... - -class DiagnosticInfo: - def __init__(self, arg0: Diagnostic) -> None: ... - @property - def location(self) -> Location: ... - @property - def message(self) -> str: ... - @property - def notes(self) -> list[DiagnosticInfo]: ... - @property - def severity(self) -> DiagnosticSeverity: ... - -class DiagnosticSeverity: - """ - Members: - - ERROR - - WARNING - - NOTE - - REMARK - """ - - ERROR: ClassVar[DiagnosticSeverity] # value = - NOTE: ClassVar[DiagnosticSeverity] # value = - REMARK: ClassVar[DiagnosticSeverity] # value = - WARNING: ClassVar[DiagnosticSeverity] # value = - __members__: ClassVar[ - dict[str, DiagnosticSeverity] - ] # value = {'ERROR': , 'WARNING': , 'NOTE': , 'REMARK': } - def __eq__(self, other: Any) -> bool: ... - def __getstate__(self) -> int: ... - def __hash__(self) -> int: ... - def __index__(self) -> int: ... - def __init__(self, value: int) -> None: ... - def __int__(self) -> int: ... - def __ne__(self, other: Any) -> bool: ... - def __setstate__(self, state: int) -> None: ... - @property - def name(self) -> str: ... - @property - def value(self) -> int: ... - -class Dialect: - def __init__(self, descriptor: DialectDescriptor) -> None: ... - @property - def descriptor(self) -> DialectDescriptor: ... - -class DialectDescriptor: - @property - def namespace(self) -> str: ... - -class DialectRegistry: - def _CAPICreate(self) -> DialectRegistry: ... - def __init__(self) -> None: ... - @property - def _CAPIPtr(self) -> object: ... - -class Dialects: - def __getattr__(self, arg0: str) -> Dialect: ... - def __getitem__(self, arg0: str) -> Dialect: ... - -class DictAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(value: dict = {}, context: Context | None = None) -> DictAttr: - """ - Gets an uniqued Dict attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __contains__(self, arg0: str) -> bool: ... - @overload - def __getitem__(self, arg0: str) -> Attribute: ... - @overload - def __getitem__(self, arg0: int) -> NamedAttribute: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __len__(self) -> int: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class FloatType(Type): - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def width(self) -> int: - """ - Returns the width of the floating-point type. - """ - -class F16Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> F16Type: - """ - Create a f16 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class F32Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> F32Type: - """ - Create a f32 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class F64Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> F64Type: - """ - Create a f64 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class FlatSymbolRefAttr(Attribute): - @staticmethod - def get(value: str, context: Context | None = None) -> FlatSymbolRefAttr: - """ - Gets a uniqued FlatSymbolRef attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> str: - """ - Returns the value of the FlatSymbolRef attribute as a string - """ - -class Float4E2M1FNType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float4E2M1FNType: - """ - Create a float4_e2m1fn type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float6E2M3FNType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float6E2M3FNType: - """ - Create a float6_e2m3fn type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float6E3M2FNType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float6E3M2FNType: - """ - Create a float6_e3m2fn type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E3M4Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E3M4Type: - """ - Create a float8_e3m4 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E4M3B11FNUZType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E4M3B11FNUZType: - """ - Create a float8_e4m3b11fnuz type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E4M3FNType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E4M3FNType: - """ - Create a float8_e4m3fn type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E4M3FNUZType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E4M3FNUZType: - """ - Create a float8_e4m3fnuz type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E4M3Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E4M3Type: - """ - Create a float8_e4m3 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E5M2FNUZType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E5M2FNUZType: - """ - Create a float8_e5m2fnuz type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E5M2Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E5M2Type: - """ - Create a float8_e5m2 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class Float8E8M0FNUType(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> Float8E8M0FNUType: - """ - Create a float8_e8m0fnu type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class FloatAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(type: Type, value: float, loc: Location | None = None) -> FloatAttr: - """ - Gets an uniqued float point attribute associated to a type - """ - @staticmethod - def get_f32(value: float, context: Context | None = None) -> FloatAttr: - """ - Gets an uniqued float point attribute associated to a f32 type - """ - @staticmethod - def get_f64(value: float, context: Context | None = None) -> FloatAttr: - """ - Gets an uniqued float point attribute associated to a f64 type - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __float__(self: Attribute) -> float: - """ - Converts the value of the float attribute to a Python float - """ - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> float: - """ - Returns the value of the float attribute - """ - -class FloatTF32Type(FloatType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> FloatTF32Type: - """ - Create a tf32 type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class FunctionType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - inputs: list[Type], results: list[Type], context: Context | None = None - ) -> FunctionType: - """ - Gets a FunctionType from a List of input and result types - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def inputs(self) -> list: - """ - Returns the List of input types in the FunctionType. - """ - @property - def results(self) -> list: - """ - Returns the List of result types in the FunctionType. - """ - @property - def typeid(self) -> TypeID: ... - -class IndexType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> IndexType: - """ - Create a index type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class InferShapedTypeOpInterface: - def __init__(self, object: object, context: Context | None = None) -> None: - """ - Creates an interface from a given operation/opview object or from a - subclass of OpView. Raises ValueError if the operation does not implement the - interface. - """ - def inferReturnTypeComponents( - self, - operands: list | None = None, - attributes: Attribute | None = None, - properties=None, - regions: list[Region] | None = None, - context: Context | None = None, - loc: Location | None = None, - ) -> list[ShapedTypeComponents]: - """ - Given the arguments required to build an operation, attempts to infer - its return shaped type components. Raises ValueError on failure. - """ - @property - def operation(self) -> Operation: - """ - Returns an Operation for which the interface was constructed. - """ - @property - def opview(self) -> OpView: - """ - Returns an OpView subclass _instance_ for which the interface was - constructed - """ - -class InferTypeOpInterface: - def __init__(self, object: object, context: Context | None = None) -> None: - """ - Creates an interface from a given operation/opview object or from a - subclass of OpView. Raises ValueError if the operation does not implement the - interface. - """ - def inferReturnTypes( - self, - operands: list | None = None, - attributes: Attribute | None = None, - properties=None, - regions: list[Region] | None = None, - context: Context | None = None, - loc: Location | None = None, - ) -> list[Type]: - """ - Given the arguments required to build an operation, attempts to infer - its return types. Raises ValueError on failure. - """ - @property - def operation(self) -> Operation: - """ - Returns an Operation for which the interface was constructed. - """ - @property - def opview(self) -> OpView: - """ - Returns an OpView subclass _instance_ for which the interface was - constructed - """ - -class InsertionPoint: - current: ClassVar[InsertionPoint] = ... # read-only - @staticmethod - def at_block_begin(block: Block) -> InsertionPoint: - """ - Inserts at the beginning of the block. - """ - @staticmethod - def at_block_terminator(block: Block) -> InsertionPoint: - """ - Inserts before the block terminator. - """ - def __enter__(self) -> InsertionPoint: ... - def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... - @overload - def __init__(self, block: Block) -> None: - """ - Inserts after the last operation but still inside the block. - """ - @overload - def __init__(self, beforeOperation: _OperationBase) -> None: - """ - Inserts before a referenced operation. - """ - def insert(self, operation: _OperationBase) -> None: - """ - Inserts an operation. - """ - @property - def block(self) -> Block: - """ - Returns the block that this InsertionPoint points to. - """ - @property - def ref_operation(self) -> _OperationBase | None: - """ - The reference operation before which new operations are inserted, or None if the insertion point is at the end of the block - """ - -class IntegerAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(type: Type, value: int) -> IntegerAttr: - """ - Gets an uniqued integer attribute associated to a type - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - def __int__(self) -> int: - """ - Converts the value of the integer attribute to a Python int - """ - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> int: - """ - Returns the value of the integer attribute - """ - -class IntegerSet: - @staticmethod - def get( - num_dims: int, - num_symbols: int, - exprs: list, - eq_flags: list[bool], - context: Context | None = None, - ) -> IntegerSet: ... - @staticmethod - def get_empty( - num_dims: int, num_symbols: int, context: Context | None = None - ) -> IntegerSet: ... - def _CAPICreate(self) -> IntegerSet: ... - @overload - def __eq__(self, arg0: IntegerSet) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - def get_replaced( - self, - dim_exprs: list, - symbol_exprs: list, - num_result_dims: int, - num_result_symbols: int, - ) -> IntegerSet: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def constraints(self) -> IntegerSetConstraintList: ... - @property - def context(self) -> Context: ... - @property - def is_canonical_empty(self) -> bool: ... - @property - def n_dims(self) -> int: ... - @property - def n_equalities(self) -> int: ... - @property - def n_inequalities(self) -> int: ... - @property - def n_inputs(self) -> int: ... - @property - def n_symbols(self) -> int: ... - -class IntegerSetAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(integer_set) -> IntegerSetAttr: - """ - Gets an attribute wrapping an IntegerSet. - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class IntegerSetConstraint: - def __init__(self, *args, **kwargs) -> None: ... - @property - def expr(self) -> AffineExpr: ... - @property - def is_eq(self) -> bool: ... - -class IntegerSetConstraintList: - def __init__(self, *args, **kwargs) -> None: ... - def __add__(self, arg0: IntegerSetConstraintList) -> list[IntegerSetConstraint]: ... - @overload - def __getitem__(self, arg0: int) -> IntegerSetConstraint: ... - @overload - def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ... - def __len__(self) -> int: ... - -class IntegerType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get_signed(width: int, context: Context | None = None) -> IntegerType: - """ - Create a signed integer type - """ - @staticmethod - def get_signless(width: int, context: Context | None = None) -> IntegerType: - """ - Create a signless integer type - """ - @staticmethod - def get_unsigned(width: int, context: Context | None = None) -> IntegerType: - """ - Create an unsigned integer type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def is_signed(self) -> bool: - """ - Returns whether this is a signed integer - """ - @property - def is_signless(self) -> bool: - """ - Returns whether this is a signless integer - """ - @property - def is_unsigned(self) -> bool: - """ - Returns whether this is an unsigned integer - """ - @property - def typeid(self) -> TypeID: ... - @property - def width(self) -> int: - """ - Returns the width of the integer type - """ - -class Location: - current: ClassVar[Location] = ... # read-only - __hash__: ClassVar[None] = None - @staticmethod - def callsite( - callee: Location, frames: Sequence[Location], context: Context | None = None - ) -> Location: - """ - Gets a Location representing a caller and callsite - """ - @staticmethod - def file( - filename: str, line: int, col: int, context: Context | None = None - ) -> Location: - """ - Gets a Location representing a file, line and column - """ - @staticmethod - def from_attr(attribute: Attribute, context: Context | None = None) -> Location: - """ - Gets a Location from a LocationAttr - """ - @staticmethod - def fused( - locations: Sequence[Location], - metadata: Attribute | None = None, - context: Context | None = None, - ) -> Location: - """ - Gets a Location representing a fused location with optional metadata - """ - @staticmethod - def name( - name: str, - childLoc: Location | None = None, - context: Context | None = None, - ) -> Location: - """ - Gets a Location representing a named location with optional child location - """ - @staticmethod - def unknown(context: Context | None = None) -> Location: - """ - Gets a Location representing an unknown location - """ - def _CAPICreate(self) -> Location: ... - def __enter__(self) -> Location: ... - @overload - def __eq__(self, arg0: Location) -> bool: ... - @overload - def __eq__(self, arg0: Location) -> bool: ... - def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... - def emit_error(self, message: str) -> None: - """ - Emits an error at this location - """ - @property - def _CAPIPtr(self) -> object: ... - @property - def attr(self) -> Attribute: - """ - Get the underlying LocationAttr - """ - @property - def context(self) -> Context: - """ - Context that owns the Location - """ - -class MemRefType(ShapedType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - shape: list[int], - element_type: Type, - layout: Attribute = None, - memory_space: Attribute = None, - loc: Location | None = None, - ) -> MemRefType: - """ - Create a memref type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def affine_map(self) -> AffineMap: - """ - The layout of the MemRef type as an affine map. - """ - @property - def layout(self) -> Attribute: - """ - The layout of the MemRef type. - """ - @property - def memory_space(self) -> Attribute | None: - """ - Returns the memory space of the given MemRef type. - """ - @property - def typeid(self) -> TypeID: ... - def get_strides_and_offset(self) -> tuple[list[int], int]: - """ - The strides and offset of the MemRef type. - """ - -class Module: - @staticmethod - def create(loc: Location | None = None) -> Module: - """ - Creates an empty module - """ - @staticmethod - def parse(asm: str | bytes, context: Context | None = None) -> Module: - """ - Parses a module's assembly format from a string. - - Returns a new MlirModule or raises an MLIRError if the parsing fails. - - See also: https://mlir.llvm.org/docs/LangRef/ - """ - @staticmethod - def parseFile(path: str, context: Context | None = None) -> Module: - """ - Parses a module's assembly format from file. - - Returns a new MlirModule or raises an MLIRError if the parsing fails. - - See also: https://mlir.llvm.org/docs/LangRef/ - """ - def _CAPICreate(self) -> Any: ... - def __str__(self) -> str: - """ - Gets the assembly form of the operation with default options. - - If more advanced control over the assembly formatting or I/O options is needed, - use the dedicated print or get_asm method, which supports keyword arguments to - customize behavior. - """ - def dump(self) -> None: - """ - Dumps a debug representation of the object to stderr. - """ - @property - def _CAPIPtr(self) -> object: ... - @property - def body(self) -> Block: - """ - Return the block for this module - """ - @property - def context(self) -> Context: - """ - Context that created the Module - """ - @property - def operation(self) -> Operation: - """ - Accesses the module as an operation - """ - -class MLIRError(Exception): - def __init__( - self, message: str, error_diagnostics: list[DiagnosticInfo] - ) -> None: ... - -class NamedAttribute: - @property - def attr(self) -> Attribute: - """ - The underlying generic attribute of the NamedAttribute binding - """ - @property - def name(self) -> str: - """ - The name of the NamedAttribute binding - """ - -class NoneType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> NoneType: - """ - Create a none type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class OpAttributeMap: - def __contains__(self, arg0: str) -> bool: ... - def __delitem__(self, arg0: str) -> None: ... - @overload - def __getitem__(self, arg0: str) -> Attribute: ... - @overload - def __getitem__(self, arg0: int) -> NamedAttribute: ... - def __len__(self) -> int: ... - def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... - -class OpOperand: - @property - def operand_number(self) -> int: ... - @property - def owner(self) -> _OperationBase: ... - -class OpOperandIterator: - def __iter__(self) -> OpOperandIterator: ... - def __next__(self) -> OpOperand: ... - -class OpOperandList: - def __add__(self, arg0: OpOperandList) -> list[Value]: ... - @overload - def __getitem__(self, arg0: int) -> Value: ... - @overload - def __getitem__(self, arg0: slice) -> OpOperandList: ... - def __len__(self) -> int: ... - def __setitem__(self, arg0: int, arg1: Value) -> None: ... - -class OpResult(Value): - @staticmethod - def isinstance(other_value: Value) -> bool: ... - def __init__(self, value: Value) -> None: ... - @staticmethod - def isinstance(arg: Any) -> bool: ... - @property - def owner(self) -> _OperationBase: ... - @property - def result_number(self) -> int: ... - -class OpResultList: - def __add__(self, arg0: OpResultList) -> list[OpResult]: ... - @overload - def __getitem__(self, arg0: int) -> OpResult: ... - @overload - def __getitem__(self, arg0: slice) -> OpResultList: ... - def __len__(self) -> int: ... - @property - def owner(self) -> _OperationBase: ... - @property - def types(self) -> list[Type]: ... - -class OpSuccessors: - def __add__(self, arg0: OpSuccessors) -> list[Block]: ... - @overload - def __getitem__(self, arg0: int) -> Block: ... - @overload - def __getitem__(self, arg0: slice) -> OpSuccessors: ... - def __setitem__(self, arg0: int, arg1: Block) -> None: ... - def __len__(self) -> int: ... - -class OpView(_OperationBase): - _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... - _ODS_REGIONS: ClassVar[tuple] = ... - _ODS_RESULT_SEGMENTS: ClassVar[None] = ... - def __init__(self, operation: _OperationBase) -> None: ... - @classmethod - def build_generic( - cls: type[_TOperation], - results: Sequence[Type] | None = None, - operands: Sequence[Value] | None = None, - attributes: dict[str, Attribute] | None = None, - successors: Sequence[Block] | None = None, - regions: int | None = None, - loc: Location | None = None, - ip: InsertionPoint | None = None, - ) -> _TOperation: - """ - Builds a specific, generated OpView based on class level attributes. - """ - @classmethod - def parse( - cls: type[_TOperation], - source: str | bytes, - *, - source_name: str = "", - context: Context | None = None, - ) -> _TOperation: - """ - Parses a specific, generated OpView based on class level attributes - """ - def __init__(self, operation: _OperationBase) -> None: ... - @property - def operation(self) -> _OperationBase: ... - @property - def opview(self) -> OpView: ... - @property - def successors(self) -> OpSuccessors: - """ - Returns the List of Operation successors. - """ - -class OpaqueAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - dialect_namespace: str, - buffer: Buffer, - type: Type, - context: Context | None = None, - ) -> OpaqueAttr: - """ - Gets an Opaque attribute. - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def data(self) -> bytes: - """ - Returns the data for the Opaqued attributes as `bytes` - """ - @property - def dialect_namespace(self) -> str: - """ - Returns the dialect namespace for the Opaque attribute as a string - """ - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class OpaqueType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - dialect_namespace: str, buffer: str, context: Context | None = None - ) -> OpaqueType: - """ - Create an unregistered (opaque) dialect type. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def data(self) -> str: - """ - Returns the data for the Opaque type as a string. - """ - @property - def dialect_namespace(self) -> str: - """ - Returns the dialect namespace for the Opaque type as a string. - """ - @property - def typeid(self) -> TypeID: ... - -class Operation(_OperationBase): - def _CAPICreate(self) -> object: ... - @staticmethod - def create( - name: str, - results: Sequence[Type] | None = None, - operands: Sequence[Value] | None = None, - attributes: dict[str, Attribute] | None = None, - successors: Sequence[Block] | None = None, - regions: int = 0, - loc: Location | None = None, - ip: InsertionPoint | None = None, - infer_type: bool = False, - ) -> Operation: - """ - Creates a new operation. - - Args: - name: Operation name (e.g. "dialect.operation"). - results: Sequence of Type representing op result types. - attributes: Dict of str:Attribute. - successors: List of Block for the operation's successors. - regions: Number of regions to create. - loc: A Location object (defaults to resolve from context manager). - ip: An InsertionPoint (defaults to resolve from context manager or set to - False to disable insertion, even with an insertion point set in the - context manager). - infer_type: Whether to infer result types. - Returns: - A new "detached" Operation object. Detached operations can be added - to blocks, which causes them to become "attached." - """ - @staticmethod - def parse( - source: str | bytes, *, source_name: str = "", context: Context | None = None - ) -> Operation: - """ - Parses an operation. Supports both text assembly format and binary bytecode format. - """ - def _CAPICreate(self) -> object: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def operation(self) -> Operation: ... - @property - def opview(self) -> OpView: ... - @property - def successors(self) -> OpSuccessors: - """ - Returns the List of Operation successors. - """ - -class OperationIterator: - def __iter__(self) -> OperationIterator: ... - def __next__(self) -> OpView: ... - -class OperationList: - def __getitem__(self, arg0: int) -> OpView: ... - def __iter__(self) -> OperationIterator: ... - def __len__(self) -> int: ... - -class RankedTensorType(ShapedType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - shape: list[int], - element_type: Type, - encoding: Attribute | None = None, - loc: Location | None = None, - ) -> RankedTensorType: - """ - Create a ranked tensor type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def encoding(self) -> Attribute | None: ... - @property - def typeid(self) -> TypeID: ... - -class Region: - __hash__: ClassVar[None] = None - @overload - def __eq__(self, arg0: Region) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __iter__(self) -> BlockIterator: - """ - Iterates over blocks in the region. - """ - @property - def blocks(self) -> BlockList: - """ - Returns a forward-optimized sequence of blocks. - """ - @property - def owner(self) -> OpView: - """ - Returns the operation owning this region. - """ - -class RegionIterator: - def __iter__(self) -> RegionIterator: ... - def __next__(self) -> Region: ... - -class RegionSequence: - @overload - def __getitem__(self, arg0: int) -> Region: ... - @overload - def __getitem__(self, arg0: slice) -> Sequence[Region]: ... - def __iter__(self) -> RegionIterator: ... - def __len__(self) -> int: ... - -class ShapedType(Type): - @staticmethod - def get_dynamic_size() -> int: - """ - Returns the value used to indicate dynamic dimensions in shaped types. - """ - @staticmethod - def get_dynamic_stride_or_offset() -> int: - """ - Returns the value used to indicate dynamic strides or offsets in shaped types. - """ - @staticmethod - def is_dynamic_size(dim_size: int) -> bool: - """ - Returns whether the given dimension size indicates a dynamic dimension. - """ - @staticmethod - def is_static_size(dim_size: int) -> bool: - """ - Returns whether the given dimension size indicates a static dimension. - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - def get_dim_size(self, dim: int) -> int: - """ - Returns the dim-th dimension of the given ranked shaped type. - """ - def is_dynamic_dim(self, dim: int) -> bool: - """ - Returns whether the dim-th dimension of the given shaped type is dynamic. - """ - def is_static_dim(self, dim: int) -> bool: - """ - Returns whether the dim-th dimension of the given shaped type is static. - """ - def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: - """ - Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types. - """ - def is_static_stride_or_offset(self, dim_size: int) -> bool: - """ - Returns whether the given shaped type stride or offset value is statically-sized. - """ - @property - def element_type(self) -> Type: - """ - Returns the element type of the shaped type. - """ - @property - def has_rank(self) -> bool: - """ - Returns whether the given shaped type is ranked. - """ - @property - def has_static_shape(self) -> bool: - """ - Returns whether the given shaped type has a static shape. - """ - @property - def rank(self) -> int: - """ - Returns the rank of the given ranked shaped type. - """ - @property - def shape(self) -> list[int]: - """ - Returns the shape of the ranked shaped type as a List of integers. - """ - @property - def static_typeid(self) -> TypeID: ... - @property - def typeid(self) -> TypeID: ... - -class ShapedTypeComponents: - @staticmethod - @overload - def get(element_type: Type) -> ShapedTypeComponents: - """ - Create an shaped type components object with only the element type. - """ - @staticmethod - @overload - def get(shape: list, element_type: Type) -> ShapedTypeComponents: - """ - Create a ranked shaped type components object. - """ - @staticmethod - @overload - def get( - shape: list, element_type: Type, attribute: Attribute - ) -> ShapedTypeComponents: - """ - Create a ranked shaped type components object with attribute. - """ - @property - def element_type(self) -> Type: - """ - Returns the element type of the shaped type components. - """ - @property - def has_rank(self) -> bool: - """ - Returns whether the given shaped type component is ranked. - """ - @property - def rank(self) -> int: - """ - Returns the rank of the given ranked shaped type components. If the shaped type components does not have a rank, None is returned. - """ - @property - def shape(self) -> list[int]: - """ - Returns the shape of the ranked shaped type components as a List of integers. Returns none if the shaped type component does not have a rank. - """ - -class StridedLayoutAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - offset: int, strides: list[int], context: Context | None = None - ) -> StridedLayoutAttr: - """ - Gets a strided layout attribute. - """ - @staticmethod - def get_fully_dynamic( - rank: int, context: Context | None = None - ) -> StridedLayoutAttr: - """ - Gets a strided layout attribute with dynamic offset and strides of a given rank. - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def offset(self) -> int: - """ - Returns the value of the float point attribute - """ - @property - def strides(self) -> list[int]: - """ - Returns the value of the float point attribute - """ - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class StringAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(value: str | bytes, context: Context | None = None) -> StringAttr: - """ - Gets a uniqued string attribute - """ - @staticmethod - def get_typed(type: Type, value: str) -> StringAttr: - """ - Gets a uniqued string attribute associated to a type - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> str: - """ - Returns the value of the string attribute - """ - @property - def value_bytes(self) -> bytes: - """ - Returns the value of the string attribute as `bytes` - """ - -class SymbolRefAttr(Attribute): - @staticmethod - def get(symbols: list[str], context: Context | None = None) -> Attribute: - """ - Gets a uniqued SymbolRef attribute from a List of symbol names - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def static_typeid(self) -> TypeID: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> list[str]: - """ - Returns the value of the SymbolRef attribute as a List[str] - """ - -class SymbolTable: - @staticmethod - def get_symbol_name(symbol: _OperationBase) -> Attribute: ... - @staticmethod - def get_visibility(symbol: _OperationBase) -> Attribute: ... - @staticmethod - def replace_all_symbol_uses( - old_symbol: str, new_symbol: str, from_op: _OperationBase - ) -> None: ... - @staticmethod - def set_symbol_name(symbol: _OperationBase, name: str) -> None: ... - @staticmethod - def set_visibility(symbol: _OperationBase, visibility: str) -> None: ... - @staticmethod - def walk_symbol_tables( - from_op: _OperationBase, - all_sym_uses_visible: bool, - callback: Callable[[_OperationBase, bool], None], - ) -> None: ... - def __contains__(self, arg0: str) -> bool: ... - def __delitem__(self, arg0: str) -> None: ... - def __getitem__(self, arg0: str) -> OpView: ... - def __init__(self, arg0: _OperationBase) -> None: ... - def erase(self, operation: _OperationBase) -> None: ... - def insert(self, operation: _OperationBase) -> Attribute: ... - -class TupleType(Type): - static_typeid: ClassVar[TypeID] - @staticmethod - def get_tuple(elements: list[Type], context: Context | None = None) -> TupleType: - """ - Create a Tuple type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - def get_type(self, pos: int) -> Type: - """ - Returns the pos-th type in the Tuple type. - """ - @property - def num_types(self) -> int: - """ - Returns the number of types contained in a Tuple. - """ - @property - def typeid(self) -> TypeID: ... - -class TypeAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(value: Type, context: Context | None = None) -> TypeAttr: - """ - Gets a uniqued Type attribute - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - @property - def value(self) -> Type: ... - -class TypeID: - def _CAPICreate(self) -> TypeID: ... - @overload - def __eq__(self, arg0: TypeID) -> bool: ... - @overload - def __eq__(self, arg0: Any) -> bool: ... - def __hash__(self) -> int: ... - @property - def _CAPIPtr(self) -> object: ... - -class UnitAttr(Attribute): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(context: Context | None = None) -> UnitAttr: - """ - Create a Unit attribute. - """ - @staticmethod - def isinstance(other: Attribute) -> bool: ... - def __init__(self, cast_from_attr: Attribute) -> None: ... - @property - def type(self) -> Type: ... - @property - def typeid(self) -> TypeID: ... - -class UnrankedMemRefType(ShapedType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - element_type: Type, memory_space: Attribute, loc: Location | None = None - ) -> UnrankedMemRefType: - """ - Create a unranked memref type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def memory_space(self) -> Attribute | None: - """ - Returns the memory space of the given Unranked MemRef type. - """ - @property - def typeid(self) -> TypeID: ... - -class UnrankedTensorType(ShapedType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get(element_type: Type, loc: Location | None = None) -> UnrankedTensorType: - """ - Create a unranked tensor type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def typeid(self) -> TypeID: ... - -class VectorType(ShapedType): - static_typeid: ClassVar[TypeID] - @staticmethod - def get( - shape: list[int], - element_type: Type, - *, - scalable: list | None = None, - scalable_dims: list[int] | None = None, - loc: Location | None = None, - ) -> VectorType: - """ - Create a vector type - """ - @staticmethod - def isinstance(other: Type) -> bool: ... - def __init__(self, cast_from_type: Type) -> None: ... - @property - def scalable(self) -> bool: ... - @property - def scalable_dims(self) -> list[bool]: ... - @property - def typeid(self) -> TypeID: ... - -class _GlobalDebug: - flag: ClassVar[bool] = False diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi deleted file mode 100644 index 1010dadda..000000000 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ /dev/null @@ -1,36 +0,0 @@ -# Originally imported via: -# stubgen {...} -m mlir._mlir_libs._mlir.passmanager -# Local modifications: -# * Relative imports for cross-module references. -# * Add __all__ - - -from . import ir as _ir - -__all__ = [ - "PassManager", -] - -class PassManager: - def __init__(self, context: _ir.Context | None = None) -> None: ... - def _CAPICreate(self) -> object: ... - def _testing_release(self) -> None: ... - def enable_ir_printing( - self, - print_before_all: bool = False, - print_after_all: bool = True, - print_module_scope: bool = False, - print_after_change: bool = False, - print_after_failure: bool = False, - large_elements_limit: int | None = None, - large_resource_limit: int | None = None, - enable_debug_info: bool = False, - print_generic_op_form: bool = False, - tree_printing_dir_path: str | None = None, - ) -> None: ... - def enable_verifier(self, enable: bool) -> None: ... - @staticmethod - def parse(pipeline: str, context: _ir.Context | None = None) -> PassManager: ... - def run(self, module: _ir._OperationBase) -> None: ... - @property - def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/py.typed b/mlir/python/mlir/_mlir_libs/_mlir/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/mlir/python/mlir/dialects/IRDLOps.td b/mlir/python/mlir/dialects/IRDLOps.td new file mode 100644 index 000000000..7b061fcf3 --- /dev/null +++ b/mlir/python/mlir/dialects/IRDLOps.td @@ -0,0 +1,14 @@ +//===-- IRDLOps.td - Entry point for IRDL binding ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_IRDL_OPS +#define PYTHON_BINDINGS_IRDL_OPS + +include "mlir/Dialect/IRDL/IR/IRDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/TransformSMTExtensionOps.td b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td new file mode 100644 index 000000000..3e92417a3 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformSMTExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformSMTExtensionOps.td - Binding entry point --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the SMT extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS + +include "mlir/Dialect/Transform/SMTExtension/SMTExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_SMT_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/UBOps.td b/mlir/python/mlir/dialects/UBOps.td new file mode 100644 index 000000000..b84e7f15f --- /dev/null +++ b/mlir/python/mlir/dialects/UBOps.td @@ -0,0 +1,14 @@ +//===-- UBOps.td - Entry point for UB bindings -------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_UB_OPS +#define PYTHON_BINDINGS_UB_OPS + +include "mlir/Dialect/UB/IR/UBOps.td" + +#endif // PYTHON_BINDINGS_UB_OPS diff --git a/mlir/python/mlir/dialects/irdl.py b/mlir/python/mlir/dialects/irdl.py new file mode 100644 index 000000000..1ec951b69 --- /dev/null +++ b/mlir/python/mlir/dialects/irdl.py @@ -0,0 +1,92 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._irdl_ops_gen import * +from ._irdl_ops_gen import _Dialect +from ._irdl_enum_gen import * +from .._mlir_libs._mlirDialectsIRDL import * +from ..ir import register_attribute_builder +from ._ods_common import _cext as _ods_cext +from typing import Union, Sequence + +_ods_ir = _ods_cext.ir + + +@_ods_cext.register_operation(_Dialect, replace=True) +class DialectOp(DialectOp): + __doc__ = DialectOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def dialect(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> DialectOp: + return DialectOp(sym_name=sym_name, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperationOp(OperationOp): + __doc__ = OperationOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def operation_( + sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None +) -> OperationOp: + return OperationOp(sym_name=sym_name, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypeOp(TypeOp): + __doc__ = TypeOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def type_(sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None) -> TypeOp: + return TypeOp(sym_name=sym_name, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AttributeOp(AttributeOp): + __doc__ = AttributeOp.__doc__ + + def __init__(self, sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None): + super().__init__(sym_name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self) -> _ods_ir.Block: + return self.regions[0].blocks[0] + + +def attribute( + sym_name: Union[str, _ods_ir.Attribute], *, loc=None, ip=None +) -> AttributeOp: + return AttributeOp(sym_name=sym_name, loc=loc, ip=ip) + + +@register_attribute_builder("VariadicityArrayAttr") +def _variadicity_array_attr(x: Sequence[Variadicity], context) -> _ods_ir.Attribute: + return _ods_ir.Attribute.parse( + f"#irdl", context + ) diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py index ae7a4c41c..38970d17a 100644 --- a/mlir/python/mlir/dialects/smt.py +++ b/mlir/python/mlir/dialects/smt.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._smt_ops_gen import * +from ._smt_enum_gen import * from .._mlir_libs._mlirDialectsSMT import * from ..extras.meta import region_op diff --git a/mlir/python/mlir/dialects/transform/smt.py b/mlir/python/mlir/dialects/transform/smt.py new file mode 100644 index 000000000..1f0b7f066 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/smt.py @@ -0,0 +1,38 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Sequence + +from ...ir import Type, Block +from .._transform_smt_extension_ops_gen import * +from .._transform_smt_extension_ops_gen import _Dialect +from ...dialects import transform + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstrainParamsOp(ConstrainParamsOp): + def __init__( + self, + params: Sequence[transform.AnyParamType], + arg_types: Sequence[Type], + loc=None, + ip=None, + ): + if len(params) != len(arg_types): + raise ValueError(f"{params=} not same length as {arg_types=}") + super().__init__( + params, + loc=loc, + ip=ip, + ) + self.regions[0].blocks.append(*arg_types) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index bf40cc532..e3bacb577 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -44,18 +44,12 @@ def __init__( loc=None, ip=None, ): - # No other types are allowed, so hard-code those here. - allocated_buffer_type = transform.AnyValueType.get() - new_ops_type = transform.AnyOpType.get() - if isinstance(memory_space, int): memory_space = str(memory_space) if isinstance(memory_space, str): memory_space = Attribute.parse(memory_space) super().__init__( - allocated_buffer_type, - new_ops_type, target, memory_space=memory_space, memcpy_op=memcpy_op, diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py index f63f88a38..b3bfa8015 100644 --- a/mlir/python/mlir/dialects/transform/tune.py +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -6,6 +6,9 @@ from ...ir import ( Type, + Value, + Operation, + OpView, Attribute, ArrayAttr, StringAttr, @@ -19,7 +22,10 @@ from .._transform_tune_extension_ops_gen import _Dialect try: - from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -36,7 +42,7 @@ def __init__( ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): @@ -75,8 +81,62 @@ def knob( ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute ], *, - selected: Optional[Attribute] = None, + selected: Optional[Union[Attribute, bool, int, float, str]] = None, loc=None, ip=None, ): return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AlternativesOp(AlternativesOp): + def __init__( + self, + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[ + Union[int, IntegerAttr, Value, Operation, OpView] + ] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + selected_region_attr = selected_region_param = None + if isinstance(selected_region, IntegerAttr): + selected_region_attr = selected_region + elif isinstance(selected_region, int): + selected_region_attr = IntegerAttr.get( + IntegerType.get_signless(32), selected_region + ) + elif isinstance(selected_region, (Value, Operation, OpView)): + selected_region_param = _get_op_result_or_value(selected_region) + + super().__init__( + results, + name, + num_alternatives, + selected_region_attr=selected_region_attr, + selected_region_param=selected_region_param, + loc=loc, + ip=ip, + ) + for region in self.regions: + region.blocks.append() + + +def alternatives( + results: Sequence[Type], + name: Union[StringAttr, str], + num_alternatives: int, + *, + selected_region: Optional[Union[int, IntegerAttr, Value, Operation, OpView]] = None, + loc=None, + ip=None, +): + return AlternativesOp( + results, name, num_alternatives, selected_region=selected_region, loc=loc, ip=ip + ) diff --git a/mlir/python/mlir/dialects/ub.py b/mlir/python/mlir/dialects/ub.py new file mode 100644 index 000000000..32e870674 --- /dev/null +++ b/mlir/python/mlir/dialects/ub.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._ub_ops_gen import * diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 6f37266d5..11477d061 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -2,9 +2,18 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from __future__ import annotations + +from collections.abc import Iterable +from contextlib import contextmanager + from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs._mlir import ( + register_type_caster, + register_value_caster, + globals as _globals, +) from ._mlir_libs import ( get_dialect_registry, append_load_on_create_dialect, @@ -12,6 +21,30 @@ ) +@contextmanager +def loc_tracebacks(*, max_depth: int | None = None) -> Iterable[None]: + """Enables automatic traceback-based locations for MLIR operations. + + Operations created within this context will have their location + automatically set based on the Python call stack. + + Args: + max_depth: Maximum number of frames to include in the location. + If None, the default limit is used. + """ + old_enabled = _globals.loc_tracebacks_enabled() + old_limit = _globals.loc_tracebacks_frame_limit() + try: + _globals.set_loc_tracebacks_frame_limit(max_depth) + if not old_enabled: + _globals.set_loc_tracebacks_enabled(True) + yield + finally: + if not old_enabled: + _globals.set_loc_tracebacks_enabled(False) + _globals.set_loc_tracebacks_frame_limit(old_limit) + + # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind, replace=False): def decorator_builder(func): diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 1a0075e82..abe09259b 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,6 +1,7 @@ -nanobind>=2.4, <3.0 +nanobind>=2.9, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 -ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" \ No newline at end of file +ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" +typing_extensions>=4.12.2 diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp index 3140f12c0..8ec2e0309 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeDefGen.cpp @@ -10,12 +10,12 @@ #include "CppGenUtilities.h" #include "mlir/TableGen/AttrOrTypeDef.h" #include "mlir/TableGen/Class.h" -#include "mlir/TableGen/CodeGenHelpers.h" #include "mlir/TableGen/Format.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Interfaces.h" #include "llvm/ADT/StringSet.h" #include "llvm/Support/CommandLine.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" @@ -71,14 +71,14 @@ class DefGen { void emitDecl(raw_ostream &os) const { if (storageCls && def.genStorageClass()) { - NamespaceEmitter ns(os, def.getStorageNamespace()); + llvm::NamespaceEmitter ns(os, def.getStorageNamespace()); os << "struct " << def.getStorageClassName() << ";\n"; } defCls.writeDeclTo(os); } void emitDef(raw_ostream &os) const { if (storageCls && def.genStorageClass()) { - NamespaceEmitter ns(os, def.getStorageNamespace()); + llvm::NamespaceEmitter ns(os, def.getStorageNamespace()); storageCls->writeDeclTo(os); // everything is inline } defCls.writeDefTo(os); @@ -130,6 +130,9 @@ class DefGen { void emitTraitMethods(const InterfaceTrait &trait); /// Emit a trait method. void emitTraitMethod(const InterfaceMethod &method); + /// Generate a using declaration for a trait method. + void genTraitMethodUsingDecl(const InterfaceTrait &trait, + const InterfaceMethod &method); //===--------------------------------------------------------------------===// // OpAsm{Type,Attr}Interface Default Method Emission @@ -176,6 +179,9 @@ class DefGen { StringRef valueType; /// The prefix/suffix of the TableGen def name, either "Attr" or "Type". StringRef defType; + + /// The set of using declarations for trait methods. + llvm::StringSet<> interfaceUsingNames; }; } // namespace @@ -513,14 +519,57 @@ getCustomBuilderParams(std::initializer_list prefix, return builderParams; } +static std::string getSignature(const Method &m) { + std::string signature; + llvm::raw_string_ostream os(signature); + raw_indented_ostream indentedOs(os); + m.writeDeclTo(indentedOs); + return signature; +} + +static void emitDuplicatedBuilderError(const Method ¤tMethod, + StringRef methodName, + const Class &defCls, + const AttrOrTypeDef &def) { + + // Try to search for method that makes `get` redundant. + auto loc = def.getDef()->getFieldLoc("builders"); + for (auto &method : defCls.getMethods()) { + if (method->getName() == methodName && + method->makesRedundant(currentMethod)) { + PrintError(loc, llvm::Twine("builder `") + methodName + + "` conflicts with an existing builder. "); + PrintFatalNote(llvm::Twine("A new builder with signature:\n") + + getSignature(currentMethod) + + "\nis shadowed by an existing builder with signature:\n" + + getSignature(*method) + + "\nPlease remove one of the conflicting " + "definitions."); + } + } + + // This code shouldn't be reached, but leaving this here for potential future + // use. + PrintFatalError(loc, "Failed to generate builder " + methodName); +} + void DefGen::emitCustomBuilder(const AttrOrTypeBuilder &builder) { // Don't emit a body if there isn't one. auto props = builder.getBody() ? Method::Static : Method::StaticDeclaration; StringRef returnType = def.getCppClassName(); if (std::optional builderReturnType = builder.getReturnType()) returnType = *builderReturnType; - Method *m = defCls.addMethod(returnType, "get", props, - getCustomBuilderParams({}, builder)); + + llvm::StringRef methodName = "get"; + const auto parameters = getCustomBuilderParams({}, builder); + Method *m = defCls.addMethod(returnType, methodName, props, parameters); + + // If method is pruned, report error and terminate. + if (!m) { + auto curMethod = Method(returnType, methodName, props, parameters); + emitDuplicatedBuilderError(curMethod, methodName, defCls, def); + } + if (!builder.getBody()) return; @@ -547,11 +596,19 @@ void DefGen::emitCheckedCustomBuilder(const AttrOrTypeBuilder &builder) { StringRef returnType = def.getCppClassName(); if (std::optional builderReturnType = builder.getReturnType()) returnType = *builderReturnType; - Method *m = defCls.addMethod( - returnType, "getChecked", props, - getCustomBuilderParams( - {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}, - builder)); + + llvm::StringRef methodName = "getChecked"; + auto parameters = getCustomBuilderParams( + {{"::llvm::function_ref<::mlir::InFlightDiagnostic()>", "emitError"}}, + builder); + Method *m = defCls.addMethod(returnType, methodName, props, parameters); + + // If method is pruned, report error and terminate. + if (!m) { + auto curMethod = Method(returnType, methodName, props, parameters); + emitDuplicatedBuilderError(curMethod, methodName, defCls, def); + } + if (!builder.getBody()) return; @@ -581,8 +638,10 @@ void DefGen::emitTraitMethods(const InterfaceTrait &trait) { // Don't declare if the method has a body. Or if the method has a default // implementation and the def didn't request that it always be declared. if (method.getBody() || (method.getDefaultImplementation() && - !alwaysDeclared.count(method.getName()))) + !alwaysDeclared.count(method.getName()))) { + genTraitMethodUsingDecl(trait, method); continue; + } emitTraitMethod(method); } } @@ -598,6 +657,15 @@ void DefGen::emitTraitMethod(const InterfaceMethod &method) { std::move(params)); } +void DefGen::genTraitMethodUsingDecl(const InterfaceTrait &trait, + const InterfaceMethod &method) { + std::string name = (llvm::Twine(trait.getFullyQualifiedTraitName()) + "<" + + def.getCppClassName() + ">::" + method.getName()) + .str(); + if (interfaceUsingNames.insert(name).second) + defCls.declare(std::move(name)); +} + //===----------------------------------------------------------------------===// // OpAsm{Type,Attr}Interface Default Method Emission @@ -799,7 +867,7 @@ class AsmPrinter; bool DefGenerator::emitDecls(StringRef selectedDialect) { emitSourceFileHeader((defType + "Def Declarations").str(), os); - IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os); + llvm::IfDefEmitter scope(os, "GET_" + defType.upper() + "DEF_CLASSES"); // Output the common "header". os << typeDefDeclHeader; @@ -809,15 +877,12 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) { if (defs.empty()) return false; { - NamespaceEmitter nsEmitter(os, defs.front().getDialect()); + DialectNamespaceEmitter nsEmitter(os, defs.front().getDialect()); // Declare all the def classes first (in case they reference each other). for (const AttrOrTypeDef &def : defs) { - std::string comments = tblgen::emitSummaryAndDescComments( - def.getSummary(), def.getDescription()); - if (!comments.empty()) { - os << comments << "\n"; - } + tblgen::emitSummaryAndDescComments(os, def.getSummary(), + def.getDescription()); os << "class " << def.getCppClassName() << ";\n"; } @@ -841,7 +906,7 @@ bool DefGenerator::emitDecls(StringRef selectedDialect) { //===----------------------------------------------------------------------===// void DefGenerator::emitTypeDefList(ArrayRef defs) { - IfDefScope scope("GET_" + defType.upper() + "DEF_LIST", os); + llvm::IfDefEmitter scope(os, "GET_" + defType.upper() + "DEF_LIST"); auto interleaveFn = [&](const AttrOrTypeDef &def) { os << def.getDialect().getCppNamespace() << "::" << def.getCppClassName(); }; @@ -1032,11 +1097,11 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { return false; emitTypeDefList(defs); - IfDefScope scope("GET_" + defType.upper() + "DEF_CLASSES", os); + llvm::IfDefEmitter scope(os, "GET_" + defType.upper() + "DEF_CLASSES"); emitParsePrintDispatch(defs); for (const AttrOrTypeDef &def : defs) { { - NamespaceEmitter ns(os, def.getDialect()); + DialectNamespaceEmitter ns(os, def.getDialect()); DefGen gen(def); gen.emitDef(os); } @@ -1051,7 +1116,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { // Emit the default parser/printer for Attributes if the dialect asked for it. if (isAttrGenerator && firstDialect.useDefaultAttributePrinterParser()) { - NamespaceEmitter nsEmitter(os, firstDialect); + DialectNamespaceEmitter nsEmitter(os, firstDialect); if (firstDialect.isExtensible()) { os << llvm::formatv(dialectDefaultAttrPrinterParserDispatch, firstDialect.getCppClassName(), @@ -1065,7 +1130,7 @@ bool DefGenerator::emitDefs(StringRef selectedDialect) { // Emit the default parser/printer for Types if the dialect asked for it. if (!isAttrGenerator && firstDialect.useDefaultTypePrinterParser()) { - NamespaceEmitter nsEmitter(os, firstDialect); + DialectNamespaceEmitter nsEmitter(os, firstDialect); if (firstDialect.isExtensible()) { os << llvm::formatv(dialectDefaultTypePrinterParserDispatch, firstDialect.getCppClassName(), @@ -1115,7 +1180,7 @@ getAllCppAttrConstraints(const RecordKeeper &records) { /// Emit the declarations for the given constraints, of the form: /// `bool ( );` -static void emitConstraintDecls(const std::vector &constraints, +static void emitConstraintDecls(ArrayRef constraints, raw_ostream &os, StringRef parameterTypeName, StringRef parameterName) { static const char *const constraintDecl = "bool {0}({1} {2});\n"; @@ -1141,7 +1206,7 @@ static void emitAttrConstraintDecls(const RecordKeeper &records, /// return (); }` /// where `` is the condition template with the `self` variable /// replaced with the `selfName` parameter. -static void emitConstraintDefs(const std::vector &constraints, +static void emitConstraintDefs(ArrayRef constraints, raw_ostream &os, StringRef parameterTypeName, StringRef selfName) { static const char *const constraintDef = R"( diff --git a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp index 10a162f81..34547e9fe 100644 --- a/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/AttrOrTypeFormatGen.cpp @@ -89,10 +89,7 @@ static ParameterElement *getEncapsulatedParameterElement(FormatElement *el) { .Case([&](auto param) { return param; }) .Case( [&](auto ref) { return cast(ref->getArg()); }) - .Default([&](auto el) { - assert(false && "unexpected struct element type"); - return nullptr; - }); + .DefaultUnreachable("unexpected struct element type"); } /// Shorthand functions that can be used with ranged-based conditions. @@ -403,6 +400,7 @@ void DefFormat::genLiteralParser(StringRef value, FmtContext &ctx, .Case("]", "RSquare") .Case("?", "Question") .Case("+", "Plus") + .Case("-", "Minus") .Case("*", "Star") .Case("...", "Ellipsis") << "()"; @@ -585,7 +583,7 @@ void DefFormat::genStructParser(StructDirective *el, FmtContext &ctx, os.getStream().printReindented(strfmt(checkParamKey, param->getName())); if (isa(arg)) genVariableParser(param, ctx, os.indent()); - else if (auto custom = dyn_cast(arg)) + else if (auto *custom = dyn_cast(arg)) genCustomParser(custom, ctx, os.indent()); os.unindent() << "} else "; // Print the check for duplicate or unknown parameter. @@ -877,9 +875,9 @@ void DefFormat::genCommaSeparatedPrinter( extra(arg); shouldEmitSpace = false; lastWasPunctuation = true; - if (auto realParam = dyn_cast(arg)) + if (auto *realParam = dyn_cast(arg)) genVariablePrinter(realParam, ctx, os); - else if (auto custom = dyn_cast(arg)) + else if (auto *custom = dyn_cast(arg)) genCustomPrinter(custom, ctx, os); if (param->isOptional()) os.unindent() << "}\n"; @@ -1124,7 +1122,7 @@ DefFormatParser::verifyStructArguments(SMLoc loc, return emitError(loc, "expected a parameter, custom directive or params " "directive in `struct` arguments list"); } - if (auto custom = dyn_cast(el)) { + if (auto *custom = dyn_cast(el)) { if (custom->getNumElements() != 1) { return emitError(loc, "`struct` can only contain `custom` directives " "with a single argument"); diff --git a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp index da28ca3a7..533a9cff5 100644 --- a/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp +++ b/mlir/tools/mlir-tblgen/BytecodeDialectGen.cpp @@ -151,9 +151,9 @@ void Generator::emitParse(StringRef kind, const Record &x) { os << "\n\n"; } -void printParseConditional(mlir::raw_indented_ostream &ios, - ArrayRef args, - ArrayRef argNames) { +static void printParseConditional(mlir::raw_indented_ostream &ios, + ArrayRef args, + ArrayRef argNames) { ios << "if "; auto parenScope = ios.scope("(", ") {"); ios.indent(); diff --git a/mlir/tools/mlir-tblgen/CppGenUtilities.cpp b/mlir/tools/mlir-tblgen/CppGenUtilities.cpp index ebca20cc6..fddd7790a 100644 --- a/mlir/tools/mlir-tblgen/CppGenUtilities.cpp +++ b/mlir/tools/mlir-tblgen/CppGenUtilities.cpp @@ -14,26 +14,31 @@ #include "CppGenUtilities.h" #include "mlir/Support/IndentedOstream.h" -std::string -mlir::tblgen::emitSummaryAndDescComments(llvm::StringRef summary, - llvm::StringRef description) { +void mlir::tblgen::emitSummaryAndDescComments(llvm::raw_ostream &os, + llvm::StringRef summary, + llvm::StringRef description, + bool terminateComment) { std::string comments = ""; StringRef trimmedSummary = summary.trim(); StringRef trimmedDesc = description.trim(); - llvm::raw_string_ostream os(comments); raw_indented_ostream ros(os); + bool empty = true; if (!trimmedSummary.empty()) { ros.printReindented(trimmedSummary, "/// "); + empty = false; } if (!trimmedDesc.empty()) { - if (!trimmedSummary.empty()) { + if (!empty) { // If there is a summary, add a newline after it. ros << "\n"; } ros.printReindented(trimmedDesc, "/// "); + empty = false; } - return comments; + + if (!empty && terminateComment) + ros << "\n"; } diff --git a/mlir/tools/mlir-tblgen/CppGenUtilities.h b/mlir/tools/mlir-tblgen/CppGenUtilities.h index 231c59a9e..69d8cd85e 100644 --- a/mlir/tools/mlir-tblgen/CppGenUtilities.h +++ b/mlir/tools/mlir-tblgen/CppGenUtilities.h @@ -15,14 +15,16 @@ #define MLIR_TOOLS_MLIRTBLGEN_CPPGENUTILITIES_H_ #include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" namespace mlir { namespace tblgen { -// Emit the summary and description as a C++ comment, perperly aligned placed -// adjacent to the class declaration of generated classes. -std::string emitSummaryAndDescComments(llvm::StringRef summary, - llvm::StringRef description); +// Emit the summary and description as a C++ comment. If `terminateComment` is +// true, terminates the comment with a `\n`. +void emitSummaryAndDescComments(llvm::raw_ostream &os, llvm::StringRef summary, + llvm::StringRef description, + bool terminateComment = true); } // namespace tblgen } // namespace mlir diff --git a/mlir/tools/mlir-tblgen/DialectGen.cpp b/mlir/tools/mlir-tblgen/DialectGen.cpp index 02941ec12..c2c0c1f41 100644 --- a/mlir/tools/mlir-tblgen/DialectGen.cpp +++ b/mlir/tools/mlir-tblgen/DialectGen.cpp @@ -109,9 +109,7 @@ tblgen::findDialectToGenerate(ArrayRef dialects) { /// {0}: The name of the dialect class. /// {1}: The dialect namespace. /// {2}: The dialect parent class. -/// {3}: The summary and description comments. static const char *const dialectDeclBeginStr = R"( -{3} class {0} : public ::mlir::{2} { explicit {0}(::mlir::MLIRContext *context); @@ -242,17 +240,18 @@ static const char *const discardableAttrHelperDecl = R"( static void emitDialectDecl(Dialect &dialect, raw_ostream &os) { // Emit all nested namespaces. { - NamespaceEmitter nsEmitter(os, dialect); + DialectNamespaceEmitter nsEmitter(os, dialect); // Emit the start of the decl. std::string cppName = dialect.getCppClassName(); StringRef superClassName = dialect.isExtensible() ? "ExtensibleDialect" : "Dialect"; - std::string comments = tblgen::emitSummaryAndDescComments( - dialect.getSummary(), dialect.getDescription()); + tblgen::emitSummaryAndDescComments(os, dialect.getSummary(), + dialect.getDescription(), + /*terminateCmment=*/false); os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(), - superClassName, comments); + superClassName); // If the dialect requested the default attribute printer and parser, emit // the declarations for the hooks. @@ -358,7 +357,7 @@ static void emitDialectDef(Dialect &dialect, const RecordKeeper &records, << "::" << cppClassName << ")\n"; // Emit all nested namespaces. - NamespaceEmitter nsEmitter(os, dialect); + DialectNamespaceEmitter nsEmitter(os, dialect); /// Build the list of dependent dialects. std::string dependentDialectRegistrations; diff --git a/mlir/tools/mlir-tblgen/EnumsGen.cpp b/mlir/tools/mlir-tblgen/EnumsGen.cpp index 06dc588f9..d55ad482f 100644 --- a/mlir/tools/mlir-tblgen/EnumsGen.cpp +++ b/mlir/tools/mlir-tblgen/EnumsGen.cpp @@ -46,8 +46,7 @@ static std::string makeIdentifier(StringRef str) { static void emitEnumClass(const Record &enumDef, StringRef enumName, StringRef underlyingType, StringRef description, - const std::vector &enumerants, - raw_ostream &os) { + ArrayRef enumerants, raw_ostream &os) { os << "// " << description << "\n"; os << "enum class " << enumName; @@ -55,14 +54,13 @@ static void emitEnumClass(const Record &enumDef, StringRef enumName, os << " : " << underlyingType; os << " {\n"; - for (const auto &enumerant : enumerants) { + for (const EnumCase &enumerant : enumerants) { auto symbol = makeIdentifier(enumerant.getSymbol()); auto value = enumerant.getValue(); - if (value >= 0) { + if (value >= 0) os << formatv(" {0} = {1},\n", symbol, value); - } else { + else os << formatv(" {0},\n", symbol); - } } os << "};\n\n"; } @@ -222,7 +220,7 @@ inline ::llvm::raw_ostream &operator<<(::llvm::raw_ostream &p, {0} value) {{ llvm::StringSwitch(separator.trim()) .Case("|", "parseOptionalVerticalBar") .Case(",", "parseOptionalComma") - .Default("error, enum seperator must be '|' or ','"); + .Default("error, enum separator must be '|' or ','"); os << formatv(parsedAndPrinterStartUnquotedBitEnum, qualName, cppNamespace, enumInfo.getSummary(), casesList, separator, parseSeparatorFn); @@ -364,6 +362,9 @@ getAllBitsUnsetCase(llvm::ArrayRef cases) { // inline constexpr operator|( a, b); // inline constexpr operator&( a, b); // inline constexpr operator^( a, b); +// inline constexpr &operator|=( &a, b); +// inline constexpr &operator&=( &a, b); +// inline constexpr &operator^=( &a, b); // inline constexpr operator~( bits); // inline constexpr bool bitEnumContainsAll( bits, bit); // inline constexpr bool bitEnumContainsAny( bits, bit); @@ -385,6 +386,15 @@ inline constexpr {0} operator&({0} a, {0} b) {{ inline constexpr {0} operator^({0} a, {0} b) {{ return static_cast<{0}>(static_cast<{1}>(a) ^ static_cast<{1}>(b)); } +inline constexpr {0} &operator|=({0} &a, {0} b) {{ + return a = a | b; +} +inline constexpr {0} &operator&=({0} &a, {0} b) {{ + return a = a & b; +} +inline constexpr {0} &operator^=({0} &a, {0} b) {{ + return a = a ^ b; +} inline constexpr {0} operator~({0} bits) {{ // Ensure only bits that can be present in the enum are set return static_cast<{0}>(~static_cast<{1}>(bits) & static_cast<{1}>({2}u)); diff --git a/mlir/tools/mlir-tblgen/FormatGen.cpp b/mlir/tools/mlir-tblgen/FormatGen.cpp index 4dfdde214..04d3ed1f3 100644 --- a/mlir/tools/mlir-tblgen/FormatGen.cpp +++ b/mlir/tools/mlir-tblgen/FormatGen.cpp @@ -518,7 +518,7 @@ bool mlir::tblgen::isValidLiteral(StringRef value, // If there is only one character, this must either be punctuation or a // single character bare identifier. if (value.size() == 1) { - StringRef bare = "_:,=<>()[]{}?+*"; + StringRef bare = "_:,=<>()[]{}?+-*"; if (isalpha(front) || bare.contains(front)) return true; if (emitError) diff --git a/mlir/tools/mlir-tblgen/OmpOpGen.cpp b/mlir/tools/mlir-tblgen/OmpOpGen.cpp index 91bc61bcd..e1be11a0d 100644 --- a/mlir/tools/mlir-tblgen/OmpOpGen.cpp +++ b/mlir/tools/mlir-tblgen/OmpOpGen.cpp @@ -342,7 +342,7 @@ static bool verifyDecls(const RecordKeeper &records, raw_ostream &) { /// structures according to the `clauses` argument of each definition deriving /// from `OpenMP_Op`. static bool genClauseOps(const RecordKeeper &records, raw_ostream &os) { - mlir::tblgen::NamespaceEmitter ns(os, "mlir::omp"); + llvm::NamespaceEmitter ns(os, "mlir::omp"); for (const Record *clause : records.getAllDerivedDefinitions("OpenMP_Clause")) genClauseOpsStruct(clause, os); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 8ea4eb7b3..969011546 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -36,6 +36,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/Signals.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/CodeGenHelpers.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -789,6 +790,14 @@ class OpEmitter { Method *genOpInterfaceMethod(const tblgen::InterfaceMethod &method, bool declaration = true); + // Generate a `using` declaration for the op interface method to include + // the default implementation from the interface trait. + // This is needed when the interface defines multiple methods with the same + // name, but some have a default implementation and some don't. + UsingDeclaration * + genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait, + const tblgen::InterfaceMethod &method); + // Generate the side effect interface methods. void genSideEffectInterfaceMethods(); @@ -815,6 +824,10 @@ class OpEmitter { // Helper for emitting op code. OpOrAdaptorHelper emitHelper; + + // Keep track of the interface using declarations that have been generated to + // avoid duplicates. + llvm::StringSet<> interfaceUsingNames; }; } // namespace @@ -3104,8 +3117,8 @@ void OpEmitter::genBuilder() { std::optional body = builder.getBody(); auto properties = body ? Method::Static : Method::StaticDeclaration; auto *method = opClass.addMethod("void", "build", properties, arguments); - if (body) - ERROR_IF_PRUNED(method, "build", op); + + ERROR_IF_PRUNED(method, "build", op); if (method) method->setDeprecated(builder.getDeprecatedMessage()); @@ -3672,8 +3685,10 @@ void OpEmitter::genOpInterfaceMethods(const tblgen::InterfaceTrait *opTrait) { // Don't declare if the method has a default implementation and the op // didn't request that it always be declared. if (method.getDefaultImplementation() && - !alwaysDeclaredMethods.count(method.getName())) + !alwaysDeclaredMethods.count(method.getName())) { + genOpInterfaceMethodUsingDecl(opTrait, method); continue; + } // Interface methods are allowed to overlap with existing methods, so don't // check if pruned. (void)genOpInterfaceMethod(method); @@ -3692,6 +3707,17 @@ Method *OpEmitter::genOpInterfaceMethod(const InterfaceMethod &method, std::move(paramList)); } +UsingDeclaration * +OpEmitter::genOpInterfaceMethodUsingDecl(const tblgen::InterfaceTrait *opTrait, + const InterfaceMethod &method) { + std::string name = (llvm::Twine(opTrait->getFullyQualifiedTraitName()) + "<" + + op.getCppClassName() + ">::" + method.getName()) + .str(); + if (interfaceUsingNames.insert(name).second) + return opClass.declare(std::move(name)); + return nullptr; +} + void OpEmitter::genOpInterfaceMethods() { for (const auto &trait : op.getTraits()) { if (const auto *opTrait = dyn_cast(&trait)) @@ -3849,9 +3875,9 @@ void OpEmitter::genTypeInterfaceMethods() { const InferredResultType &infer = op.getInferredResultType(i); if (!infer.isArg()) continue; - Operator::OperandOrAttribute arg = - op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + Operator::OperandAttrOrProp arg = + op.getArgToOperandAttrOrProp(infer.getIndex()); + if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) { maxAccessedIndex = std::max(maxAccessedIndex, arg.operandOrAttributeIndex()); } @@ -3877,17 +3903,16 @@ void OpEmitter::genTypeInterfaceMethods() { if (infer.isArg()) { // If this is an operand, just index into operand list to access the // type. - Operator::OperandOrAttribute arg = - op.getArgToOperandOrAttribute(infer.getIndex()); - if (arg.kind() == Operator::OperandOrAttribute::Kind::Operand) { + Operator::OperandAttrOrProp arg = + op.getArgToOperandAttrOrProp(infer.getIndex()); + if (arg.kind() == Operator::OperandAttrOrProp::Kind::Operand) { typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) + "].getType()") .str(); // If this is an attribute, index into the attribute dictionary. - } else { - auto *attr = - cast(op.getArg(arg.operandOrAttributeIndex())); + } else if (auto *attr = dyn_cast( + op.getArg(arg.operandOrAttributeIndex()))) { body << " ::mlir::TypedAttr odsInferredTypeAttr" << inferredTypeIdx << " = "; if (op.getDialect().usePropertiesForAttributes()) { @@ -3907,6 +3932,9 @@ void OpEmitter::genTypeInterfaceMethods() { typeStr = ("odsInferredTypeAttr" + Twine(inferredTypeIdx) + ".getType()") .str(); + } else { + llvm::PrintFatalError(&op.getDef(), + "Properties cannot be used for type inference"); } } else if (std::optional builder = op.getResult(infer.getResultIndex()) @@ -4798,11 +4826,9 @@ void OpOperandAdaptorEmitter::emitDef( } /// Emit the class declarations or definitions for the given op defs. -static void -emitOpClasses(const RecordKeeper &records, - const std::vector &defs, raw_ostream &os, - const StaticVerifierFunctionEmitter &staticVerifierEmitter, - bool emitDecl) { +static void emitOpClasses( + const RecordKeeper &records, ArrayRef defs, raw_ostream &os, + const StaticVerifierFunctionEmitter &staticVerifierEmitter, bool emitDecl) { if (defs.empty()) return; @@ -4837,23 +4863,19 @@ emitOpClasses(const RecordKeeper &records, /// Emit the declarations for the provided op classes. static void emitOpClassDecls(const RecordKeeper &records, - const std::vector &defs, - raw_ostream &os) { + ArrayRef defs, raw_ostream &os) { // First emit forward declaration for each class, this allows them to refer // to each others in traits for example. - for (auto *def : defs) { + for (const Record *def : defs) { Operator op(*def); NamespaceEmitter emitter(os, op.getCppNamespace()); - std::string comments = tblgen::emitSummaryAndDescComments( - op.getSummary(), op.getDescription()); - if (!comments.empty()) { - os << comments << "\n"; - } + tblgen::emitSummaryAndDescComments(os, op.getSummary(), + op.getDescription()); os << "class " << op.getCppClassName() << ";\n"; } // Emit the op class declarations. - IfDefScope scope("GET_OP_CLASSES", os); + IfDefEmitter scope(os, "GET_OP_CLASSES"); if (defs.empty()) return; StaticVerifierFunctionEmitter staticVerifierEmitter(os, records); @@ -4896,7 +4918,7 @@ static bool emitOpDecls(const RecordKeeper &records, raw_ostream &os) { return false; Dialect dialect = Operator(defs.front()).getDialect(); - NamespaceEmitter ns(os, dialect); + DialectNamespaceEmitter ns(os, dialect); const char *const opRegistrationHook = "void register{0}Operations{1}({2}::{0} *dialect);\n"; @@ -4919,7 +4941,7 @@ static void emitOpDefShard(const RecordKeeper &records, std::string shardGuard = "GET_OP_DEFS_"; std::string indexStr = std::to_string(shardIndex); shardGuard += indexStr; - IfDefScope scope(shardGuard, os); + IfDefEmitter scope(os, shardGuard); // Emit the op registration hook in the first shard. const char *const opRegistrationHook = @@ -4960,14 +4982,14 @@ static bool emitOpDefs(const RecordKeeper &records, raw_ostream &os) { // If no shard was requested, emit the regular op list and class definitions. if (shardedDefs.size() == 1) { { - IfDefScope scope("GET_OP_LIST", os); + IfDefEmitter scope(os, "GET_OP_LIST"); interleave( defs, os, [&](const Record *def) { os << Operator(def).getQualCppClassName(); }, ",\n"); } { - IfDefScope scope("GET_OP_CLASSES", os); + IfDefEmitter scope(os, "GET_OP_CLASSES"); emitOpClassDefs(records, defs, os); } return false; diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 67fc7636a..ccf21d160 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -672,7 +672,7 @@ const char *const inferReturnTypesParserCode = R"( /// The code snippet used to generate a parser call for a region list. /// /// {0}: The name for the region list. -const char *regionListParserCode = R"( +static const char *regionListParserCode = R"( { std::unique_ptr<::mlir::Region> region; auto firstRegionResult = parser.parseOptionalRegion(region); @@ -695,7 +695,7 @@ const char *regionListParserCode = R"( /// The code snippet used to ensure a list of regions have terminators. /// /// {0}: The name of the region list. -const char *regionListEnsureTerminatorParserCode = R"( +static const char *regionListEnsureTerminatorParserCode = R"( for (auto ®ion : {0}Regions) ensureTerminator(*region, parser.getBuilder(), result.location); )"; @@ -703,7 +703,7 @@ const char *regionListEnsureTerminatorParserCode = R"( /// The code snippet used to ensure a list of regions have a block. /// /// {0}: The name of the region list. -const char *regionListEnsureSingleBlockParserCode = R"( +static const char *regionListEnsureSingleBlockParserCode = R"( for (auto ®ion : {0}Regions) if (region->empty()) region->emplaceBlock(); )"; @@ -711,7 +711,7 @@ const char *regionListEnsureSingleBlockParserCode = R"( /// The code snippet used to generate a parser call for an optional region. /// /// {0}: The name of the region. -const char *optionalRegionParserCode = R"( +static const char *optionalRegionParserCode = R"( { auto parseResult = parser.parseOptionalRegion(*{0}Region); if (parseResult.has_value() && failed(*parseResult)) @@ -722,7 +722,7 @@ const char *optionalRegionParserCode = R"( /// The code snippet used to generate a parser call for a region. /// /// {0}: The name of the region. -const char *regionParserCode = R"( +static const char *regionParserCode = R"( if (parser.parseRegion(*{0}Region)) return ::mlir::failure(); )"; @@ -730,21 +730,21 @@ const char *regionParserCode = R"( /// The code snippet used to ensure a region has a terminator. /// /// {0}: The name of the region. -const char *regionEnsureTerminatorParserCode = R"( +static const char *regionEnsureTerminatorParserCode = R"( ensureTerminator(*{0}Region, parser.getBuilder(), result.location); )"; /// The code snippet used to ensure a region has a block. /// /// {0}: The name of the region. -const char *regionEnsureSingleBlockParserCode = R"( +static const char *regionEnsureSingleBlockParserCode = R"( if ({0}Region->empty()) {0}Region->emplaceBlock(); )"; /// The code snippet used to generate a parser call for a successor list. /// /// {0}: The name for the successor list. -const char *successorListParserCode = R"( +static const char *successorListParserCode = R"( { ::mlir::Block *succ; auto firstSucc = parser.parseOptionalSuccessor(succ); @@ -766,7 +766,7 @@ const char *successorListParserCode = R"( /// The code snippet used to generate a parser call for a successor. /// /// {0}: The name of the successor. -const char *successorParserCode = R"( +static const char *successorParserCode = R"( if (parser.parseSuccessor({0}Successor)) return ::mlir::failure(); )"; @@ -774,7 +774,7 @@ const char *successorParserCode = R"( /// The code snippet used to generate a parser for OIList /// /// {0}: literal keyword corresponding to a case for oilist -const char *oilistParserCode = R"( +static const char *oilistParserCode = R"( if ({0}Clause) { return parser.emitError(parser.getNameLoc()) << "`{0}` clause can appear at most once in the expansion of the " @@ -852,6 +852,7 @@ static void genLiteralParser(StringRef value, MethodBody &body) { .Case("]", "RSquare()") .Case("?", "Question()") .Case("+", "Plus()") + .Case("-", "Minus()") .Case("*", "Star()") .Case("...", "Ellipsis()"); } @@ -1976,7 +1977,7 @@ void OperationFormat::genParserVariadicSegmentResolution(Operator &op, // operation that has the SingleBlockImplicitTerminator trait. /// /// {0}: The name of the region. -const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( +static const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( { bool printTerminator = true; if (auto *term = {0}.empty() ? nullptr : {0}.begin()->getTerminator()) {{ @@ -1994,7 +1995,7 @@ const char *regionSingleBlockImplicitTerminatorPrinterCode = R"( /// /// {0}: The name of the enum attribute. /// {1}: The name of the enum attributes symbolToString function. -const char *enumAttrBeginPrinterCode = R"( +static const char *enumAttrBeginPrinterCode = R"( { auto caseValue = {0}(); auto caseValueStr = {1}(caseValue); @@ -2386,8 +2387,8 @@ static void genOptionalGroupPrinterAnchor(FormatElement *anchor, }); } -void collect(FormatElement *element, - SmallVectorImpl &variables) { +static void collect(FormatElement *element, + SmallVectorImpl &variables) { TypeSwitch(element) .Case([&](VariableElement *var) { variables.emplace_back(var); }) .Case([&](CustomDirective *ele) { diff --git a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp index 4dfa1908b..730b5b26a 100644 --- a/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp +++ b/mlir/tools/mlir-tblgen/OpInterfacesGen.cpp @@ -42,10 +42,10 @@ static raw_ostream &emitCPPType(StringRef type, raw_ostream &os) { /// Emit the method name and argument list for the given method. If 'addThisArg' /// is true, then an argument is added to the beginning of the argument list for /// the concrete value. -static void emitMethodNameAndArgs(const InterfaceMethod &method, +static void emitMethodNameAndArgs(const InterfaceMethod &method, StringRef name, raw_ostream &os, StringRef valueType, bool addThisArg, bool addConst) { - os << method.getName() << '('; + os << name << '('; if (addThisArg) { if (addConst) os << "const "; @@ -96,9 +96,9 @@ class InterfaceGenerator { void emitConceptDecl(const Interface &interface); void emitModelDecl(const Interface &interface); void emitModelMethodsDef(const Interface &interface); - void emitTraitDecl(const Interface &interface, StringRef interfaceName, - StringRef interfaceTraitsName); + void forwardDeclareInterface(const Interface &interface); void emitInterfaceDecl(const Interface &interface); + void emitInterfaceTraitDecl(const Interface &interface); /// The set of interface records to emit. std::vector defs; @@ -183,11 +183,13 @@ static void emitInterfaceDefMethods(StringRef interfaceQualName, emitInterfaceMethodDoc(method, os); emitCPPType(method.getReturnType(), os); os << interfaceQualName << "::"; - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + emitMethodNameAndArgs(method, method.getName(), os, valueType, + /*addThisArg=*/false, /*addConst=*/!isOpInterface); // Forward to the method on the concrete operation type. - os << " {\n return " << implValue << "->" << method.getName() << '('; + os << " {\n return " << implValue << "->" << method.getUniqueName() + << '('; if (!method.isStatic()) { os << implValue << ", "; os << (isOpInterface ? "getOperation()" : "*this"); @@ -239,7 +241,7 @@ void InterfaceGenerator::emitConceptDecl(const Interface &interface) { for (auto &method : interface.getMethods()) { os << " "; emitCPPType(method.getReturnType(), os); - os << "(*" << method.getName() << ")("; + os << "(*" << method.getUniqueName() << ")("; if (!method.isStatic()) { os << "const Concept *impl, "; emitCPPType(valueType, os) << (method.arg_empty() ? "" : ", "); @@ -289,13 +291,13 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { os << " " << modelClass << "() : Concept{"; llvm::interleaveComma( interface.getMethods(), os, - [&](const InterfaceMethod &method) { os << method.getName(); }); + [&](const InterfaceMethod &method) { os << method.getUniqueName(); }); os << "} {}\n\n"; // Insert each of the virtual method overrides. for (auto &method : interface.getMethods()) { emitCPPType(method.getReturnType(), os << " static inline "); - emitMethodNameAndArgs(method, os, valueType, + emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << ";\n"; @@ -319,7 +321,7 @@ void InterfaceGenerator::emitModelDecl(const Interface &interface) { if (method.isStatic()) os << "static "; emitCPPType(method.getReturnType(), os); - os << method.getName() << "("; + os << method.getUniqueName() << "("; if (!method.isStatic()) { emitCPPType(valueType, os); os << "tablegen_opaque_val"; @@ -350,7 +352,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { emitCPPType(method.getReturnType(), os); os << "detail::" << interface.getName() << "InterfaceTraits::Model<" << valueTemplate << ">::"; - emitMethodNameAndArgs(method, os, valueType, + emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << " {\n "; @@ -384,7 +386,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { emitCPPType(method.getReturnType(), os); os << "detail::" << interface.getName() << "InterfaceTraits::FallbackModel<" << valueTemplate << ">::"; - emitMethodNameAndArgs(method, os, valueType, + emitMethodNameAndArgs(method, method.getUniqueName(), os, valueType, /*addThisArg=*/!method.isStatic(), /*addConst=*/false); os << " {\n "; @@ -396,7 +398,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { os << "return static_cast(impl)->"; // Add the arguments to the call. - os << method.getName() << '('; + os << method.getUniqueName() << '('; if (!method.isStatic()) os << "tablegen_opaque_val" << (method.arg_empty() ? "" : ", "); llvm::interleaveComma( @@ -416,7 +418,7 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { << "InterfaceTraits::ExternalModel::"; - os << method.getName() << "("; + os << method.getUniqueName() << "("; if (!method.isStatic()) { emitCPPType(valueType, os); os << "tablegen_opaque_val"; @@ -445,9 +447,16 @@ void InterfaceGenerator::emitModelMethodsDef(const Interface &interface) { os << "} // namespace " << ns << "\n"; } -void InterfaceGenerator::emitTraitDecl(const Interface &interface, - StringRef interfaceName, - StringRef interfaceTraitsName) { +void InterfaceGenerator::emitInterfaceTraitDecl(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + + os << "namespace detail {\n"; + + StringRef interfaceName = interface.getName(); + auto interfaceTraitsName = (interfaceName + "InterfaceTraits").str(); os << llvm::formatv(" template \n" " struct {0}Trait : public ::mlir::{2}<{0}," " detail::{1}>::Trait<{3}> {{\n", @@ -470,7 +479,8 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface, emitInterfaceMethodDoc(method, os, " "); os << " " << (method.isStatic() ? "static " : ""); emitCPPType(method.getReturnType(), os); - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + emitMethodNameAndArgs(method, method.getName(), os, valueType, + /*addThisArg=*/false, /*addConst=*/!isOpInterface && !method.isStatic()); os << " {\n " << tblgen::tgfmt(defaultImpl->trim(), &traitMethodFmt) << "\n }\n"; @@ -494,6 +504,10 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface, os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n"; os << " };\n"; + os << "}// namespace detail\n"; + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; } static void emitInterfaceDeclMethods(const Interface &interface, @@ -503,7 +517,8 @@ static void emitInterfaceDeclMethods(const Interface &interface, for (auto &method : interface.getMethods()) { emitInterfaceMethodDoc(method, os, " "); emitCPPType(method.getReturnType(), os << " "); - emitMethodNameAndArgs(method, os, valueType, /*addThisArg=*/false, + emitMethodNameAndArgs(method, method.getName(), os, valueType, + /*addThisArg=*/false, /*addConst=*/!isOpInterface); os << ";\n"; } @@ -517,6 +532,24 @@ static void emitInterfaceDeclMethods(const Interface &interface, os << tblgen::tgfmt(extraDecls->rtrim(), &extraDeclsFmt) << "\n"; } +void InterfaceGenerator::forwardDeclareInterface(const Interface &interface) { + llvm::SmallVector namespaces; + llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); + for (StringRef ns : namespaces) + os << "namespace " << ns << " {\n"; + + // Emit a forward declaration of the interface class so that it becomes usable + // in the signature of its methods. + tblgen::emitSummaryAndDescComments(os, "", + interface.getDescription().value_or("")); + + StringRef interfaceName = interface.getName(); + os << "class " << interfaceName << ";\n"; + + for (StringRef ns : llvm::reverse(namespaces)) + os << "} // namespace " << ns << "\n"; +} + void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { llvm::SmallVector namespaces; llvm::SplitString(interface.getCppNamespace(), namespaces, "::"); @@ -528,12 +561,8 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { // Emit a forward declaration of the interface class so that it becomes usable // in the signature of its methods. - std::string comments = tblgen::emitSummaryAndDescComments( - "", interface.getDescription().value_or("")); - if (!comments.empty()) { - os << comments << "\n"; - } - os << "class " << interfaceName << ";\n"; + tblgen::emitSummaryAndDescComments(os, "", + interface.getDescription().value_or("")); // Emit the traits struct containing the concept and model declarations. os << "namespace detail {\n" @@ -603,10 +632,6 @@ void InterfaceGenerator::emitInterfaceDecl(const Interface &interface) { os << "};\n"; - os << "namespace detail {\n"; - emitTraitDecl(interface, interfaceName, interfaceTraitsName); - os << "}// namespace detail\n"; - for (StringRef ns : llvm::reverse(namespaces)) os << "} // namespace " << ns << "\n"; } @@ -619,10 +644,15 @@ bool InterfaceGenerator::emitInterfaceDecls() { llvm::sort(sortedDefs, [](const Record *lhs, const Record *rhs) { return lhs->getID() < rhs->getID(); }); + for (const Record *def : sortedDefs) + forwardDeclareInterface(Interface(def)); for (const Record *def : sortedDefs) emitInterfaceDecl(Interface(def)); + for (const Record *def : sortedDefs) + emitInterfaceTraitDecl(Interface(def)); for (const Record *def : sortedDefs) emitModelMethodsDef(Interface(def)); + return false; } diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 038f56d5a..0172b3fa3 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -13,6 +13,7 @@ #include "OpGenHelpers.h" +#include "mlir/Support/IndentedOstream.h" #include "mlir/TableGen/GenInfo.h" #include "mlir/TableGen/Operator.h" #include "llvm/ADT/StringSet.h" @@ -20,6 +21,7 @@ #include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" +#include using namespace mlir; using namespace mlir::tblgen; @@ -36,7 +38,6 @@ from ._ods_common import _cext as _ods_cext from ._ods_common import ( equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, - get_op_result_or_op_results as _get_op_result_or_op_results, get_op_results_or_values as _get_op_results_or_values, segmented_accessor as _ods_segmented_accessor, ) @@ -44,7 +45,7 @@ _ods_ir = _ods_cext.ir _ods_cext.globals.register_traceback_file_exclusion(__file__) import builtins -from typing import Sequence as _Sequence, Union as _Union +from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional )Py"; @@ -62,10 +63,11 @@ from ._{0}_ops_gen import _Dialect /// Template for operation class: /// {0} is the Python class name; -/// {1} is the operation name. +/// {1} is the operation name; +/// {2} is the docstring for this operation. constexpr const char *opClassTemplate = R"Py( @_ods_cext.register_operation(_Dialect) -class {0}(_ods_ir.OpView): +class {0}(_ods_ir.OpView):{2} OPERATION_NAME = "{1}" )Py"; @@ -93,9 +95,10 @@ constexpr const char *opClassRegionSpecTemplate = R"Py( /// {0} is the name of the accessor; /// {1} is either 'operand' or 'result'; /// {2} is the position in the element list. +/// {3} is the type hint. constexpr const char *opSingleTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {3}: return self.operation.{1}s[{2}] )Py"; @@ -104,11 +107,12 @@ constexpr const char *opSingleTemplate = R"Py( /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. +/// {4} is the type hint. /// This works for both a single variadic group (non-negative length) and an /// single optional element (zero length if the element is absent). constexpr const char *opSingleAfterVariableTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {4}: _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 return self.operation.{1}s[{3} + _ods_variadic_group_length - 1] )Py"; @@ -118,12 +122,13 @@ constexpr const char *opSingleAfterVariableTemplate = R"Py( /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. +/// {4} is the type hint. /// This works if we have only one variable-length group (and it's the optional /// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is /// smaller than the total number of groups. constexpr const char *opOneOptionalTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> _Optional[{4}]: return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}] )Py"; @@ -132,9 +137,10 @@ constexpr const char *opOneOptionalTemplate = R"Py( /// {1} is either 'operand' or 'result'; /// {2} is the total number of element groups; /// {3} is the position of the current group in the group list. +/// {4} is the type hint. constexpr const char *opOneVariadicTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {4}: _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1 return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length] )Py"; @@ -146,9 +152,10 @@ constexpr const char *opOneVariadicTemplate = R"Py( /// {3} is the total number of variadic groups; /// {4} is the number of non-variadic groups preceding the current group; /// {5} is the number of variadic groups preceding the current group. +/// {6} is the type hint. constexpr const char *opVariadicEqualPrefixTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {6}: start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py"; /// Second part of the template for equally-sized case, accessing a single @@ -171,9 +178,10 @@ constexpr const char *opVariadicEqualVariadicTemplate = R"Py( /// {2} is the position of the group in the group list; /// {3} is a return suffix (expected [0] for single-element, empty for /// variadic, and opVariadicSegmentOptionalTrailingTemplate for optional). +/// {4} is the type hint. constexpr const char *opVariadicSegmentTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {4}: {1}_range = _ods_segmented_accessor( self.operation.{1}s, self.operation.attributes["{1}SegmentSizes"], {2}) @@ -189,18 +197,20 @@ constexpr const char *opVariadicSegmentOptionalTrailingTemplate = /// Template for an operation attribute getter: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. +/// {2} is the type hint. constexpr const char *attributeGetterTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {2}: return self.operation.attributes["{1}"] )Py"; /// Template for an optional operation attribute getter: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. +/// {2} is the type hint. constexpr const char *optionalAttributeGetterTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> _Optional[{2}]: if "{1}" not in self.operation.attributes: return None return self.operation.attributes["{1}"] @@ -213,16 +223,17 @@ constexpr const char *optionalAttributeGetterTemplate = R"Py( /// {1} is the original name of the attribute. constexpr const char *unitAttributeGetterTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> bool: return "{1}" in self.operation.attributes )Py"; /// Template for an operation attribute setter: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. +/// {2} is the type hint. constexpr const char *attributeSetterTemplate = R"Py( @{0}.setter - def {0}(self, value): + def {0}(self, value: {2}): if value is None: raise ValueError("'None' not allowed as value for mandatory attributes") self.operation.attributes["{1}"] = value @@ -232,9 +243,10 @@ constexpr const char *attributeSetterTemplate = R"Py( /// removes the attribute: /// {0} is the name of the attribute sanitized for Python; /// {1} is the original name of the attribute. +/// {2} is the type hint. constexpr const char *optionalAttributeSetterTemplate = R"Py( @{0}.setter - def {0}(self, value): + def {0}(self, value: _Optional[{2}]): if value is not None: self.operation.attributes["{1}"] = value elif "{1}" in self.operation.attributes: @@ -266,7 +278,7 @@ constexpr const char *attributeDeleterTemplate = R"Py( constexpr const char *regionAccessorTemplate = R"Py( @builtins.property - def {0}(self): + def {0}(self) -> {2}: return self.regions[{1}] )Py"; @@ -276,8 +288,9 @@ def {0}({2}) -> {4}: )Py"; constexpr const char *valueBuilderVariadicTemplate = R"Py( -def {0}({2}) -> {4}: - return _get_op_result_or_op_results({1}({3})) +def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]: + op = {1}({3}); results = op.results + return results if len(results) > 1 else (results[0] if len(results) == 1 else op) )Py"; static llvm::cl::OptionCategory @@ -357,15 +370,24 @@ static void emitElementAccessors( seenVariableLength = true; if (element.name.empty()) continue; + const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" + : "_ods_ir.OpResult"; if (element.isVariableLength()) { - os << formatv(element.isOptional() ? opOneOptionalTemplate - : opOneVariadicTemplate, - sanitizeName(element.name), kind, numElements, i); + if (element.isOptional()) { + os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind, + numElements, i, type); + } else { + type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList" + : "_ods_ir.OpResultList"; + os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind, + numElements, i, type); + } } else if (seenVariableLength) { os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name), - kind, numElements, i); + kind, numElements, i, type); } else { - os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i); + os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i, + type); } } return; @@ -388,9 +410,17 @@ static void emitElementAccessors( for (unsigned i = 0; i < numElements; ++i) { const NamedTypeConstraint &element = getElement(op, i); if (!element.name.empty()) { + std::string type; + if (element.isVariableLength()) { + type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList" + : "_ods_ir.OpResultList"; + } else { + type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" + : "_ods_ir.OpResult"; + } os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name), kind, numSimpleLength, numVariadicGroups, - numPrecedingSimple, numPrecedingVariadic); + numPrecedingSimple, numPrecedingVariadic, type); os << formatv(element.isVariableLength() ? opVariadicEqualVariadicTemplate : opVariadicEqualSimpleTemplate, @@ -413,13 +443,23 @@ static void emitElementAccessors( if (element.name.empty()) continue; std::string trailing; - if (!element.isVariableLength()) - trailing = "[0]"; - else if (element.isOptional()) - trailing = std::string( - formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); + std::string type = std::strcmp(kind, "operand") == 0 + ? "_ods_ir.OpOperandList" + : "_ods_ir.OpResultList"; + if (!element.isVariableLength() || element.isOptional()) { + type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value" + : "_ods_ir.OpResult"; + if (!element.isVariableLength()) { + trailing = "[0]"; + } else if (element.isOptional()) { + type = "_Optional[" + type + "]"; + trailing = std::string( + formatv(opVariadicSegmentOptionalTrailingTemplate, kind)); + } + } + os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind, - i, trailing); + i, trailing, type); } return; } @@ -449,6 +489,72 @@ static void emitResultAccessors(const Operator &op, raw_ostream &os) { getNumResults(op), getResult); } +static std::string getPythonAttrName(mlir::tblgen::Attribute attr) { + auto storageTypeStr = attr.getStorageType(); + if (storageTypeStr == "::mlir::AffineMapAttr") + return "AffineMapAttr"; + if (storageTypeStr == "::mlir::ArrayAttr") + return "ArrayAttr"; + if (storageTypeStr == "::mlir::BoolAttr") + return "BoolAttr"; + if (storageTypeStr == "::mlir::DenseBoolArrayAttr") + return "DenseBoolArrayAttr"; + if (storageTypeStr == "::mlir::DenseElementsAttr") { + llvm::StringSet<> superClasses; + for (const Record *sc : attr.getDef().getSuperClasses()) + superClasses.insert(sc->getNameInitAsString()); + if (superClasses.contains("FloatElementsAttr") || + superClasses.contains("RankedFloatElementsAttr")) { + return "DenseFPElementsAttr"; + } + return "DenseElementsAttr"; + } + if (storageTypeStr == "::mlir::DenseF32ArrayAttr") + return "DenseF32ArrayAttr"; + if (storageTypeStr == "::mlir::DenseF64ArrayAttr") + return "DenseF64ArrayAttr"; + if (storageTypeStr == "::mlir::DenseFPElementsAttr") + return "DenseFPElementsAttr"; + if (storageTypeStr == "::mlir::DenseI16ArrayAttr") + return "DenseI16ArrayAttr"; + if (storageTypeStr == "::mlir::DenseI32ArrayAttr") + return "DenseI32ArrayAttr"; + if (storageTypeStr == "::mlir::DenseI64ArrayAttr") + return "DenseI64ArrayAttr"; + if (storageTypeStr == "::mlir::DenseI8ArrayAttr") + return "DenseI8ArrayAttr"; + if (storageTypeStr == "::mlir::DenseIntElementsAttr") + return "DenseIntElementsAttr"; + if (storageTypeStr == "::mlir::DenseResourceElementsAttr") + return "DenseResourceElementsAttr"; + if (storageTypeStr == "::mlir::DictionaryAttr") + return "DictAttr"; + if (storageTypeStr == "::mlir::FlatSymbolRefAttr") + return "FlatSymbolRefAttr"; + if (storageTypeStr == "::mlir::FloatAttr") + return "FloatAttr"; + if (storageTypeStr == "::mlir::IntegerAttr") { + if (attr.getAttrDefName().str() == "I1Attr") + return "BoolAttr"; + return "IntegerAttr"; + } + if (storageTypeStr == "::mlir::IntegerSetAttr") + return "IntegerSetAttr"; + if (storageTypeStr == "::mlir::OpaqueAttr") + return "OpaqueAttr"; + if (storageTypeStr == "::mlir::StridedLayoutAttr") + return "StridedLayoutAttr"; + if (storageTypeStr == "::mlir::StringAttr") + return "StringAttr"; + if (storageTypeStr == "::mlir::SymbolRefAttr") + return "SymbolRefAttr"; + if (storageTypeStr == "::mlir::TypeAttr") + return "TypeAttr"; + if (storageTypeStr == "::mlir::UnitAttr") + return "UnitAttr"; + return "Attribute"; +} + /// Emits accessors to Op attributes. static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { for (const auto &namedAttr : op.getAttributes()) { @@ -470,15 +576,18 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { continue; } + std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr); if (namedAttr.attr.isOptional()) { os << formatv(optionalAttributeGetterTemplate, sanitizedName, - namedAttr.name); + namedAttr.name, type); os << formatv(optionalAttributeSetterTemplate, sanitizedName, - namedAttr.name); + namedAttr.name, type); os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name); } else { - os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name); - os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name); + os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name, + type); + os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name, + type); // Non-optional attributes cannot be deleted. } } @@ -492,7 +601,6 @@ static void emitAttributeAccessors(const Operator &op, raw_ostream &os) { constexpr const char *initTemplate = R"Py( def __init__(self, {0}): operands = [] - results = [] attributes = {{} regions = None {1} @@ -738,18 +846,24 @@ populateBuilderLinesOperand(const Operator &op, ArrayRef names, } } -/// Python code template for deriving the operation result types from its -/// attribute: +/// Python code template of generating result types for +/// FirstAttrDerivedResultType trait /// - {0} is the name of the attribute from which to derive the types. -constexpr const char *deriveTypeFromAttrTemplate = - R"Py(_ods_result_type_source_attr = attributes["{0}"] -_ods_derived_result_type = ( +/// - {1} is the number of results. +constexpr const char *firstAttrDerivedResultTypeTemplate = + R"Py(if results is None: + _ods_result_type_source_attr = attributes["{0}"] + _ods_derived_result_type = ( _ods_ir.TypeAttr(_ods_result_type_source_attr).value if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else - _ods_result_type_source_attr.type))Py"; + _ods_result_type_source_attr.type) + results = [_ods_derived_result_type] * {1})Py"; -/// Python code template appending {0} type {1} times to the results list. -constexpr const char *appendSameResultsTemplate = "results.extend([{0}] * {1})"; +/// Python code template of generating result types for +/// SameOperandsAndResultType trait +/// - {0} is the number of results. +constexpr const char *sameOperandsAndResultTypeTemplate = + R"Py(if results is None: results = [operands[0].type] * {0})Py"; /// Appends the given multiline string as individual strings into /// `builderLines`. @@ -768,11 +882,10 @@ static void appendLineByLine(StringRef string, static void populateBuilderLinesResult(const Operator &op, ArrayRef names, SmallVectorImpl &builderLines) { - bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; - if (hasSameArgumentAndResultTypes(op)) { - builderLines.push_back(formatv(appendSameResultsTemplate, - "operands[0].type", op.getNumResults())); + appendLineByLine( + formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(), + builderLines); return; } @@ -780,17 +893,19 @@ populateBuilderLinesResult(const Operator &op, ArrayRef names, const NamedAttribute &firstAttr = op.getAttribute(0); assert(!firstAttr.name.empty() && "unexpected empty name for the attribute " "from which the type is derived"); - appendLineByLine(formatv(deriveTypeFromAttrTemplate, firstAttr.name).str(), + appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name, + op.getNumResults()) + .str(), builderLines); - builderLines.push_back(formatv(appendSameResultsTemplate, - "_ods_derived_result_type", - op.getNumResults())); return; } if (hasInferTypeInterface(op)) return; + bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr; + builderLines.push_back("results = []"); + // For each element, find or generate a name. for (int i = 0, e = op.getNumResults(); i < e; ++i) { const NamedTypeConstraint &element = op.getResult(i); @@ -909,6 +1024,9 @@ static SmallVector emitDefaultOpBuilder(const Operator &op, functionArgs.push_back(builderArgs[i]); } } + if (canInferType(op)) { + functionArgs.push_back("results=None"); + } functionArgs.push_back("loc=None"); functionArgs.push_back("ip=None"); @@ -918,8 +1036,7 @@ static SmallVector emitDefaultOpBuilder(const Operator &op, initArgs.push_back("self._ODS_OPERAND_SEGMENTS"); initArgs.push_back("self._ODS_RESULT_SEGMENTS"); initArgs.push_back("attributes=attributes"); - if (!hasInferTypeInterface(op)) - initArgs.push_back("results=results"); + initArgs.push_back("results=results"); initArgs.push_back("operands=operands"); initArgs.push_back("successors=_ods_successors"); initArgs.push_back("regions=regions"); @@ -972,8 +1089,9 @@ static void emitRegionAccessors(const Operator &op, raw_ostream &os) { assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) && "expected only the last region to be variadic"); os << formatv(regionAccessorTemplate, sanitizeName(region.name), - std::to_string(en.index()) + - (region.isVariadic() ? ":" : "")); + std::to_string(en.index()) + (region.isVariadic() ? ":" : ""), + region.isVariadic() ? "_ods_ir.RegionSequence" + : "_ods_ir.Region"); } } @@ -1005,30 +1123,49 @@ static void emitValueBuilder(const Operator &op, nameWithoutDialect += "_"; std::string params = llvm::join(valueBuilderParams, ", "); std::string args = llvm::join(opBuilderArgs, ", "); - const char *type = - (op.getNumResults() > 1 - ? "_Sequence[_ods_ir.Value]" - : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")); - if (op.getNumVariableLengthResults() > 0) { + if (op.getNumVariableLengthResults()) { os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect, - op.getCppClassName(), params, args, type); + op.getCppClassName(), params, args); } else { - const char *results; - if (op.getNumResults() == 0) { - results = ""; + std::string type = op.getCppClassName().str(); + const char *results = ""; + if (op.getNumResults() > 1) { + type = "_ods_ir.OpResultList"; + results = ".results"; } else if (op.getNumResults() == 1) { + type = "_ods_ir.OpResult"; results = ".result"; - } else { - results = ".results"; } os << formatv(valueBuilderTemplate, nameWithoutDialect, op.getCppClassName(), params, args, type, results); } } +/// Retrieve the description of the given op and generate a docstring for it. +static std::string makeDocStringForOp(const Operator &op) { + if (!op.hasDescription()) + return ""; + + auto desc = op.getDescription().rtrim(" \t").str(); + // Replace all """ with \"\"\" to avoid early termination of the literal. + desc = std::regex_replace(desc, std::regex(R"(""")"), R"(\"\"\")"); + + std::string docString = "\n"; + llvm::raw_string_ostream os(docString); + raw_indented_ostream identedOs(os); + os << R"( r""")" << "\n"; + identedOs.printReindented(desc, " "); + if (!StringRef(desc).ends_with("\n")) + os << "\n"; + os << R"( """)" << "\n"; + + return docString; +} + /// Emits bindings for a specific Op to the given output stream. static void emitOpBindings(const Operator &op, raw_ostream &os) { - os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName()); + os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(), + makeDocStringForOp(op)); // Sized segments. if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) { diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 605033daa..40bc1a9c3 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1024,6 +1024,32 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { int depth = 0; emitMatch(tree, opName, depth); + // Some of the operands could be bound to the same symbol name, we need + // to enforce equality constraint on those. + // This has to happen before user provided constraints, which may assume the + // same name checks are already performed, since in the pattern source code + // the user provided constraints appear later. + // TODO: we should be able to emit equality checks early + // and short circuit unnecessary work if vars are not equal. + for (auto symbolInfoIt = symbolInfoMap.begin(); + symbolInfoIt != symbolInfoMap.end();) { + auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); + auto startRange = range.first; + auto endRange = range.second; + + auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); + for (++startRange; startRange != endRange; ++startRange) { + auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); + emitMatchCheck( + opName, + formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), + formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, + secondOperand)); + } + + symbolInfoIt = endRange; + } + for (auto &appliedConstraint : pattern.getConstraints()) { auto &constraint = appliedConstraint.constraint; auto &entities = appliedConstraint.entities; @@ -1068,29 +1094,6 @@ void PatternEmitter::emitMatchLogic(DagNode tree, StringRef opName) { } } - // Some of the operands could be bound to the same symbol name, we need - // to enforce equality constraint on those. - // TODO: we should be able to emit equality checks early - // and short circuit unnecessary work if vars are not equal. - for (auto symbolInfoIt = symbolInfoMap.begin(); - symbolInfoIt != symbolInfoMap.end();) { - auto range = symbolInfoMap.getRangeOfEqualElements(symbolInfoIt->first); - auto startRange = range.first; - auto endRange = range.second; - - auto firstOperand = symbolInfoIt->second.getVarName(symbolInfoIt->first); - for (++startRange; startRange != endRange; ++startRange) { - auto secondOperand = startRange->second.getVarName(symbolInfoIt->first); - emitMatchCheck( - opName, - formatv("*{0}.begin() == *{1}.begin()", firstOperand, secondOperand), - formatv("\"Operands '{0}' and '{1}' must be equal\"", firstOperand, - secondOperand)); - } - - symbolInfoIt = endRange; - } - LLVM_DEBUG(llvm::dbgs() << "--- done emitting match logic ---\n"); } diff --git a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp index 41ffdfcbd..3ead2f0e3 100644 --- a/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/SPIRVUtilsGen.cpp @@ -50,7 +50,6 @@ using mlir::tblgen::EnumCase; using mlir::tblgen::EnumInfo; using mlir::tblgen::NamedAttribute; using mlir::tblgen::NamedTypeConstraint; -using mlir::tblgen::NamespaceEmitter; using mlir::tblgen::Operator; //===----------------------------------------------------------------------===// @@ -261,7 +260,7 @@ static void emitInterfaceDecl(const Availability &availability, std::string(formatv("{0}Traits", interfaceName)); StringRef cppNamespace = availability.getInterfaceClassNamespace(); - NamespaceEmitter nsEmitter(os, cppNamespace); + llvm::NamespaceEmitter nsEmitter(os, cppNamespace); os << "class " << interfaceName << ";\n\n"; // Emit the traits struct containing the concept and model declarations. diff --git a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp index 17261ed40..dc8cc5849 100644 --- a/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp +++ b/mlir/tools/mlir-tblgen/TosaUtilsGen.cpp @@ -1,4 +1,4 @@ -//===- TosaUtilsGen.cpp - Tosa utility generator -===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -122,7 +122,7 @@ StringRef Availability::getMergeInstance() const { } // Returns the availability spec of the given `def`. -std::vector getAvailabilities(const Record &def) { +static std::vector getAvailabilities(const Record &def) { std::vector availabilities; if (def.getValue("availability")) { diff --git a/mlir/tools/mlir-tblgen/mlir-tblgen.cpp b/mlir/tools/mlir-tblgen/mlir-tblgen.cpp index 6c4b61959..9c5cc6a6f 100644 --- a/mlir/tools/mlir-tblgen/mlir-tblgen.cpp +++ b/mlir/tools/mlir-tblgen/mlir-tblgen.cpp @@ -18,10 +18,11 @@ using namespace llvm; using namespace mlir; // Generator that prints records. -GenRegistration printRecords("print-records", "Print all records to stdout", - [](const RecordKeeper &records, raw_ostream &os) { - os << records; - return false; - }); +static GenRegistration + printRecords("print-records", "Print all records to stdout", + [](const RecordKeeper &records, raw_ostream &os) { + os << records; + return false; + }); int main(int argc, char **argv) { return MlirTblgenMain(argc, argv); }