Skip to content

Commit

Permalink
Replace Enum with map of tuples
Browse files Browse the repository at this point in the history
  • Loading branch information
josephburkhart committed Sep 20, 2024
1 parent a7aa0c2 commit 4c4099c
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions gsd/pygsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,19 @@
gsd_index_entry = namedtuple('gsd_index_entry', 'frame N location M id type flags')
gsd_index_entry_struct = struct.Struct('QQqIHBB')

gsd_type_mapping = Enum(
"gsd_type",
[
["uint8", numpy.dtype('uint8')],
["uint16", numpy.dtype('uint16')],
["uint32", numpy.dtype('uint32')],
["uint64", numpy.dtype('uint64')],
["int8", numpy.dtype('int8')],
["int16", numpy.dtype('int16')],
["int32", numpy.dtype('int32')],
["int64", numpy.dtype('int64')],
["float32", numpy.dtype('float32')],
["float64", numpy.dtype('float64')],
["str", numpy.dtype('int8')], # used for strings
]
)
gsd_type_mapping = {
1: ("uint8", numpy.dtype('uint8')),
2: ("uint16", numpy.dtype('uint16')),
3: ("uint32", numpy.dtype('uint32')),
4: ("uint64", numpy.dtype('uint64')),
5: ("int8", numpy.dtype('int8')),
6: ("int16", numpy.dtype('int16')),
7: ("int32", numpy.dtype('int32')),
8: ("int64", numpy.dtype('int64')),
9: ("float32", numpy.dtype('float32')),
10: ("float64", numpy.dtype('float64')),
11: ("str", numpy.dtype('int8')),
}


class GSDFile:
Expand Down Expand Up @@ -338,7 +335,7 @@ def read_chunk(self, frame, name):
'read chunk: ' + str(self.__file) + ' - ' + str(frame) + ' - ' + name
)

size = chunk.N * chunk.M * gsd_type_mapping[chunk.type].itemsize
size = chunk.N * chunk.M * gsd_type_mapping[chunk.type][1].itemsize
if chunk.location == 0:
raise RuntimeError(
'Corrupt chunk: '
Expand All @@ -350,7 +347,7 @@ def read_chunk(self, frame, name):
)

if size == 0:
return numpy.array([], dtype=gsd_type_mapping[chunk.type])
return numpy.array([], dtype=gsd_type_mapping[chunk.type][1])

self.__file.seek(chunk.location, 0)
data_raw = self.__file.read(size)
Expand All @@ -359,10 +356,10 @@ def read_chunk(self, frame, name):
raise OSError

# If gsd type is character, decode it here
if chunk.type == 11:
if gsd_type_mapping[chunk.type][0] == "str":
data_npy = data_raw.decode('utf-8')
else:
data_npy = numpy.frombuffer(data_raw, dtype=gsd_type_mapping[chunk.type])
data_npy = numpy.frombuffer(data_raw, dtype=gsd_type_mapping[chunk.type][1])

if chunk.M == 1:
return data_npy
Expand Down

0 comments on commit 4c4099c

Please sign in to comment.