Skip to content

Commit 563bd6c

Browse files
committed
read_mesh: switch to using BaseReader subclass
1 parent 218012c commit 563bd6c

File tree

1 file changed

+150
-133
lines changed

1 file changed

+150
-133
lines changed

navis/io/mesh_io.py

Lines changed: 150 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -12,42 +12,132 @@
1212
# GNU General Public License for more details.
1313

1414
import os
15+
import io
1516

16-
import multiprocessing as mp
1717
import trimesh as tm
1818

19-
from pathlib import Path
2019
from typing import Union, Iterable, Optional, Dict, Any
2120
from typing_extensions import Literal
21+
from urllib3 import HTTPResponse
2222

2323
from .. import config, utils, core
2424
from . import base
2525

2626
# Set up logging
2727
logger = config.get_logger(__name__)
2828

29+
# Mesh files can have all sort of extensions
30+
DEFAULT_FMT = "{name}.{file_ext}"
31+
32+
# Mesh extensions supported by trimesh
33+
MESH_LOAD_EXT = tuple(tm.exchange.load.mesh_loaders.keys())
34+
MESH_WRITE_EXT = tuple(tm.exchange.export._mesh_exporters.keys())
35+
36+
37+
class MeshReader(base.BaseReader):
38+
def __init__(
39+
self,
40+
output: str,
41+
fmt: str = DEFAULT_FMT,
42+
attrs: Optional[Dict[str, Any]] = None,
43+
):
44+
super().__init__(
45+
fmt=fmt,
46+
attrs=attrs,
47+
file_ext=MESH_LOAD_EXT,
48+
name_fallback="MESH",
49+
read_binary=True,
50+
)
51+
self.output = output
52+
53+
def format_output(self, x):
54+
# This function replaces the BaseReader.format_output()
55+
# This is to avoid trying to convert multiple (image, header) to NeuronList
56+
if self.output == "trimesh":
57+
return x
58+
elif x:
59+
return core.NeuronList(x)
60+
else:
61+
return core.NeuronList([])
62+
63+
@base.handle_errors
64+
def read_buffer(
65+
self, f, attrs: Optional[Dict[str, Any]] = None
66+
) -> Union[tm.Trimesh, "core.Volume", "core.MeshNeuron"]:
67+
"""Read buffer into mesh.
68+
69+
Parameters
70+
----------
71+
f : IO
72+
Readable buffer (must be bytes).
73+
attrs : dict | None
74+
Arbitrary attributes to include in the neurons.
75+
76+
Returns
77+
-------
78+
Trimesh | MeshNeuron | Volume
79+
80+
"""
81+
if isinstance(f, HTTPResponse):
82+
f = io.StringIO(f.content)
83+
84+
if isinstance(f, bytes):
85+
f = io.BytesIO(f)
86+
87+
# We need to tell trimesh what file type we are reading
88+
if "file" not in attrs:
89+
raise KeyError(
90+
f'Unable to parse file type. "file" not in attributes: {attrs}'
91+
)
92+
93+
file_type = attrs["file"].split(".")[-1]
94+
95+
mesh = tm.load_mesh(f, file_type=file_type)
96+
97+
if self.output == "trimesh":
98+
return mesh
99+
elif self.output == "volume":
100+
return core.Volume(mesh.vertices, mesh.faces, **attrs)
101+
102+
# Turn into a MeshNeuron
103+
n = core.MeshNeuron(mesh)
104+
105+
# Try adding properties one-by-one. If one fails, we'll keep track of it
106+
# in the `.meta` attribute
107+
meta = {}
108+
for k, v in attrs.items():
109+
try:
110+
n._register_attr(k, v)
111+
except (AttributeError, ValueError, TypeError):
112+
meta[k] = v
113+
114+
if meta:
115+
n.meta = meta
116+
117+
return n
29118

