From 729151e792732310d652c963cc6cdce554ac1e16 Mon Sep 17 00:00:00 2001 From: Marcin Wojdyr Date: Mon, 7 Oct 2024 15:00:34 +0200 Subject: [PATCH] python: add __array__() where we had buffer protocol previously --- python/mtz.cpp | 30 +++++++++++++++++------------- python/unitcell.cpp | 21 ++++++++++++++------- 2 files changed, 31 insertions(+), 20 deletions(-) diff --git a/python/mtz.cpp b/python/mtz.cpp index 3b635195..01d3f69f 100644 --- a/python/mtz.cpp +++ b/python/mtz.cpp @@ -62,6 +62,19 @@ auto make_new_column(const Mtz& mtz, int dataset, Func func) { return numpy_arr; } +auto mtz_to_array(Mtz& self) { + size_t nrow = self.has_data() ? (size_t) self.nreflections : 0; + size_t ncol = self.columns.size(); + return nb::ndarray>(self.data.data(), {nrow, ncol}, nb::handle()); +} + +auto column_to_array(Mtz::Column& self) { + return nb::ndarray>(self.parent->data.data() + self.idx, + {(size_t)self.size()}, + nb::handle(), + {(int64_t) self.stride()}); +} + } // anonymous namespace void add_mtz(nb::module_& m) { @@ -78,12 +91,8 @@ void add_mtz(nb::module_& m) { mtz .def(nb::init(), nb::arg("with_base")=false) - .def_prop_ro("array", [](Mtz& self) { - size_t nrow = self.has_data() ? (size_t) self.nreflections : 0; - size_t ncol = self.columns.size(); - return nb::ndarray>( - self.data.data(), {nrow, ncol}, nb::handle()); - }, nb::rv_policy::reference_internal) + .def_prop_ro("array", &mtz_to_array, nb::rv_policy::reference_internal) + .def("__array__", &mtz_to_array, nb::rv_policy::reference_internal) .def_rw("title", &Mtz::title) .def_rw("nreflections", &Mtz::nreflections) .def_rw("sort_order", &Mtz::sort_order) @@ -292,13 +301,8 @@ void add_mtz(nb::module_& m) { }) ; pyMtzColumn - .def_prop_ro("array", [](const Mtz::Column& self) { - return nb::ndarray>( - self.parent->data.data() + self.idx, - {(size_t)self.size()}, - nb::handle(), - {(int64_t) self.stride()}); - }, nb::rv_policy::reference_internal) + .def_prop_ro("array", &column_to_array, nb::rv_policy::reference_internal) + .def("__array__", &column_to_array, nb::rv_policy::reference_internal) .def_prop_ro("dataset", (Mtz::Dataset& (Mtz::Column::*)()) &Mtz::Column::dataset) .def_rw("dataset_id", &Mtz::Column::dataset_id) diff --git a/python/unitcell.cpp b/python/unitcell.cpp index 474e7d58..638ffd81 100644 --- a/python/unitcell.cpp +++ b/python/unitcell.cpp @@ -17,23 +17,32 @@ using namespace gemmi; -static std::string triple(double x, double y, double z) { +namespace { + +std::string triple(double x, double y, double z) { char buf[128]; auto r = [](double d) { return std::fabs(d) < 1e-15 ? 0 : d; }; snprintf_z(buf, 128, "%g, %g, %g", r(x), r(y), r(z)); return std::string(buf); } -static void mat33_from_list(Mat33& self, std::array,3>& m) { +void mat33_from_list(Mat33& self, std::array,3>& m) { for (int i = 0; i < 3; ++i) for (int j = 0; j < 3; ++j) self.a[i][j] = m[i][j]; } -static nb::tuple make_six_tuple(const std::array& v) { +nb::tuple make_six_tuple(const std::array& v) { return nb::make_tuple(v[0], v[1], v[2], v[3], v[4], v[5]); } +auto mat33_to_array(Mat33& self) { + return nb::ndarray, nb::c_contig>( + &self.a[0][0], {3, 3}, nb::handle()); +} + +} // anonymous namespace + template void add_smat33(nb::module_& m, const char* name) { using M = SMat33; nb::class_(m, name) @@ -135,10 +144,8 @@ void add_unitcell(nb::module_& m) { new(mat) Mat33(); mat33_from_list(*mat, arr); }) - .def_prop_ro("array", [](Mat33& self) { - return nb::ndarray, nb::c_contig>( - &self.a[0][0], {3, 3}, nb::handle()); - }, nb::rv_policy::reference_internal) + .def_prop_ro("array", &mat33_to_array, nb::rv_policy::reference_internal) + .def("__array__", &mat33_to_array, nb::rv_policy::reference_internal) .def("row_copy", &Mat33::row_copy) .def("column_copy", &Mat33::column_copy) .def(nb::self + nb::self)