Skip to content

Commit d50ac4e

Browse files
committed
improve voxel indexing via kdtree
1 parent de890e7 commit d50ac4e

File tree

4 files changed

+282
-90
lines changed

4 files changed

+282
-90
lines changed

cylinder.py

Lines changed: 83 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -88,89 +88,86 @@ def transform_points_to_C0(points: np.ndarray, base_point: np.ndarray, axis_poin
8888
return points_transformed
8989

9090

91-
RCSB_ID = '3J7Z'
92-
radius = 40
93-
height = 80
94-
voxel_size = 1
95-
96-
# residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height)
97-
98-
base_point = np.array(PTC_location(RCSB_ID).location)
99-
axis_point = np.array( get_constriction(RCSB_ID) )
100-
# translation, rotation = get_transformation_to_C0(base, axis)
101-
# t_base = ( base + translation ) @ rotation.T
102-
# t_axis = ( axis + translation ) @ rotation.T
103-
104-
if os.path.exists('points.npy'):
105-
points = np.load('points.npy')
106-
print("Loaded")
107-
else:
108-
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, )
109-
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list])
110-
np.save('points.npy', points)
111-
print("Saved")
112-
...
113-
114-
115-
nx = ny = int(2 * radius / voxel_size) + 1
116-
nz = int(height / voxel_size) + 1
117-
x = np.linspace(-radius, radius, nx)
118-
y = np.linspace(-radius, radius, ny)
119-
z = np.linspace(0, height, nz)
120-
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
121-
122-
123-
transformed = transform_points_to_C0(points, base, axis)
124-
X_I = np.round(transformed[:,0])
125-
Y_I = np.round(transformed[:,1])
126-
Z_I = np.round(transformed[:,2])
127-
128-
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
129-
hollow_cylinder = ~cylinder_mask
130-
131-
# !-------------
132-
# 3. Create point cloud mask
133-
# point_cloud_mask = np.zeros_like(X, dtype=bool)
134-
# for point in zip(X_I, Y_I, Z_I):
135-
# point_cloud_mask |= (X == point[0]) & (Y == point[1]) & (Z == point[2])
136-
137-
138-
# !-------------
139-
radius_around_point = 2.0 # radius of sphere around each point
140-
# point_cloud_mask = np.zeros_like(X, dtype=bool)
141-
# for point in zip(X_I, Y_I, Z_I):
142-
# distance_to_point = np.sqrt(
143-
# (X - point[0])**2 +
144-
# (Y - point[1])**2 +
145-
# (Z - point[2])**2
146-
# )
147-
# point_cloud_mask |= (distance_to_point <= radius_around_point)
148-
149-
150-
# !-------------
151-
points = np.column_stack((X_I, Y_I, Z_I)) # Shape: (N, 3)
152-
point_cloud_mask = np.zeros_like(X, dtype=bool)
153-
154-
# Reshape grid coordinates for broadcasting
155-
grid_coords = np.stack([X, Y, Z]) # Shape: (3, nx, ny, nz)
156-
grid_coords = grid_coords.reshape(3, -1) # Shape: (3, nx*ny*nz)
157-
158-
for point in points:
159-
# Calculate distances using broadcasting
160-
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1))
161-
# Reshape back to grid shape and add to mask
162-
point_cloud_mask |= (distances.reshape(X.shape) <= radius_around_point)
163-
164-
165-
# !-------------
166-
167-
final_mask = hollow_cylinder | point_cloud_mask
168-
occupied = np.where(final_mask)
169-
170-
points = np.column_stack((
171-
x[occupied[0]],
172-
y[occupied[1]],
173-
z[occupied[2]]
174-
))
175-
occupied_points = pv.PolyData(points)
176-
visualize_pointcloud(occupied_points)
91+
if __name__ == '__main__':
92+
RCSB_ID = '3J7Z'
93+
radius = 40
94+
height = 80
95+
voxel_size = 1
96+
97+
98+
base_point = np.array(PTC_location(RCSB_ID).location)
99+
axis_point = np.array( get_constriction(RCSB_ID) )
100+
101+
if os.path.exists('points.npy'):
102+
points = np.load('points.npy')
103+
print("Loaded")
104+
else:
105+
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, )
106+
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list])
107+
np.save('points.npy', points)
108+
print("Saved")
109+
...
110+
111+
112+
nx = ny = int(2 * radius / voxel_size) + 1
113+
nz = int(height / voxel_size) + 1
114+
x = np.linspace(-radius, radius, nx)
115+
y = np.linspace(-radius, radius, ny)
116+
z = np.linspace(0, height, nz)
117+
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
118+
119+
120+
transformed = transform_points_to_C0(points, base_point, axis_point)
121+
X_I = np.round(transformed[:,0])
122+
Y_I = np.round(transformed[:,1])
123+
Z_I = np.round(transformed[:,2])
124+
125+
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
126+
hollow_cylinder = ~cylinder_mask
127+
128+
# !-------------
129+
# 3. Create point cloud mask
130+
# point_cloud_mask = np.zeros_like(X, dtype=bool)
131+
# for point in zip(X_I, Y_I, Z_I):
132+
# point_cloud_mask |= (X == point[0]) & (Y == point[1]) & (Z == point[2])
133+
134+
135+
# !-------------
136+
radius_around_point = 2.0 # radius of sphere around each point
137+
# point_cloud_mask = np.zeros_like(X, dtype=bool)
138+
# for point in zip(X_I, Y_I, Z_I):
139+
# distance_to_point = np.sqrt(
140+
# (X - point[0])**2 +
141+
# (Y - point[1])**2 +
142+
# (Z - point[2])**2
143+
# )
144+
# point_cloud_mask |= (distance_to_point <= radius_around_point)
145+
146+
147+
# !-------------
148+
points = np.column_stack((X_I, Y_I, Z_I)) # Shape: (N, 3)
149+
point_cloud_mask = np.zeros_like(X, dtype=bool)
150+
151+
# Reshape grid coordinates for broadcasting
152+
grid_coords = np.stack([X, Y, Z]) # Shape: (3, nx, ny, nz)
153+
grid_coords = grid_coords.reshape(3, -1) # Shape: (3, nx*ny*nz)
154+
155+
for point in points:
156+
# Calculate distances using broadcasting
157+
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1))
158+
# Reshape back to grid shape and add to mask
159+
point_cloud_mask |= (distances.reshape(X.shape) <= radius_around_point)
160+
161+
162+
# !-------------
163+
164+
final_mask = hollow_cylinder | point_cloud_mask
165+
occupied = np.where(final_mask)
166+
167+
points = np.column_stack((
168+
x[occupied[0]],
169+
y[occupied[1]],
170+
z[occupied[2]]
171+
))
172+
occupied_points = pv.PolyData(points)
173+
visualize_pointcloud(occupied_points)