30-
def read_mesh(f: Union[str, Iterable],
31-
include_subdirs: bool = False,
32-
parallel: Union[bool, int] = 'auto',
33-
output: Union[Literal['neuron'],
34-
Literal['volume'],
35-
Literal['trimesh']] = 'neuron',
36-
errors: Union[Literal['raise'],
37-
Literal['log'],
38-
Literal['ignore']] = 'log',
39-
limit: Optional[int] = None,
40-
**kwargs) -> 'core.NeuronObject':
41-
"""Create Neuron/List from mesh.
119+
120+
def read_mesh(
121+
f: Union[str, Iterable],
122+
include_subdirs: bool = False,
123+
parallel: Union[bool, int] = "auto",
124+
output: Union[Literal["neuron"], Literal["volume"], Literal["trimesh"]] = "neuron",
125+
errors: Literal["raise", "log", "ignore"] = "raise",
126+
limit: Optional[int] = None,
127+
fmt: str = "{name}.",
128+
**kwargs,
129+
) -> "core.NeuronObject":
130+
"""Load mesh file into Neuron/List.
42131
43132
This is a thin wrapper around `trimesh.load_mesh` which supports most
44-
common formats (obj, ply, stl, etc.).
133+
commonly used formats (obj, ply, stl, etc.).
45134
46135
Parameters
47136
----------
48137
f : str | iterable
49-
Filename(s) or folder. If folder must include file
50-
extension (e.g. `my/dir/*.ply`).
138+
Filename(s) or folder. If folder should include file
139+
extension (e.g. `my/dir/*.ply`) otherwise all
140+
mesh files in the folder will be read.
51141
include_subdirs : bool, optional
52142
If True and `f` is a folder, will also search
53143
subdirectories for meshes.
@@ -59,9 +149,10 @@ def read_mesh(f: Union[str, Iterable],
59149
neurons. Integer will be interpreted as the number of
60150
cores (otherwise defaults to `os.cpu_count() - 2`).
61151
output : "neuron" | "volume" | "trimesh"
62-
Determines function's output. See Returns.
152+
Determines function's output - see `Returns`.
63153
errors : "raise" | "log" | "ignore"
64-
If "log" or "ignore", errors will not be raised.
154+
If "log" or "ignore", errors will not be raised and the
155+
mesh will be skipped. Can result in empty output.
65156
limit : int | str | slice | list, optional
66157
When reading from a folder or archive you can use this parameter to
67158
restrict the which files read:
@@ -81,19 +172,24 @@ def read_mesh(f: Union[str, Iterable],
81172
82173
Returns
83174
-------
84-
navis.MeshNeuron
175+
MeshNeuron
85176
If `output="neuron"` (default).
86-
navis.Volume
177+
Volume
87178
If `output="volume"`.
88-
trimesh.Trimesh
89-
If `output='trimesh'`.
90-
navis.NeuronList
179+
Trimesh
180+
If `output="trimesh"`.
181+
NeuronList
91182
If `output="neuron"` and import has multiple meshes
92183
will return NeuronList of MeshNeurons.
93184
list
94185
If `output!="neuron"` and import has multiple meshes
95186
will return list of Volumes or Trimesh.
96187
188+
See Also
189+
--------
190+
[`navis.read_precomputed`][]
191+
Read meshes and skeletons from Neuroglancer's precomputed format.
192+
97193
Examples
98194
--------
99195
@@ -114,101 +210,19 @@ def read_mesh(f: Union[str, Iterable],
114210
>>> nl = navis.read_mesh('mesh.obj', output='volume') # doctest: +SKIP
115211
116212
"""
117-
utils.eval_param(output, name='output',
118-
allowed_values=('neuron', 'volume', 'trimesh'))
119-
120-
# If is directory, compile list of filenames
121-
if isinstance(f, str) and '*' in f:
122-
f, ext = f.split('*')
123-
f = Path(f).expanduser()
124-
125-
if not f.is_dir():
126-
raise ValueError(f'{f} does not appear to exist')
127-
128-
if not include_subdirs:
129-
f = list(f.glob(f'*{ext}'))
130-
else:
131-
f = list(f.rglob(f'*{ext}'))
132-
133-
if limit:
134-
f = f[:limit]
135-
136-
if utils.is_iterable(f):
137-
# Do not use if there is only a small batch to import
138-
if isinstance(parallel, str) and parallel.lower() == 'auto':
139-
if len(f) < 100:
140-
parallel = False
141-
142-
if parallel:
143-
# Do not swap this as `isinstance(True, int)` returns `True`
144-
if isinstance(parallel, (bool, str)):
145-
n_cores = os.cpu_count() - 2
146-
else:
147-
n_cores = int(parallel)
148-
149-
with mp.Pool(processes=n_cores) as pool:
150-
results = pool.imap(_worker_wrapper, [dict(f=x,
151-
output=output,
152-
errors=errors,
153-
include_subdirs=include_subdirs,
154-
parallel=False) for x in f],
155-
chunksize=1)
156-
157-
res = list(config.tqdm(results,
158-
desc='Importing',
159-
total=len(f),
160-
disable=config.pbar_hide,
161-
leave=config.pbar_leave))
162-
163-
else:
164-
# If not parallel just import the good 'ole way: sequentially
165-
res = [read_mesh(x,
166-
include_subdirs=include_subdirs,
167-
output=output,
168-
errors=errors,
169-
parallel=parallel,
170-
**kwargs)
171-
for x in config.tqdm(f, desc='Importing',
172-
disable=config.pbar_hide,
173-
leave=config.pbar_leave)]
174-
175-
if output == 'neuron':
176-
return core.NeuronList([r for r in res if r])
177-
178-
return res
179-
180-
try:
181-
# Open the file
182-
fname = '.'.join(os.path.basename(f).split('.')[:-1])
183-
mesh = tm.load_mesh(f)
184-
185-
if output == 'trimesh':
186-
return mesh
187-
188-
attrs = {'name': fname, 'origin': f}
189-
attrs.update(kwargs)
190-
if output == 'volume':
191-
return core.Volume(mesh.vertices, mesh.faces, **attrs)
192-
else:
193-
return core.MeshNeuron(mesh, **attrs)
194-
except BaseException as e:
195-
msg = f'Error reading file {fname}.'
196-
if errors == 'raise':
197-
raise ImportError(msg) from e
198-
elif errors == 'log':
199-
logger.error(f'{msg}: {e}')
200-
return
201-
213+
utils.eval_param(
214+
output, name="output", allowed_values=("neuron", "volume", "trimesh")
215+
)
202216

203-
def _worker_wrapper(kwargs):
204-
"""Helper for importing meshes using multiple processes."""
205-
return read_mesh(**kwargs)
217+
reader = MeshReader(fmt=fmt, output=output, errors=errors, attrs=kwargs)
218+
return reader.read_any(f, include_subdirs, parallel, limit=limit)
206219

207220

208-
def write_mesh(x: Union['core.NeuronList', 'core.MeshNeuron', 'core.Volume', 'tm.Trimesh'],
209-
filepath: Optional[str] = None,
210-
filetype: str = None,
211-
) -> None:
221+
def write_mesh(
222+
x: Union["core.NeuronList", "core.MeshNeuron", "core.Volume", "tm.Trimesh"],
223+
filepath: Optional[str] = None,
224+
filetype: str = None,
225+
) -> None:
212226
"""Export meshes (MeshNeurons, Volumes, Trimeshes) to disk.
213227
214228
Under the hood this is using trimesh to export meshes.
@@ -264,41 +278,44 @@ def write_mesh(x: Union['core.NeuronList', 'core.MeshNeuron', 'core.Volume', 'tm
264278
>>> navis.write_mesh(nl, tmp_dir / 'meshes.zip', filetype='obj')
265279
266280
"""
267-
ALLOWED_FILETYPES = ('stl', 'ply', 'obj')
268281
if filetype is not None:
269-
utils.eval_param(filetype, name='filetype', allowed_values=ALLOWED_FILETYPES)
282+
utils.eval_param(filetype, name="filetype", allowed_values=MESH_WRITE_EXT)
270283
else:
271284
# See if we can get filetype from filepath
272285
if filepath is not None:
273-
for f in ALLOWED_FILETYPES:
274-
if str(filepath).endswith(f'.{f}'):
286+
for f in MESH_WRITE_EXT:
287+
if str(filepath).endswith(f".{f}"):
275288
filetype = f
276289
break
277290

278291
if not filetype:
279-
raise ValueError('Must provide mesh type either explicitly via '
280-
'`filetype` variable or implicitly via the '
281-
'file extension in `filepath`')
292+
raise ValueError(
293+
"Must provide mesh type either explicitly via "
294+
"`filetype` variable or implicitly via the "
295+
"file extension in `filepath`"
296+
)
282297

283-
writer = base.Writer(_write_mesh, ext=f'.{filetype}')
298+
writer = base.Writer(_write_mesh, ext=f".{filetype}")
284299

285-
return writer.write_any(x,
286-
filepath=filepath)
300+
return writer.write_any(x, filepath=filepath)
287301

288302

289-
def _write_mesh(x: Union['core.MeshNeuron', 'core.Volume', 'tm.Trimesh'],
290-
filepath: Optional[str] = None) -> None:
303+
def _write_mesh(
304+
x: Union["core.MeshNeuron", "core.Volume", "tm.Trimesh"],
305+
filepath: Optional[str] = None,
306+
) -> None:
291307
"""Write single mesh to disk."""
292308
if filepath and os.path.isdir(filepath):
293309
if isinstance(x, core.MeshNeuron):
294310
if not x.id:
295-
raise ValueError('Neuron(s) must have an ID when destination '
296-
'is a folder')
297-
filepath = os.path.join(filepath, f'{x.id}')
311+
raise ValueError(
312+
"Neuron(s) must have an ID when destination " "is a folder"
313+
)
314+
filepath = os.path.join(filepath, f"{x.id}")
298315
elif isinstance(x, core.Volume):
299-
filepath = os.path.join(filepath, f'{x.name}')
316+
filepath = os.path.join(filepath, f"{x.name}")
300317
else:
301-
raise ValueError(f'Unable to generate filename for {type(x)}')
318+
raise ValueError(f"Unable to generate filename for {type(x)}")
302319

303320
if isinstance(x, core.MeshNeuron):
304321
mesh = x.trimesh

0 commit comments

Comments
 (0)