Skip to content

Commit

Permalink
Ensure that np.prod returns a 64-bit integer
Browse files Browse the repository at this point in the history
This can cause overflow issues when checking the size of large arrays.
  • Loading branch information
rayosborn committed Dec 14, 2023
1 parent 64160ca commit ab282b1
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/nexusformat/nexus/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def _readdata(self, name):
else:
field = self.get(self.nxpath)
# Read in the data if it's not too large
if np.prod(field.shape) < 1000: # i.e., less than 1k dims
if _getsize(field.shape) < 1000: # i.e., less than 1k dims
try:
value = self.readvalue(self.nxpath)
except Exception:
Expand Down Expand Up @@ -1204,7 +1204,7 @@ def readvalues(self, attrs=None):
return None, None, None, {}
shape, dtype = field.shape, field.dtype
# Read in the data if it's not too large
if np.prod(shape) < 1000: # i.e., less than 1k dims
if _getsize(shape) < 1000: # i.e., less than 1k dims
try:
value = self.readvalue(self.nxpath)
except Exception:
Expand Down Expand Up @@ -1706,7 +1706,7 @@ def _getsize(shape):
return 1
else:
try:
return np.prod(shape)
return np.prod(shape, dtype=np.int64)
except Exception:
return 1

Expand Down Expand Up @@ -3129,7 +3129,7 @@ def _get_uncopied_data(self, idx=None):
self._create_memfile()
f.copy(_path, self._memfile, name='data')
self._uncopied_data = None
if (np.prod(self.shape) * np.dtype(self.dtype).itemsize
if (_getsize(self.shape) * np.dtype(self.dtype).itemsize
<= NX_CONFIG['memory']*1000*1000):
return f.readvalue(_path)
else:
Expand Down Expand Up @@ -3744,7 +3744,7 @@ def nxdata(self):
if self._value is None:
if self.dtype is None or self.shape is None:
return None
if (np.prod(self.shape) * np.dtype(self.dtype).itemsize
if (_getsize(self.shape) * np.dtype(self.dtype).itemsize
<= NX_CONFIG['memory']*1000*1000):
try:
if self.nxfilemode:
Expand Down Expand Up @@ -4037,7 +4037,7 @@ def ndim(self):
@property
def size(self):
"""Total size of the NXfield."""
return int(np.prod(self.shape))
return _getsize(self.shape)

@property
def nbytes(self):
Expand Down

0 comments on commit ab282b1

Please sign in to comment.