Skip to content

Commit debd521

Browse files
committed
[doc/fix]add a Jupyternotebook for bezierfit and fix some small bugs in bezierfit
1 parent 687a45c commit debd521

29 files changed

+2229
-173
lines changed
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
import sys
2-
sys.path.insert(0, '/data3/Zhen/MemXTerminator/src/')

build/lib/memxterminator/GUI/particle_membrane_subtraction_bezierfit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def retranslateUi(self, ParticleMembraneSubtraction_bezierfit):
107107
self.controlpoints_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Control points file"))
108108
self.template_browse_pushButton.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Browse..."))
109109
self.particle.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Particle .cs file"))
110-
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.005"))
110+
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.001"))
111111
self.points_step_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Points_step"))
112112
self.physical_membrane_dist_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Physical membrane distance(Å)"))
113113
self.physical_membrane_dist_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "35"))

build/lib/memxterminator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
__author__ = 'Zhen Huang'
22
__email__ = 'zhen.victor.huang@gmail.com'
3-
__version__ = '1.2.1'
3+
__version__ = '1.2.2'

build/lib/memxterminator/bezierfit/bin/mem_subtract_main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def membrane_subtract(particle_filename):
8080
shifts = shift_list[mask]
8181
classes = class_list[mask]
8282
for particle_idx, psi, pixel_size, shift, class_ in zip(particle_idxes, psis, pixel_sizes, shifts, classes):
83-
# class_得根据control_points.json找到control_points
8483
# if str(class_) in control_points_dict:
8584
control_points = np.array(control_points_dict[str(class_)])
8685
# print(control_points)

build/lib/memxterminator/bezierfit/lib/bezierfit.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,13 @@
1414
import json
1515
from scipy.ndimage import zoom
1616

17-
18-
# def bezier_curve(control_points, t):
19-
# B = np.outer((1 - t) ** 3, control_points[0]) + \
20-
# np.outer(3 * (1 - t) ** 2 * t, control_points[1]) + \
21-
# np.outer(3 * (1 - t) * t ** 2, control_points[2]) + \
22-
# np.outer(t ** 3, control_points[3])
23-
# return B.squeeze()
2417
def bezier_curve(control_points, t):
2518
n = len(control_points) - 1
2619
B = np.zeros_like(control_points[0], dtype=float)
2720
for i, point in enumerate(control_points):
2821
B += comb(n, i) * (1 - t) ** (n - i) * t ** i * point
2922
return B
30-
# def bezier_curve_derivative(control_points, t):
31-
# control_points = np.array(control_points)
32-
# B_prime = 3 * (1 - t) ** 2 * (control_points[1] - control_points[0]) + \
33-
# 6 * (1 - t) * t * (control_points[2] - control_points[1]) + \
34-
# 3 * t ** 2 * (control_points[3] - control_points[2])
35-
# return B_prime
23+
3624
def bezier_curve_derivative(control_points, t):
3725
n = len(control_points) - 1
3826
B_prime = np.zeros(2)
@@ -41,25 +29,6 @@ def bezier_curve_derivative(control_points, t):
4129
B_prime += coef * (control_points[i+1] - control_points[i])
4230
return B_prime
4331

44-
# def bezier_curvature(control_points, t):
45-
# dB0 = -3 * (1 - t) ** 2
46-
# dB1 = 3 * (1 - t) ** 2 - 6 * t * (1 - t)
47-
# dB2 = 6 * t * (1 - t) - 3 * t ** 2
48-
# dB3 = 3 * t ** 2
49-
50-
# ddB0 = 6 * (1 - t)
51-
# ddB1 = 6 - 18 * t
52-
# ddB2 = 18 * t - 6
53-
# ddB3 = 6 * t
54-
55-
# p = control_points
56-
# dx = sum([p[i, 0] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
57-
# dy = sum([p[i, 1] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
58-
# ddx = sum([p[i, 0] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])
59-
# ddy = sum([p[i, 1] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])
60-
61-
# curvature = abs(dx * ddy - dy * ddx) / (dx * dx + dy * dy) ** 1.5
62-
# return curvature
6332
def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6):
6433
n = len(control_points) - 1
6534

