diff --git a/gunpowder/graph.py b/gunpowder/graph.py index 3321c5ac..52eadd29 100644 --- a/gunpowder/graph.py +++ b/gunpowder/graph.py @@ -150,10 +150,20 @@ def u(self): def v(self): return self.__v + @property + def attrs(self): + return self.__attrs + @property def all(self): return self.__attrs + @classmethod + def from_attrs(cls, attrs: Dict[str, Any]): + u = attrs["u"] + v = attrs["v"] + return cls(u, v, attrs=attrs) + def __iter__(self): return iter([self.u, self.v]) @@ -287,6 +297,13 @@ def node(self, id: int): attrs = self.__graph.nodes[id] return Node.from_attrs(attrs) + def edge(self, id: tuple[int, int]): + """ + Get specific edge + """ + attrs = self.__graph.edges[id] + return Edge.from_attrs(attrs) + def contains(self, node_id: int): return node_id in self.__graph.nodes diff --git a/gunpowder/nodes/rasterize_graph.py b/gunpowder/nodes/rasterize_graph.py index bb2473f6..50c1c5eb 100644 --- a/gunpowder/nodes/rasterize_graph.py +++ b/gunpowder/nodes/rasterize_graph.py @@ -299,6 +299,7 @@ def __rasterize( settings.mode == "ball" and settings.inner_radius_fraction is None and len(list(graph.edges)) == 0 + and settings.color_attr is None ) if use_fast_rasterization: @@ -347,7 +348,7 @@ def __rasterize( else: if settings.color_attr is not None: - c = graph.nodes[node].get(settings.color_attr) + c = node.attrs.get(settings.color_attr) if c is None: logger.debug(f"Skipping node: {node}") continue @@ -363,7 +364,7 @@ def __rasterize( if settings.edges: for e in graph.edges: if settings.color_attr is not None: - c = graph.edges[e].get(settings.color_attr) + c = e.attrs.get(settings.color_attr) if c is None: continue elif np.isclose(c, 1) and not np.isclose(settings.fg_value, 1): @@ -372,26 +373,44 @@ def __rasterize( f"attribute {settings.color_attr} " f"but color 1 will be replaced with fg_value: {settings.fg_value}" ) + else: + c = 1 u = graph.node(e.u) v = graph.node(e.v) u_coord = Coordinate(u.location / voxel_size) v_coord = Coordinate(v.location / voxel_size) line = draw.line_nd(u_coord, v_coord, endpoint=True) - rasterized_graph[line] = 1 + rasterized_graph[line] = c # grow graph if not use_fast_rasterization: if settings.mode == "ball": - enlarge_binary_map( - rasterized_graph, - settings.radius, - voxel_size, - settings.inner_radius_fraction, - in_place=True, - ) + if settings.color_attr is not None: + for color in np.unique(rasterized_graph): + if color == 0: + continue + assert color in [2,3], np.unique(rasterized_graph) + mask = rasterized_graph == color + enlarge_binary_map( + mask, + settings.radius, + voxel_size, + settings.inner_radius_fraction, + in_place=True, + ) + rasterized_graph[mask] = color + else: + enlarge_binary_map( + rasterized_graph, + settings.radius, + voxel_size, + settings.inner_radius_fraction, + in_place=True, + ) else: + sigmas = settings.radius / voxel_size gaussian_filter( diff --git a/tests/cases/rasterize_points.py b/tests/cases/rasterize_points.py index a57906f8..f6cb85e2 100644 --- a/tests/cases/rasterize_points.py +++ b/tests/cases/rasterize_points.py @@ -17,6 +17,36 @@ import numpy as np +def test_rasterize_graph_colors(): + graph = Graph( + [ + Node(id=1, location=np.array((0.5, 0.5)), attrs={"color": 2}), + Node(id=2, location=np.array((0.5, 4.5)), attrs={"color": 2}), + Node(id=3, location=np.array((4.5, 0.5)), attrs={"color": 3}), + Node(id=4, location=np.array((4.5, 4.5)), attrs={"color": 3}), + ], + [Edge(1, 2, attrs={"color": 2}), Edge(3, 4, attrs={"color": 3})], + GraphSpec(roi=Roi((0, 0), (5, 5))), + ) + + graph_key = GraphKey("G") + array_key = ArrayKey("A") + graph_source = GraphSource(graph_key, graph) + pipeline = graph_source + RasterizeGraph( + graph_key, + array_key, + ArraySpec(roi=Roi((0, 0), (5, 5)), voxel_size=Coordinate(1, 1), dtype=np.uint8), + settings=RasterizationSettings(1, color_attr="color"), + ) + with build(pipeline): + request = BatchRequest() + request[array_key] = ArraySpec(Roi((0, 0), (5, 5))) + rasterized = pipeline.request_batch(request)[array_key].data + assert rasterized[0, 0] == 2 + assert rasterized[0, :].sum() == 10 + assert rasterized[4, 0] == 3 + assert rasterized[4, :].sum() == 15 + def test_3d(): graph_key = GraphKey("TEST_GRAPH") array_key = ArrayKey("TEST_ARRAY")