Skip to content

Commit

Permalink
Use nb::iterator instead of Py_Iterator (#3120)
Browse files Browse the repository at this point in the history
  • Loading branch information
alkino authored Oct 9, 2024
1 parent ab0a571 commit 31c9245
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 55 deletions.
69 changes: 16 additions & 53 deletions src/nrnpython/nrnpy_p2h.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "parse.hpp"

#include <nanobind/nanobind.h>

namespace nb = nanobind;

static char* nrnpyerr_str();
Expand Down Expand Up @@ -926,8 +927,8 @@ static Object* py_alltoall_type(int size, int type) {
size = 0; // calculate dest size (cannot be -1 so cannot return it)
}

char* s = NULL;
int* scnt = NULL;
std::vector<char> s{};
std::vector<int> scnt{};
int* sdispl = NULL;
char* r = NULL;
int* rcnt = NULL;
Expand All @@ -937,49 +938,18 @@ static Object* py_alltoall_type(int size, int type) {
// for alltoall, each rank handled identically
// for scatter, root handled as list all, other ranks handled as None
if (type == 1 || nrnmpi_myid == root) { // psrc is list of nhost items

scnt = new int[np];
for (int i = 0; i < np; ++i) {
scnt[i] = 0;
}

PyObject* iterator = PyObject_GetIter(psrc);
PyObject* p;

size_t bufsz = 100000; // 100k buffer to start with
if (size > 0) { // or else the positive number specified
bufsz = size;
}
if (size >= 0) { // otherwise count only
s = new char[bufsz];
}
size_t curpos = 0;
for (size_t i = 0; (p = PyIter_Next(iterator)) != NULL; ++i) {
if (p == Py_None) {
scnt[i] = 0;
Py_DECREF(p);
scnt.reserve(np);
for (const nb::handle& p: nb::list(psrc)) {
if (p.is_none()) {
scnt.push_back(0);
continue;
}
auto b = pickle(p);
const std::vector<char> b = pickle(p.ptr());
if (size >= 0) {
if (curpos + b.size() >= bufsz) {
bufsz = bufsz * 2 + b.size();
char* s2 = new char[bufsz];
for (size_t i = 0; i < curpos; ++i) {
s2[i] = s[i];
}
delete[] s;
s = s2;
}
for (size_t j = 0; j < b.size(); ++j) {
s[curpos + j] = b[j];
}
s.insert(std::end(s), std::begin(b), std::end(b));
}
curpos += b.size();
scnt[i] = static_cast<int>(b.size());
Py_DECREF(p);
scnt.push_back(static_cast<int>(b.size()));
}
Py_DECREF(iterator);

// scatter equivalent to alltoall NONE list for not root ranks.
} else if (type == 5 && nrnmpi_myid != root) {
Expand All @@ -996,26 +966,23 @@ static Object* py_alltoall_type(int size, int type) {
}
sdispl = mk_displ(ones);
rcnt = new int[np];
nrnmpi_int_alltoallv(scnt, ones, sdispl, rcnt, ones, sdispl);
nrnmpi_int_alltoallv(scnt.data(), ones, sdispl, rcnt, ones, sdispl);
delete[] ones;
delete[] sdispl;

// exchange
sdispl = mk_displ(scnt);
sdispl = mk_displ(scnt.data());
rdispl = mk_displ(rcnt);
if (size < 0) {
pdest = PyTuple_New(2);
PyTuple_SetItem(pdest, 0, Py_BuildValue("l", (long) sdispl[np]));
PyTuple_SetItem(pdest, 1, Py_BuildValue("l", (long) rdispl[np]));
delete[] scnt;
delete[] sdispl;
delete[] rcnt;
delete[] rdispl;
} else {
char* r = new char[rdispl[np] + 1]; // force > 0 for all None case
nrnmpi_char_alltoallv(s, scnt, sdispl, r, rcnt, rdispl);
delete[] s;
delete[] scnt;
nrnmpi_char_alltoallv(s.data(), scnt.data(), sdispl, r, rcnt, rdispl);
delete[] sdispl;

pdest = char2pylist(r, np, rcnt, rdispl);
Expand All @@ -1029,18 +996,14 @@ static Object* py_alltoall_type(int size, int type) {

// destination counts
rcnt = new int[1];
nrnmpi_int_scatter(scnt, rcnt, 1, root);
nrnmpi_int_scatter(scnt.data(), rcnt, 1, root);
std::vector<char> r(rcnt[0] + 1); // rcnt[0] can be 0

// exchange
if (nrnmpi_myid == root) {
sdispl = mk_displ(scnt);
sdispl = mk_displ(scnt.data());
}
nrnmpi_char_scatterv(s, scnt, sdispl, r.data(), rcnt[0], root);
if (s)
delete[] s;
if (scnt)
delete[] scnt;
nrnmpi_char_scatterv(s.data(), scnt.data(), sdispl, r.data(), rcnt[0], root);
if (sdispl)
delete[] sdispl;

Expand Down
4 changes: 2 additions & 2 deletions src/utils/enumerate.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ constexpr auto enumerate(T&& iterable) {
++iter;
}
auto operator*() const {
return std::tie(i, *iter);
return std::forward_as_tuple(i, *iter);
}
};
struct iterable_wrapper {
Expand Down Expand Up @@ -129,7 +129,7 @@ constexpr auto renumerate(T&& iterable) {
++iter;
}
auto operator*() const {
return std::tie(i, *iter);
return std::forward_as_tuple(i, *iter);
}
};
struct iterable_wrapper {
Expand Down

0 comments on commit 31c9245

Please sign in to comment.