@@ -79,7 +48,6 @@ def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6
7948

8049
magnitude_squared = dx * dx + dy * dy
8150

82-
# 规避除数接近零的问题
8351
if magnitude_squared < threshold:
8452
return high_curvature_value
8553

@@ -331,22 +299,3 @@ def generate_curve_within_boundaries(control_points, image_shape, step):
331299
break
332300
fitted_curve_points = np.array([bezier_curve(control_points, t_val) for t_val in t_values])
333301
return np.array(fitted_curve_points), np.array(t_values)
334-
335-
if __name__ == '__main__':
336-
multiprocessing.set_start_method('spawn', force=True)
337-
with mrcfile.open('/data3/kzhang/cryosparc/CS-vsv/J354/templates_selected.mrc') as f:
338-
image = f.data[2]
339-
image = zoom(image, 2)
340-
coarsefit = Coarsefit(image, 600, 3, 300, 20)
341-
initial_control_points = coarsefit()
342-
ga_refine = GA_Refine(image, 1.068, 0.05, 50, 700, 18)
343-
refined_control_points = ga_refine(initial_control_points, image)
344-
refined_control_points = np.array(refined_control_points)
345-
fitted_curve_points, t_values = generate_curve_within_boundaries(refined_control_points, image.shape, 0.01)
346-
# save the control points in JSON format
347-
with open('control_points.json', 'w') as f:
348-
json.dump(refined_control_points.tolist(), f)
349-
plt.imshow(image, cmap='gray')
350-
plt.plot(fitted_curve_points[:, 0], fitted_curve_points[:, 1], 'r-')
351-
plt.plot(refined_control_points[:, 0], refined_control_points[:, 1], 'g.')
352-
plt.show()

build/lib/memxterminator/bezierfit/lib/subtraction.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def generate_2d_mask(self, image, fitted_points, membrane_distance):
153153
membrane_mask[mask_outside_distance & ~mask_small_gray_value] = gray_value[mask_outside_distance & ~mask_small_gray_value]
154154
return membrane_mask
155155

156-
def average_1d(self, image_gpu, fitted_points, normals, extra_mem_dist):
156+
def average_1d(self, image_gpu, fitted_points, normals, mem_dist):
157157
average_1d_lst = []
158-
for membrane_dist in range(-extra_mem_dist, extra_mem_dist+1):
158+
for membrane_dist in range(-mem_dist, mem_dist+1):
159159
normals_points = fitted_points + membrane_dist * normals
160160
# Ensure the points are within the image boundaries
161161
mask = (normals_points[:, 0] >= 0) & (normals_points[:, 0] < image_gpu.shape[1]) & \
@@ -227,11 +227,11 @@ def average_1d(self, image_gpu, fitted_points, normals, extra_mem_dist):
227227
# new_image[cp.isnan(new_image)] = 0
228228
# return new_image.astype(image.dtype)
229229

