Skip to content

Commit

Permalink
update rasterize graph node to properly color edges
Browse files Browse the repository at this point in the history
adds a test for this behavior as well
  • Loading branch information
pattonw committed Jan 3, 2024
1 parent 218c813 commit 3686130
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 10 deletions.
17 changes: 17 additions & 0 deletions gunpowder/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,20 @@ def u(self):
def v(self):
return self.__v

@property
def attrs(self):
return self.__attrs

@property
def all(self):
return self.__attrs

@classmethod
def from_attrs(cls, attrs: Dict[str, Any]):
u = attrs["u"]
v = attrs["v"]
return cls(u, v, attrs=attrs)

def __iter__(self):
return iter([self.u, self.v])

Expand Down Expand Up @@ -287,6 +297,13 @@ def node(self, id: int):
attrs = self.__graph.nodes[id]
return Node.from_attrs(attrs)

def edge(self, id: tuple[int, int]):
"""
Get specific edge
"""
attrs = self.__graph.edges[id]
return Edge.from_attrs(attrs)

def contains(self, node_id: int):
return node_id in self.__graph.nodes

Expand Down
39 changes: 29 additions & 10 deletions gunpowder/nodes/rasterize_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def __rasterize(
settings.mode == "ball"
and settings.inner_radius_fraction is None
and len(list(graph.edges)) == 0
and settings.color_attr is None
)

if use_fast_rasterization:
Expand Down Expand Up @@ -347,7 +348,7 @@ def __rasterize(

else:
if settings.color_attr is not None:
c = graph.nodes[node].get(settings.color_attr)
c = node.attrs.get(settings.color_attr)
if c is None:
logger.debug(f"Skipping node: {node}")
continue
Expand All @@ -363,7 +364,7 @@ def __rasterize(
if settings.edges:
for e in graph.edges:
if settings.color_attr is not None:
c = graph.edges[e].get(settings.color_attr)
c = e.attrs.get(settings.color_attr)
if c is None:
continue
elif np.isclose(c, 1) and not np.isclose(settings.fg_value, 1):
Expand All @@ -372,26 +373,44 @@ def __rasterize(
f"attribute {settings.color_attr} "
f"but color 1 will be replaced with fg_value: {settings.fg_value}"
)
else:
c = 1

u = graph.node(e.u)
v = graph.node(e.v)
u_coord = Coordinate(u.location / voxel_size)
v_coord = Coordinate(v.location / voxel_size)
line = draw.line_nd(u_coord, v_coord, endpoint=True)
rasterized_graph[line] = 1
rasterized_graph[line] = c

# grow graph
if not use_fast_rasterization:
if settings.mode == "ball":
enlarge_binary_map(
rasterized_graph,
settings.radius,
voxel_size,
settings.inner_radius_fraction,
in_place=True,
)
if settings.color_attr is not None:
for color in np.unique(rasterized_graph):
if color == 0:
continue
assert color in [2,3], np.unique(rasterized_graph)
mask = rasterized_graph == color
enlarge_binary_map(
mask,
settings.radius,
voxel_size,
settings.inner_radius_fraction,
in_place=True,
)
rasterized_graph[mask] = color
else:
enlarge_binary_map(
rasterized_graph,
settings.radius,
voxel_size,
settings.inner_radius_fraction,
in_place=True,
)

else:

sigmas = settings.radius / voxel_size

gaussian_filter(
Expand Down
30 changes: 30 additions & 0 deletions tests/cases/rasterize_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,36 @@
import numpy as np


def test_rasterize_graph_colors():
graph = Graph(
[
Node(id=1, location=np.array((0.5, 0.5)), attrs={"color": 2}),
Node(id=2, location=np.array((0.5, 4.5)), attrs={"color": 2}),
Node(id=3, location=np.array((4.5, 0.5)), attrs={"color": 3}),
Node(id=4, location=np.array((4.5, 4.5)), attrs={"color": 3}),
],
[Edge(1, 2, attrs={"color": 2}), Edge(3, 4, attrs={"color": 3})],
GraphSpec(roi=Roi((0, 0), (5, 5))),
)

graph_key = GraphKey("G")
array_key = ArrayKey("A")
graph_source = GraphSource(graph_key, graph)
pipeline = graph_source + RasterizeGraph(
graph_key,
array_key,
ArraySpec(roi=Roi((0, 0), (5, 5)), voxel_size=Coordinate(1, 1), dtype=np.uint8),
settings=RasterizationSettings(1, color_attr="color"),
)
with build(pipeline):
request = BatchRequest()
request[array_key] = ArraySpec(Roi((0, 0), (5, 5)))
rasterized = pipeline.request_batch(request)[array_key].data
assert rasterized[0, 0] == 2
assert rasterized[0, :].sum() == 10
assert rasterized[4, 0] == 3
assert rasterized[4, :].sum() == 15

def test_3d():
graph_key = GraphKey("TEST_GRAPH")
array_key = ArrayKey("TEST_ARRAY")
Expand Down

0 comments on commit 3686130

Please sign in to comment.