cylinder_parallel.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
import os
2+
import numpy as np
3+
from typing import Tuple
4+
import multiprocessing as mp
5+
from cylinder import transform_points_to_C0
6+
from mesh_generation.mes_visualization import visualize_pointcloud
7+
from numba import jit
8+
import pyvista as pv
9+
10+
from ribctl.lib.landmarks.constriction import get_constriction
11+
from ribctl.lib.landmarks.ptc_via_trna import PTC_location
12+
from ribctl.lib.npet.tunnel_bbox_ptc_constriction import filter_residues_parallel, ribosome_entities
13+
14+
def chunk_points(points: np.ndarray, n_chunks: int) -> list:
15+
"""Split points into chunks for parallel processing"""
16+
chunk_size = len(points) // n_chunks
17+
return [points[i:i + chunk_size] for i in range(0, len(points), chunk_size)]
18+
19+
@jit(nopython=True)
20+
def process_points_chunk(points: np.ndarray, grid_coords: np.ndarray, radius: float) -> np.ndarray:
21+
"""Process a chunk of points with Numba acceleration"""
22+
mask = np.zeros(grid_coords.shape[1], dtype=np.bool_)
23+
for point in points:
24+
distances = np.sqrt(np.sum((grid_coords.T - point)**2, axis=1))
25+
mask |= (distances <= radius)
26+
return mask
27+
28+
def parallel_point_cloud_mask(points: np.ndarray, X: np.ndarray, Y: np.ndarray, Z: np.ndarray,
29+
radius_around_point: float) -> np.ndarray:
30+
"""Generate point cloud mask using parallel processing"""
31+
# Prepare grid coordinates once
32+
grid_coords = np.stack([X, Y, Z])
33+
original_shape = X.shape
34+
grid_coords = grid_coords.reshape(3, -1)
35+
36+
# Determine number of chunks based on CPU cores
37+
n_cores = mp.cpu_count() - 4
38+
chunks = chunk_points(points, n_cores)
39+
40+
# Process chunks in parallel
41+
with mp.Pool(n_cores) as pool:
42+
results = pool.starmap(process_points_chunk,
43+
[(chunk, grid_coords, radius_around_point) for chunk in chunks])
44+
45+
# Combine results
46+
final_mask = np.any(results, axis=0)
47+
return final_mask.reshape(original_shape)
48+
49+
def main():
50+
# Your existing setup code...
51+
RCSB_ID = '3J7Z'
52+
radius = 40
53+
height = 80
54+
voxel_size = 1
55+
ATOM_RADIUS = 2
56+
57+
# residues, base, axis = get_npet_cylinder_residues(RCSB_ID, radius=radius, height=height)
58+
59+
base_point = np.array(PTC_location(RCSB_ID).location)
60+
axis_point = np.array( get_constriction(RCSB_ID) )
61+
# translation, rotation = get_transformation_to_C0(base, axis)
62+
# t_base = ( base + translation ) @ rotation.T
63+
# t_axis = ( axis + translation ) @ rotation.T
64+
65+
if os.path.exists('points.npy'):
66+
points = np.load('points.npy')
67+
print("Loaded")
68+
else:
69+
residues= filter_residues_parallel( ribosome_entities(RCSB_ID, 'R'), base_point, axis_point, radius, height, )
70+
points = np.array([atom.get_coord() for residue in residues for atom in residue.child_list])
71+
np.save('points.npy', points)
72+
print("Saved")
73+
74+
nx = ny = int(2 * radius / voxel_size) + 1
75+
nz = int(height / voxel_size) + 1
76+
x = np.linspace(-radius, radius, nx)
77+
y = np.linspace(-radius, radius, ny)
78+
z = np.linspace(0, height, nz)
79+
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
80+
81+
# Transform points (vectorized)
82+
transformed = transform_points_to_C0(points, base_point, axis_point)
83+
X_I, Y_I, Z_I = transformed.T
84+
points = np.column_stack((X_I, Y_I, Z_I))
85+
86+
# Generate cylinder mask (vectorized)
87+
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
88+
hollow_cylinder = ~cylinder_mask
89+
90+
# Generate point cloud mask in parallel
91+
point_cloud_mask = parallel_point_cloud_mask(points, X, Y, Z, ATOM_RADIUS)
92+
93+
# Combine masks
94+
final_mask = hollow_cylinder | point_cloud_mask
95+
96+
# Visualize results
97+
occupied = np.where(~final_mask)
98+
visualization_points = np.column_stack((
99+
x[occupied[0]],
100+
y[occupied[1]],
101+
z[occupied[2]]
102+
))
103+
occupied_points = pv.PolyData(visualization_points)
104+
visualize_pointcloud(occupied_points)
105+
106+
107+
if __name__ == '__main__':
108+
main()

