Skip to content
This repository was archived by the owner on Oct 11, 2025. It is now read-only.

Commit e1d8506

Browse files
authored
[MLIR][Python] Add bindings for PDL constraint function registering (#160520)
This is a follow-up to #159926. That PR (#159926) exposed native rewrite function registration in PDL through the C API and Python, enabling use with `pdl.apply_native_rewrite`. In this PR, we add support for native constraint functions in PDL via `pdl.apply_native_constraint`, further completing the PDL API.
1 parent 1aeb695 commit e1d8506

3 files changed

Lines changed: 71 additions & 10 deletions

File tree

mlir/include/mlir-c/Rewrite.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,20 @@ MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterRewriteFunction(
375375
MlirPDLPatternModule pdlModule, MlirStringRef name,
376376
MlirPDLRewriteFunction rewriteFn, void *userData);
377377

378+
/// This function type is used as callbacks for PDL native constraint functions.
379+
/// Input values can be accessed by `values` with its size `nValues`;
380+
/// output values can be added into `results` by `mlirPDLResultListPushBack*`
381+
/// APIs. And the return value indicates whether the constraint holds.
382+
typedef MlirLogicalResult (*MlirPDLConstraintFunction)(
383+
MlirPatternRewriter rewriter, MlirPDLResultList results, size_t nValues,
384+
MlirPDLValue *values, void *userData);
385+
386+
/// Register a constraint function into the given PDL pattern module.
387+
/// `userData` will be provided as an argument to the constraint function.
388+
MLIR_CAPI_EXPORTED void mlirPDLPatternModuleRegisterConstraintFunction(
389+
MlirPDLPatternModule pdlModule, MlirStringRef name,
390+
MlirPDLConstraintFunction constraintFn, void *userData);
391+
378392
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
379393

380394
#undef DEFINE_C_API_STRUCT

mlir/lib/Bindings/Python/Rewrite.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@ static nb::object objectFromPDLValue(MlirPDLValue value) {
4040
throw std::runtime_error("unsupported PDL value type");
4141
}
4242

43+
static std::vector<nb::object> objectsFromPDLValues(size_t nValues,
44+
MlirPDLValue *values) {
45+
std::vector<nb::object> args;
46+
args.reserve(nValues);
47+
for (size_t i = 0; i < nValues; ++i)
48+
args.push_back(objectFromPDLValue(values[i]));
49+
return args;
50+
}
51+
4352
// Convert the Python object to a boolean.
4453
// If it evaluates to False, treat it as success;
4554
// otherwise, treat it as failure.
@@ -74,11 +83,22 @@ class PyPDLPatternModule {
7483
size_t nValues, MlirPDLValue *values,
7584
void *userData) -> MlirLogicalResult {
7685
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
77-
std::vector<nb::object> args;
78-
args.reserve(nValues);
79-
for (size_t i = 0; i < nValues; ++i)
80-
args.push_back(objectFromPDLValue(values[i]));
81-
return logicalResultFromObject(f(rewriter, results, args));
86+
return logicalResultFromObject(
87+
f(rewriter, results, objectsFromPDLValues(nValues, values)));
88+
},
89+
fn.ptr());
90+
}
91+
92+
void registerConstraintFunction(const std::string &name,
93+
const nb::callable &fn) {
94+
mlirPDLPatternModuleRegisterConstraintFunction(
95+
get(), mlirStringRefCreate(name.data(), name.size()),
96+
[](MlirPatternRewriter rewriter, MlirPDLResultList results,
97+
size_t nValues, MlirPDLValue *values,
98+
void *userData) -> MlirLogicalResult {
99+
nb::handle f = nb::handle(static_cast<PyObject *>(userData));
100+
return logicalResultFromObject(
101+
f(rewriter, results, objectsFromPDLValues(nValues, values)));
82102
},
83103
fn.ptr());
84104
}
@@ -199,6 +219,13 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) {
199219
const nb::callable &fn) {
200220
self.registerRewriteFunction(name, fn);
201221
},
222+
nb::keep_alive<1, 3>())
223+
.def(
224+
"register_constraint_function",
225+
[](PyPDLPatternModule &self, const std::string &name,
226+
const nb::callable &fn) {
227+
self.registerConstraintFunction(name, fn);
228+
},
202229
nb::keep_alive<1, 3>());
203230
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
204231
nb::class_<PyFrozenRewritePatternSet>(m, "FrozenRewritePatternSet")

mlir/lib/CAPI/Transforms/Rewrite.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,21 +398,41 @@ void mlirPDLResultListPushBackAttribute(MlirPDLResultList results,
398398
unwrap(results)->push_back(unwrap(value));
399399
}
400400

401+
inline std::vector<MlirPDLValue> wrap(ArrayRef<PDLValue> values) {
402+
std::vector<MlirPDLValue> mlirValues;
403+
mlirValues.reserve(values.size());
404+
for (auto &value : values) {
405+
mlirValues.push_back(wrap(&value));
406+
}
407+
return mlirValues;
408+
}
409+
401410
void mlirPDLPatternModuleRegisterRewriteFunction(
402411
MlirPDLPatternModule pdlModule, MlirStringRef name,
403412
MlirPDLRewriteFunction rewriteFn, void *userData) {
404413
unwrap(pdlModule)->registerRewriteFunction(
405414
unwrap(name),
406415
[userData, rewriteFn](PatternRewriter &rewriter, PDLResultList &results,
407416
ArrayRef<PDLValue> values) -> LogicalResult {
408-
std::vector<MlirPDLValue> mlirValues;
409-
mlirValues.reserve(values.size());
410-
for (auto &value : values) {
411-
mlirValues.push_back(wrap(&value));
412-
}
417+
std::vector<MlirPDLValue> mlirValues = wrap(values);
413418
return unwrap(rewriteFn(wrap(&rewriter), wrap(&results),
414419
mlirValues.size(), mlirValues.data(),
415420
userData));
416421
});
417422
}
423+
424+
void mlirPDLPatternModuleRegisterConstraintFunction(
425+
MlirPDLPatternModule pdlModule, MlirStringRef name,
426+
MlirPDLConstraintFunction constraintFn, void *userData) {
427+
unwrap(pdlModule)->registerConstraintFunction(
428+
unwrap(name),
429+
[userData, constraintFn](PatternRewriter &rewriter,
430+
PDLResultList &results,
431+
ArrayRef<PDLValue> values) -> LogicalResult {
432+
std::vector<MlirPDLValue> mlirValues = wrap(values);
433+
return unwrap(constraintFn(wrap(&rewriter), wrap(&results),
434+
mlirValues.size(), mlirValues.data(),
435+
userData));
436+
});
437+
}
418438
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH

0 commit comments

Comments
 (0)