230-
def average_2d(self, image_gpu, fitted_points, normals, average_1d_lst, extra_mem_dist):
230+
def average_2d(self, image_gpu, fitted_points, normals, average_1d_lst, mem_dist):
231231
image = image_gpu.get()
232232
new_image = np.zeros_like(image)
233233
count_image = np.zeros_like(image)
234-
for membrane_dist, average_1d in zip(range(-extra_mem_dist, extra_mem_dist+1), average_1d_lst):
234+
for membrane_dist, average_1d in zip(range(-mem_dist, mem_dist+1), average_1d_lst):
235235
# start_time = time.time()
236236
normals_points = fitted_points + membrane_dist * normals
237237
mask = (normals_points[:, 0] >= 0) & (normals_points[:, 0] < image.shape[1]) & \
@@ -247,11 +247,11 @@ def average_2d(self, image_gpu, fitted_points, normals, average_1d_lst, extra_me
247247
new_image[np.isnan(new_image)] = 0
248248
return new_image.astype(image.dtype)
249249

250-
def average_2d_gpu(self, image_gpu, fitted_points, normals, average_1d_lst, extra_mem_dist):
250+
def average_2d_gpu(self, image_gpu, fitted_points, normals, average_1d_lst, mem_dist):
251251
new_image = cp.zeros_like(image_gpu)
252252
count_image = cp.zeros_like(image_gpu)
253253
fitted_points = cp.asarray(fitted_points)
254-
membrane_dists = cp.arange(-extra_mem_dist, extra_mem_dist + 1)
254+
membrane_dists = cp.arange(-mem_dist, mem_dist + 1)
255255
# Expand dimensions for broadcasting
256256
membrane_dists = membrane_dists[:, cp.newaxis, cp.newaxis]
257257
# Calculate all normals_points at once
@@ -275,14 +275,11 @@ def average_2d_gpu(self, image_gpu, fitted_points, normals, average_1d_lst, extr
275275
def mem_subtract(self):
276276
control_points = self.control_points_trasf(self.control_points, self.psi, self.origin_x, self.origin_y)
277277
fitted_curve_points, t_values = generate_curve_within_boundaries(control_points, self.image.shape, self.points_step)
278-
# plt.imshow(self.image, cmap='gray')
279-
# plt.plot(fitted_curve_points[:, 0], fitted_curve_points[:, 1], 'r-')
280-
# plt.plot(control_points[:, 0], control_points[:, 1], 'g.')
281-
# plt.show()
278+
extra_mem_dist = 10
282279
mem_mask = self.generate_2d_mask(self.image_gpu, fitted_curve_points, self.mem_dist)
283-
raw_image_average_1d_lst = self.average_1d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), self.mem_dist)
284-
raw_image_average_2d = self.average_2d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), raw_image_average_1d_lst, self.mem_dist)
285-
raw_image_average_2d = cp.asarray(raw_image_average_2d)
280+
raw_image_average_1d_lst = self.average_1d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), self.mem_dist+extra_mem_dist)
281+
raw_image_average_2d = self.average_2d(self.image_gpu, fitted_curve_points, points_along_normal(control_points, t_values).get(), raw_image_average_1d_lst, self.mem_dist+extra_mem_dist)
282+
raw_image_average_2d = cp.asarray(raw_image_average_2d) * mem_mask
286283
kernel = gaussian_kernel(5, 1)
287284
image_conv = convolve2d(self.image_gpu, kernel, mode = 'same')
288285
raw_image_average_2d_conv = convolve2d(raw_image_average_2d, kernel, mode = 'same')
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
import sys
2-
sys.path.insert(0, '/data3/Zhen/MemXTerminator/src/')

build/lib/memxterminator/cli/bezierfit_cli.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,11 @@ def kill_process(self):
105105
os.remove(self.PID_FILE)
106106
self.timer.stop()
107107
def update_log(self):
108-
# 读取日志文件内容
109108
try:
110109
with open('run.out', 'r') as f:
111-
f.seek(self.last_read_position) # 跳转到上次读取的位置
112-
new_content = f.read() # 读取新内容
113-
self.last_read_position = f.tell() # 更新读取的位置
110+
f.seek(self.last_read_position)
111+
new_content = f.read()
112+
self.last_read_position = f.tell()
114113
if new_content:
115114
self.LOG_textBrowser.append(new_content)
116115
except FileNotFoundError:
@@ -198,12 +197,11 @@ def kill_process(self):
198197
os.remove(self.PID_FILE)
199198
self.timer.stop()
200199
def update_log(self):
201-
# 读取日志文件内容
202200
try:
203201
with open('run.out', 'r') as f:
204-
f.seek(self.last_read_position) # 跳转到上次读取的位置
205-
new_content = f.read() # 读取新内容
206-
self.last_read_position = f.tell() # 更新读取的位置
202+
f.seek(self.last_read_position)
203+
new_content = f.read()
204+
self.last_read_position = f.tell()
207205
if new_content:
208206
self.LOG_textBrowser.append(new_content)
209207
except FileNotFoundError:

