diff --git a/lib/Bindings/Python/SupportModule.cpp b/lib/Bindings/Python/SupportModule.cpp index e32010373985..ed159c5aa554 100644 --- a/lib/Bindings/Python/SupportModule.cpp +++ b/lib/Bindings/Python/SupportModule.cpp @@ -28,16 +28,32 @@ void circt::python::populateSupportSubmodule(py::module &m) { m.def( "_walk_with_filter", [](MlirOperation operation, const std::vector &opNames, - std::function callback, - MlirWalkOrder walkOrder) { + std::function callback, + py::object walkOrderRaw) { struct UserData { - std::function callback; + std::function callback; bool gotException; std::string exceptionWhat; py::object exceptionType; std::vector opNames; }; + // As we transition from pybind11 to nanobind, the WalkOrder enum and + // automatic casting will be defined as a nanobind enum upstream. Do a + // manual conversion that works with either pybind11 or nanobind for + // now. When we're on nanobind in CIRCT, we can go back to automatic + // casting. + MlirWalkOrder walkOrder; + auto walkOrderRawValue = py::cast(walkOrderRaw.attr("value")); + switch (walkOrderRawValue) { + case 0: + walkOrder = MlirWalkOrder::MlirWalkPreOrder; + break; + case 1: + walkOrder = MlirWalkOrder::MlirWalkPostOrder; + break; + } + std::vector opNamesIdentifiers; opNamesIdentifiers.reserve(opNames.size()); @@ -68,7 +84,27 @@ void circt::python::populateSupportSubmodule(py::module &m) { return MlirWalkResult::MlirWalkResultAdvance; try { - return (calleeUserData->callback)(op); + // As we transition from pybind11 to nanobind, the WalkResult enum + // and automatic casting will be defined as a nanobind enum + // upstream. Do a manual conversion that works with either pybind11 + // or nanobind for now. When we're on nanobind in CIRCT, we can go + // back to automatic casting. + MlirWalkResult walkResult; + auto walkResultRaw = (calleeUserData->callback)(op); + auto walkResultRawValue = + py::cast(walkResultRaw.attr("value")); + switch (walkResultRawValue) { + case 0: + walkResult = MlirWalkResult::MlirWalkResultAdvance; + break; + case 1: + walkResult = MlirWalkResult::MlirWalkResultInterrupt; + break; + case 2: + walkResult = MlirWalkResult::MlirWalkResultSkip; + break; + } + return walkResult; } catch (py::error_already_set &e) { calleeUserData->gotException = true; calleeUserData->exceptionWhat = e.what();