Skip to content

Commit

Permalink
Improved GREIT rasterize algorithm. (#92)
Browse files Browse the repository at this point in the history
* using barycentric method to calculate point in triangles.
* correct lint warnings
  • Loading branch information
liubenyuan authored Aug 17, 2023
1 parent ac7e82e commit a3b631c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 16 deletions.
4 changes: 2 additions & 2 deletions examples/figures_of_merit_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def main():
fig.set_size_inches(10, 4)

fig, axs = plt.subplots(1, 2)
im_simulation = create_image_plot(axs[0], sim_render, title="Target image")
im_recon = create_image_plot(axs[1], recon_render, title="Reconstruction image")
create_image_plot(axs[0], sim_render, title="Target image")
create_image_plot(axs[1], recon_render, title="Reconstruction image")
fig.set_size_inches(10, 4)

fig, axs = plt.subplots(1, 2, constrained_layout=True)
Expand Down
4 changes: 2 additions & 2 deletions pyeit/eit/greit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import scipy.linalg as la
from .base import EitBase
from .interp2d import meshgrid, weight_sigmod
from .interp2d import rasterize, weight_sigmod


class GREIT(EitBase):
Expand Down Expand Up @@ -91,7 +91,7 @@ def setup(
}

# Build grids and mask
self.xg, self.yg, self.mask = meshgrid(self.mesh.node, n=n)
self.xg, self.yg, self.mask = rasterize(self.mesh.node, self.mesh.element, n=n)

w_mat = self._compute_grid_weights(self.xg, self.yg)
self.J, self.v0 = self.fwd.compute_jac(perm=perm, normalize=jac_normalized)
Expand Down
68 changes: 57 additions & 11 deletions pyeit/eit/interp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,52 @@
from scipy.spatial import ConvexHull


def meshgrid(
pts: np.ndarray, n: int = 32, ext_ratio: float = 0.0, gc: bool = False
class TriangleRasterizer:
def __init__(self, pts, tri):
tp = pts[:, np.newaxis][tri].squeeze()
tri_vec = tp[:, [1, 2, 0]] - tp
self.tp = tp
self.atot = np.abs(self._tri_area(tri_vec[:, 0], tri_vec[:, 1]))

@staticmethod
def _tri_area(bar0, bar1):
return bar0[:, 0] * bar1[:, 1] - bar0[:, 1] * bar1[:, 0]

def _point_in_triangle(self, v):
tv = self.tp - v
a0 = self._tri_area(tv[:, 0], tv[:, 1])
a1 = self._tri_area(tv[:, 1], tv[:, 2])
a2 = self._tri_area(tv[:, 2], tv[:, 0])
asum = np.sum(np.abs(np.vstack([a0, a1, a2])), axis=0)
# add a margin for in-triangle test
return np.any(asum <= 1.01 * self.atot)

def points_in_triangles(self, varray):
return np.array([self._point_in_triangle(v) for v in varray])


def rasterize(
pts: np.ndarray,
tri: np.ndarray,
method: str = "cg",
n: int = 32,
ext_ratio: float = 0.0,
gc: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
build xg, yg, mask grids from triangles point cloud
rasterize triangles point cloud and returns (xg, yg, mask)
function for interpolating regular grids
Parameters
----------
pts: np.ndarray
nx2 array of points (x, y)
nx2 array of points {(x, y)}
tri: np.ndarray
nx3 array of points connection {(i0, i1, i2)}
method: str
"cg", test a point in a triangle using barycentric coordinates
"quick": test the distance from a point to centers of elements
"qhull": using convex hull
n: int
the number of meshgrid per dimension, by default 32
ext_ratio: float
Expand All @@ -47,8 +82,16 @@ def meshgrid(
mask denotes points outside mesh.
"""
xg, yg = _build_grid(pts, n=n, ext_ratio=ext_ratio, gc=gc)
pts_edges = _hull_points(pts)
mask = _build_mask(pts_edges, xg, yg)
points = np.vstack((xg.flatten(), yg.flatten())).T

# perform rasterize on meshgrids
if method == "cg":
TR = TriangleRasterizer(pts[:, :2], tri)
mask = ~TR.points_in_triangles(points)
else:
pts_edges = _hull_points(pts)
mask = _build_mask(pts_edges, xg, yg)

return xg, yg, mask


Expand Down Expand Up @@ -550,22 +593,24 @@ def demo() -> None:

# plot mesh and interpolated mesh (tri2pts)
fig_size = (6, 4)
fig = plt.figure(figsize=fig_size)
fig = plt.figure(figsize=fig_size, dpi=200)
ax = fig.add_subplot(111)
ax.set_aspect("equal")
ax.triplot(pts[:, 0], pts[:, 1], tri)
im1 = ax.tripcolor(pts[:, 0], pts[:, 1], tri, mesh_new.perm)
ax.set_title("mesh_obj and anomaly")
im1 = ax.tripcolor(pts[:, 0], pts[:, 1], tri, mesh_new.perm, alpha=0.8)
fig.colorbar(im1, orientation="vertical")

fig = plt.figure(figsize=fig_size)
fig = plt.figure(figsize=fig_size, dpi=200)
ax2 = fig.add_subplot(111)
ax2.set_aspect("equal")
ax2.triplot(pts[:, 0], pts[:, 1], tri)
ax2.set_title("mesh_obj and anomaly on nodes")
im2 = ax2.tripcolor(pts[:, 0], pts[:, 1], tri, perm_node, shading="flat")
fig.colorbar(im2, orientation="vertical")

# 3. interpolate on grids (irregular or regular) using IDW, sigmod
xg, yg, mask = meshgrid(pts)
xg, yg, mask = rasterize(pts, tri)
im = np.ones_like(mask)
# mapping from values on xy to values on xyi
xy = np.mean(pts[tri], axis=1)
Expand All @@ -579,9 +624,10 @@ def demo() -> None:
im = im.reshape(xg.shape)

# plot interpolated values
fig, ax = plt.subplots(figsize=fig_size)
fig, ax = plt.subplots(figsize=fig_size, dpi=200)
ax.set_aspect("equal")
ax.triplot(pts[:, 0], pts[:, 1], tri, alpha=0.5)
ax.set_title("mesh_obj and anomaly rasterized")
im3 = ax.pcolor(xg, yg, im, edgecolors=None, linewidth=0, alpha=0.8)
fig.colorbar(im3, orientation="vertical")
plt.show()
Expand Down
2 changes: 1 addition & 1 deletion pyeit/mesh/mesh_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_electrodes(self):

# place electrodes uniformly on the boundary
n = np.linspace(
el_start, el_start + el_len, num=self.n_el, endpoint=False, dtype=np.int
el_start, el_start + el_len, num=self.n_el, endpoint=False, dtype=int
)

# for FMMU, electrodes should be placed clockwise
Expand Down

0 comments on commit a3b631c

Please sign in to comment.