build/lib/memxterminator/cli/radonfit_cli.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,11 @@ def kill_process(self):
222222
os.remove(self.PID_FILE)
223223
self.timer.stop()
224224
def update_log(self):
225-
# 读取日志文件内容
226225
try:
227226
with open('run.out', 'r') as f:
228-
f.seek(self.last_read_position) # 跳转到上次读取的位置
229-
new_content = f.read() # 读取新内容
230-
self.last_read_position = f.tell() # 更新读取的位置
227+
f.seek(self.last_read_position)
228+
new_content = f.read()
229+
self.last_read_position = f.tell()
231230
if new_content:
232231
self.textBrowser_log.append(new_content)
233232
except FileNotFoundError:
@@ -274,7 +273,7 @@ def __init__(self, parent=None):
274273
self.last_read_position = 0
275274
self.timer = QtCore.QTimer(self)
276275
self.timer.timeout.connect(self.update_log)
277-
self.timer.start(1000) # 每秒更新一次
276+
self.timer.start(1000)
278277

279278

280279
def browse_mem_analysis_starfile(self):
78.2 KB
Binary file not shown.

dist/MemXTerminator-1.2.2.tar.gz

59.3 KB
Binary file not shown.

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,6 @@ nav:
124124
- Frequently Asked Questions: ./tutorials/faq.md
125125
- Reference Pages:
126126
- "Visualize Radonfit (.ipynb)": ./tutorials/reference/radonfit-mem-analysis-visualizer.ipynb
127+
- "Visualize Bezierfit (.ipynb)": ./tutorials/reference/bezierfit-mem-analysis-visualizer.ipynb
127128
- "Conventions": ./tutorials/reference/conventions.md
128129
- "API Reference": ./tutorials/reference/api.md

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name='MemXTerminator',
5-
version='1.2.1',
5+
version='1.2.2',
66
packages=find_packages(where='src'),
77
package_dir={'': 'src'},
88
author='Zhen Huang',

src/MemXTerminator.egg-info/PKG-INFO

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Metadata-Version: 2.1
22
Name: MemXTerminator
3-
Version: 1.2.1
3+
Version: 1.2.2
44
Summary: A software for membrane analysis and subtraction in cryo-EM
55
Home-page: https://github.com/ZhenHuangLab/MemXTerminator
66
Author: Zhen Huang

src/memxterminator/GUI/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
import sys
2-
sys.path.insert(0, '/data3/Zhen/MemXTerminator/src/')

src/memxterminator/GUI/particle_membrane_subtraction_bezierfit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def retranslateUi(self, ParticleMembraneSubtraction_bezierfit):
107107
self.controlpoints_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Control points file"))
108108
self.template_browse_pushButton.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Browse..."))
109109
self.particle.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Particle .cs file"))
110-
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.005"))
110+
self.points_step_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "0.001"))
111111
self.points_step_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Points_step"))
112112
self.physical_membrane_dist_label.setText(_translate("ParticleMembraneSubtraction_bezierfit", "Physical membrane distance(Å)"))
113113
self.physical_membrane_dist_lineEdit.setText(_translate("ParticleMembraneSubtraction_bezierfit", "35"))

src/memxterminator/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
__author__ = 'Zhen Huang'
22
__email__ = 'zhen.victor.huang@gmail.com'
3-
__version__ = '1.2.1'
3+
__version__ = '1.2.2'

src/memxterminator/bezierfit/bin/mem_subtract_main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ def membrane_subtract(particle_filename):
8080
shifts = shift_list[mask]
8181
classes = class_list[mask]
8282
for particle_idx, psi, pixel_size, shift, class_ in zip(particle_idxes, psis, pixel_sizes, shifts, classes):
83-
# class_得根据control_points.json找到control_points
8483
# if str(class_) in control_points_dict:
8584
control_points = np.array(control_points_dict[str(class_)])
8685
# print(control_points)

src/memxterminator/bezierfit/lib/bezierfit.py

Lines changed: 1 addition & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,13 @@
1414
import json
1515
from scipy.ndimage import zoom
1616

