Skip to content

Commit

Permalink
Add support to current numpy __array__ ndarray method
Browse files Browse the repository at this point in the history
  • Loading branch information
Antoine DECHAUME committed Feb 12, 2025
1 parent 6503cb1 commit 74331dc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
5 changes: 3 additions & 2 deletions rpyc/core/netref.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,14 @@ def method(self, start, stop, *args):
method.__doc__ = doc
return method
elif name == "__array__":
def __array__(self):
def __array__(self, *args, **kwargs):
# Note that protocol=-1 will only work between python
# interpreters of the same version.
if not object.__getattribute__(self,'____conn__')._config["allow_pickle"]:
# Security check that server side allows pickling per #551
raise ValueError("pickling is disabled")
return pickle.loads(syncreq(self, consts.HANDLE_PICKLE, -1))
array = pickle.loads(syncreq(self, consts.HANDLE_PICKLE, -1))
return array.__array__(*args, **kwargs)
__array__.__doc__ = doc
return __array__
else:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import with_statement
import rpyc
import unittest
try:
import numpy as np
_numpy_import_failed = False
except Exception:
_numpy_import_failed = True


class MyService(rpyc.Service):

def exposed_create_array(self, array):
return np.array(array, dtype=np.int64, copy=True)


@unittest.skipIf(_numpy_import_failed, "Skipping since numpy cannot be imported")
class TestNumpy(unittest.TestCase):
def setUp(self):
self.server = rpyc.utils.server.OneShotServer(MyService, port=0, protocol_config={"allow_pickle":True})
self.server.logger.quiet = False
self.server._start_in_thread()
self.conn = rpyc.connect("localhost", port=self.server.port, config={"allow_pickle":True})

def tearDown(self):
self.conn.close()

def test_numpy(self):
remote_array = self.conn.root.create_array(np.array([0.]))
self.assertEqual(remote_array[0], 0)
self.assertIsInstance(remote_array[0], np.int64)


if __name__ == "__main__":
unittest.main()

0 comments on commit 74331dc

Please sign in to comment.