kdtree_approach.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
from scipy.spatial import cKDTree
3+
import pyvista as pv
4+
5+
from cylinder import transform_points_to_C0
6+
from mesh_generation.mes_visualization import visualize_pointcloud
7+
from ribctl.lib.landmarks.constriction import get_constriction
8+
from ribctl.lib.landmarks.ptc_via_trna import PTC_location
9+
10+
def generate_voxel_centers(radius: float, height: float, voxel_size: float) -> tuple:
11+
"""Generate centers of all voxels in the grid"""
12+
nx = ny = int(2 * radius / voxel_size) + 1
13+
nz = int(height / voxel_size) + 1
14+
15+
x = np.linspace(-radius, radius, nx)
16+
y = np.linspace(-radius, radius, ny)
17+
z = np.linspace(0, height, nz)
18+
19+
# Generate all voxel center coordinates
20+
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
21+
voxel_centers = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
22+
23+
return voxel_centers, (X.shape, x, y, z)
24+
25+
def create_point_cloud_mask(points: np.ndarray,
26+
radius: float,
27+
height: float,
28+
voxel_size: float = 1.0,
29+
radius_around_point: float = 2.0):
30+
"""
31+
Create point cloud mask using KDTree for efficient spatial queries
32+
"""
33+
# Generate voxel centers
34+
voxel_centers, (grid_shape, x, y, z) = generate_voxel_centers(radius, height, voxel_size)
35+
36+
# Create KDTree from the transformed points
37+
tree = cKDTree(points)
38+
39+
# Find all voxels that have points within radius_around_point
40+
# This is much more efficient than checking each point against each voxel
41+
indices = tree.query_ball_point(voxel_centers, radius_around_point)
42+
43+
# Create mask from the indices
44+
point_cloud_mask = np.zeros(len(voxel_centers), dtype=bool)
45+
point_cloud_mask[[i for i, idx in enumerate(indices) if idx]] = True
46+
47+
# Reshape mask back to grid shape
48+
point_cloud_mask = point_cloud_mask.reshape(grid_shape)
49+
50+
# Create cylinder mask
51+
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
52+
cylinder_mask = (np.sqrt(X**2 + Y**2) <= radius)
53+
hollow_cylinder = ~cylinder_mask
54+
55+
# Combine masks
56+
final_mask = hollow_cylinder | point_cloud_mask
57+
58+
return final_mask, (x, y, z)
59+
60+
def main():
61+
# Load your points and transform them as before
62+
points = np.load('points.npy')
63+
RCSB_ID = '3J7Z'
64+
base_point = np.array(PTC_location(RCSB_ID).location)
65+
axis_point = np.array(get_constriction(RCSB_ID) )
66+
print("loaded and got axis")
67+
transformed_points = transform_points_to_C0(points, base_point, axis_point)
68+
69+
final_mask, (x, y, z) = create_point_cloud_mask(
70+
transformed_points,
71+
radius = 40,
72+
height = 80,
73+
voxel_size = 1.0,
74+
radius_around_point = 2.0
75+
)
76+
77+
# Extract points for visualization
78+
occupied = np.where(~final_mask)
79+
visualization_points = np.column_stack((
80+
x[occupied[0]],
81+
y[occupied[1]],
82+
z[occupied[2]]
83+
))
84+
85+
# Visualize results
86+
occupied_points = pv.PolyData(visualization_points)
87+
visualize_pointcloud(occupied_points)
88+
89+
if __name__ == '__main__':
90+
main()

ribctl/lib/landmarks/constriction.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@ def get_constriction(rcsb_id: str)->np.ndarray:
1212
else:
1313
uL4 = ro.get_poly_by_polyclass('uL4')
1414
uL22 = ro.get_poly_by_polyclass('uL22')
15-
1615
if uL4 is None or uL22 is None:
1716
raise ValueError("Could not find uL4 or uL22 in {}".format(rcsb_id))
18-
1917
structure = ro.assets.biopython_structure()
20-
21-
uL4_c :Chain = structure[0][uL4.auth_asym_id]
18+
uL4_c :Chain = structure[0][uL4.auth_asym_id]
2219
uL22_c :Chain = structure[0][uL22.auth_asym_id]
2320

2421
uL4_coords = [(r.center_of_mass() ) for r in uL4_c.child_list]

0 commit comments

Comments
 (0)