17-
18-
# def bezier_curve(control_points, t):
19-
# B = np.outer((1 - t) ** 3, control_points[0]) + \
20-
# np.outer(3 * (1 - t) ** 2 * t, control_points[1]) + \
21-
# np.outer(3 * (1 - t) * t ** 2, control_points[2]) + \
22-
# np.outer(t ** 3, control_points[3])
23-
# return B.squeeze()
2417
def bezier_curve(control_points, t):
2518
n = len(control_points) - 1
2619
B = np.zeros_like(control_points[0], dtype=float)
2720
for i, point in enumerate(control_points):
2821
B += comb(n, i) * (1 - t) ** (n - i) * t ** i * point
2922
return B
30-
# def bezier_curve_derivative(control_points, t):
31-
# control_points = np.array(control_points)
32-
# B_prime = 3 * (1 - t) ** 2 * (control_points[1] - control_points[0]) + \
33-
# 6 * (1 - t) * t * (control_points[2] - control_points[1]) + \
34-
# 3 * t ** 2 * (control_points[3] - control_points[2])
35-
# return B_prime
23+
3624
def bezier_curve_derivative(control_points, t):
3725
n = len(control_points) - 1
3826
B_prime = np.zeros(2)
@@ -41,25 +29,6 @@ def bezier_curve_derivative(control_points, t):
4129
B_prime += coef * (control_points[i+1] - control_points[i])
4230
return B_prime
4331

44-
# def bezier_curvature(control_points, t):
45-
# dB0 = -3 * (1 - t) ** 2
46-
# dB1 = 3 * (1 - t) ** 2 - 6 * t * (1 - t)
47-
# dB2 = 6 * t * (1 - t) - 3 * t ** 2
48-
# dB3 = 3 * t ** 2
49-
50-
# ddB0 = 6 * (1 - t)
51-
# ddB1 = 6 - 18 * t
52-
# ddB2 = 18 * t - 6
53-
# ddB3 = 6 * t
54-
55-
# p = control_points
56-
# dx = sum([p[i, 0] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
57-
# dy = sum([p[i, 1] * [dB0, dB1, dB2, dB3][i] for i in range(4)])
58-
# ddx = sum([p[i, 0] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])
59-
# ddy = sum([p[i, 1] * [ddB0, ddB1, ddB2, ddB3][i] for i in range(4)])
60-
61-
# curvature = abs(dx * ddy - dy * ddx) / (dx * dx + dy * dy) ** 1.5
62-
# return curvature
6332
def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6):
6433
n = len(control_points) - 1
6534

@@ -79,7 +48,6 @@ def bezier_curvature(control_points, t, threshold=1e-6, high_curvature_value=1e6
7948

8049
magnitude_squared = dx * dx + dy * dy
8150

82-
# 规避除数接近零的问题
8351
if magnitude_squared < threshold:
8452
return high_curvature_value
8553

@@ -331,22 +299,3 @@ def generate_curve_within_boundaries(control_points, image_shape, step):
331299
break
332300
fitted_curve_points = np.array([bezier_curve(control_points, t_val) for t_val in t_values])
333301
return np.array(fitted_curve_points), np.array(t_values)
334-
335-
if __name__ == '__main__':
336-
multiprocessing.set_start_method('spawn', force=True)
337-
with mrcfile.open('/data3/kzhang/cryosparc/CS-vsv/J354/templates_selected.mrc') as f:
338-
image = f.data[2]
339-
image = zoom(image, 2)
340-
coarsefit = Coarsefit(image, 600, 3, 300, 20)
341-
initial_control_points = coarsefit()
342-
ga_refine = GA_Refine(image, 1.068, 0.05, 50, 700, 18)
343-
refined_control_points = ga_refine(initial_control_points, image)
344-
refined_control_points = np.array(refined_control_points)
345-
fitted_curve_points, t_values = generate_curve_within_boundaries(refined_control_points, image.shape, 0.01)
346-
# save the control points in JSON format
347-
with open('control_points.json', 'w') as f:
348-
json.dump(refined_control_points.tolist(), f)
349-
plt.imshow(image, cmap='gray')
350-
plt.plot(fitted_curve_points[:, 0], fitted_curve_points[:, 1], 'r-')
351-
plt.plot(refined_control_points[:, 0], refined_control_points[:, 1], 'g.')
352-
plt.show()

0 commit comments

Comments
 (0)