Skip to content

Commit 180b44b

Browse files
authored
Merge pull request #182 from ahoust17/main
fixed atom_tools
2 parents 496417a + ce48099 commit 180b44b

File tree

5 files changed

+284
-23
lines changed

5 files changed

+284
-23
lines changed

notebooks/4Dstem_File_Reader.ipynb

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Basics of reading an MRC file with 4D STEM data from the Spectra300 at UTK\n",
8+
"## By Austin Houston\n",
9+
"### Last updated 2024-09-14"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": null,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import os\n",
19+
"import sys\n",
20+
"\n",
21+
"%matplotlib ipympl\n",
22+
"import numpy as np\n",
23+
"import matplotlib.pyplot as plt\n",
24+
"\n",
25+
"sys.path.insert(0, '/Users/austin/Documents/GitHub/SciFiReaders/')\n",
26+
"import SciFiReaders\n",
27+
"\n",
28+
"sys.path.insert(0, '/Users/austin/Documents/GitHub/pyTEMlib/')\n",
29+
"import pyTEMlib\n",
30+
"import pyTEMlib.file_tools as ft\n",
31+
"\n",
32+
"print(\"SciFiReaders version: \", SciFiReaders.__version__)\n",
33+
"print(\"pyTEMlib version: \", pyTEMlib.__version__)\n",
34+
"\n",
35+
"# for beginning analysis\n",
36+
"from sklearn.cluster import KMeans\n",
37+
"from sklearn.decomposition import PCA\n",
38+
"from sklearn.cluster import KMeans\n"
39+
]
40+
},
41+
{
42+
"cell_type": "code",
43+
"execution_count": null,
44+
"metadata": {},
45+
"outputs": [],
46+
"source": [
47+
"mrc_filepath = '/Users/austin/Dropbox/GaTech_colabs/SnSe_MgO/2024_06_19_data/4D_STEM/'\n",
48+
"\n",
49+
"files = os.listdir(mrc_filepath)\n",
50+
"files = [f for f in files if f.endswith('.mrc')]\n",
51+
"\n",
52+
"# Load the first file\n",
53+
"dset = ft.open_file(mrc_filepath + files[1])"
54+
]
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"metadata": {},
60+
"outputs": [],
61+
"source": [
62+
"data = dset['Channel_000']\n",
63+
"\n",
64+
"view = data.plot()"
65+
]
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"metadata": {},
71+
"outputs": [],
72+
"source": [
73+
"mrc_array = np.array(data)\n",
74+
"N, M, height, width = data.shape\n",
75+
"datacube_flat = mrc_array.reshape(N * M, -1)"
76+
]
77+
},
78+
{
79+
"cell_type": "code",
80+
"execution_count": null,
81+
"metadata": {},
82+
"outputs": [],
83+
"source": [
84+
"# Perform KMeans clustering\n",
85+
"clusters = 3 \n",
86+
"kmeans = KMeans(n_clusters=clusters, random_state=0).fit(datacube_flat)\n",
87+
"labels = kmeans.labels_\n",
88+
"cluster_centers = kmeans.cluster_centers_\n",
89+
"\n",
90+
"# Reduce the data to 3D using PCA\n",
91+
"pca = PCA(n_components=3)\n",
92+
"datacube_reduced = pca.fit_transform(datacube_flat)\n",
93+
"cluster_centers_reduced = pca.transform(cluster_centers)\n",
94+
"\n",
95+
"# Create a 3D plot\n",
96+
"fig = plt.figure()\n",
97+
"ax = fig.add_subplot(111, projection='3d')\n",
98+
"scatter = ax.scatter(datacube_reduced[:, 0], datacube_reduced[:, 1], datacube_reduced[:, 2], c=labels, cmap='viridis', marker='o')\n",
99+
"ax.set_xlabel('PCA Component 1')\n",
100+
"ax.set_ylabel('PCA Component 2')\n",
101+
"ax.set_zlabel('PCA Component 3')\n",
102+
"ax.set_xticks([])\n",
103+
"ax.set_yticks([])\n",
104+
"ax.set_zticks([])\n",
105+
"plt.show()\n",
106+
"\n",
107+
"\n",
108+
"label_image = labels.reshape((M, N))\n",
109+
"\n",
110+
"plt.figure()\n",
111+
"plt.imshow(label_image, cmap='viridis')\n",
112+
"plt.colorbar()\n",
113+
"plt.show()\n",
114+
"\n",
115+
"# Reshape cluster centers back to original image dimensions\n",
116+
"cluster_center_images = cluster_centers.reshape((kmeans.n_clusters, height, width))\n",
117+
"\n",
118+
"# Plot the average images\n",
119+
"fig, axes = plt.subplots(1, kmeans.n_clusters, figsize=(15, 5))\n",
120+
"\n",
121+
"for i, ax in enumerate(axes):\n",
122+
" ax.imshow(cluster_center_images[i], cmap='viridis')\n",
123+
" ax.set_title(f'Cluster Center {i+1}')\n",
124+
" ax.axis('off')\n",
125+
"\n",
126+
"plt.show()"
127+
]
128+
},
129+
{
130+
"cell_type": "code",
131+
"execution_count": null,
132+
"metadata": {},
133+
"outputs": [],
134+
"source": []
135+
}
136+
],
137+
"metadata": {
138+
"kernelspec": {
139+
"display_name": "pytemlib",
140+
"language": "python",
141+
"name": "python3"
142+
},
143+
"language_info": {
144+
"codemirror_mode": {
145+
"name": "ipython",
146+
"version": 3
147+
},
148+
"file_extension": ".py",
149+
"mimetype": "text/x-python",
150+
"name": "python",
151+
"nbconvert_exporter": "python",
152+
"pygments_lexer": "ipython3",
153+
"version": "3.11.0"
154+
}
155+
},
156+
"nbformat": 4,
157+
"nbformat_minor": 2
158+
}

pyTEMlib/atom_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def find_atoms(image, atom_size=0.1, threshold=0.):
5353
if not isinstance(threshold, float):
5454
raise TypeError('threshold parameter has to be a float number')
5555

56-
scale_x = ft.get_slope(image.dim_0)
56+
scale_x = np.unique(np.gradient(image.dim_0.values))[0]
5757
im = np.array(image-image.min())
5858
im = im/im.max()
5959
if threshold <= 0.:

pyTEMlib/file_tools.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,9 @@
4444

4545
Dimension = sidpy.Dimension
4646

47-
get_slope = sidpy.base.num_utils.get_slope
48-
__version__ = '2022.3.3'
47+
# Austin commented the line below - it is not used anywhere in the code, and it gives import errors 9-14-2024
48+
# get_slope = sidpy.base.num_utils.get_slopes
49+
__version__ = '2024.9.14'
4950

5051
from traitlets import Unicode, Bool, validate, TraitError
5152
import ipywidgets
@@ -787,7 +788,7 @@ def h5_group_to_dict(group, group_dict={}):
787788

788789

789790
def open_file(filename=None, h5_group=None, write_hdf_file=False, sum_frames=False): # save_file=False,
790-
"""Opens a file if the extension is .hf5, .ndata, .dm3 or .dm4
791+
"""Opens a file if the extension is .emd, .mrc, .hf5, .ndata, .dm3 or .dm4
791792
792793
If no filename is provided the QT open_file windows opens (if QT_available==True)
793794
Everything will be stored in a NSID style hf5 file.
@@ -850,7 +851,7 @@ def open_file(filename=None, h5_group=None, write_hdf_file=False, sum_frames=Fa
850851
if not write_hdf_file:
851852
file.close()
852853
return dataset_dict
853-
elif extension in ['.dm3', '.dm4', '.ndata', '.ndata1', '.h5', '.emd', '.emi', '.edaxh5']:
854+
elif extension in ['.dm3', '.dm4', '.ndata', '.ndata1', '.h5', '.emd', '.emi', '.edaxh5', '.mrc']:
854855
# tags = open_file(filename)
855856
if extension in ['.dm3', '.dm4']:
856857
reader = SciFiReaders.DMReader(filename)
@@ -886,6 +887,9 @@ def open_file(filename=None, h5_group=None, write_hdf_file=False, sum_frames=Fa
886887
elif extension in ['.ndata', '.h5']:
887888
reader = SciFiReaders.NionReader(filename)
888889

890+
elif extension in ['.mrc']:
891+
reader = SciFiReaders.MRCReader(filename)
892+
889893
else:
890894
raise NotImplementedError('extension not supported')
891895

pyTEMlib/image_tools.py

Lines changed: 90 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@
5555
from scipy.optimize import leastsq
5656
from sklearn.cluster import DBSCAN
5757

58+
from ase.build import fcc110
59+
from pyTEMlib import probe_tools
60+
61+
from scipy.ndimage import rotate
62+
from scipy.interpolate import RegularGridInterpolator
63+
from scipy.signal import fftconvolve
64+
5865

5966
_SimpleITK_present = True
6067
try:
@@ -68,6 +75,72 @@
6875
'install with: conda install -c simpleitk simpleitk ')
6976

7077

78+
def get_atomic_pseudo_potential(fov, atoms, size=512, rotation=0):
79+
# Big assumption: the atoms are not near the edge of the unit cell
80+
# If any atoms are close to the edge (ex. [0,0]) then the potential will be clipped
81+
# before calling the function, shift the atoms to the center of the unit cell
82+
83+
pixel_size = fov / size
84+
max_size = int(size * np.sqrt(2) + 1) # Maximum size to accommodate rotation
85+
86+
# Create unit cell potential
87+
positions = atoms.get_positions()[:, :2]
88+
atomic_numbers = atoms.get_atomic_numbers()
89+
unit_cell_size = atoms.cell.cellpar()[:2]
90+
91+
unit_cell_potential = np.zeros((max_size, max_size))
92+
for pos, atomic_number in zip(positions, atomic_numbers):
93+
x = pos[0] / pixel_size
94+
y = pos[1] / pixel_size
95+
atom_width = 0.5 # Angstrom
96+
gauss_width = atom_width/pixel_size # important for images at various fov. Room for improvement with theory
97+
gauss = probe_tools.make_gauss(max_size, max_size, width = gauss_width, x0=x, y0=y)
98+
unit_cell_potential += gauss * atomic_number # gauss is already normalized to 1
99+
100+
# Create interpolation function for unit cell potential
101+
x_grid = np.linspace(0, fov * max_size / size, max_size)
102+
y_grid = np.linspace(0, fov * max_size / size, max_size)
103+
interpolator = RegularGridInterpolator((x_grid, y_grid), unit_cell_potential, bounds_error=False, fill_value=0)
104+
105+
# Vectorized computation of the full potential map with max_size
106+
x_coords, y_coords = np.meshgrid(np.linspace(0, fov, max_size), np.linspace(0, fov, max_size), indexing="ij")
107+
xtal_x = x_coords % unit_cell_size[0]
108+
xtal_y = y_coords % unit_cell_size[1]
109+
potential_map = interpolator((xtal_x.ravel(), xtal_y.ravel())).reshape(max_size, max_size)
110+
111+
# Rotate and crop the potential map
112+
potential_map = rotate(potential_map, rotation, reshape=False)
113+
center = potential_map.shape[0] // 2
114+
potential_map = potential_map[center - size // 2:center + size // 2, center - size // 2:center + size // 2]
115+
116+
potential_map = scipy.ndimage.gaussian_filter(potential_map,3)
117+
118+
return potential_map
119+
120+
def convolve_probe(ab, potential):
121+
# the pixel sizes should be the exact same as the potential
122+
final_sizes = potential.shape
123+
124+
# Perform FFT-based convolution
125+
pad_height = pad_width = potential.shape[0] // 2
126+
potential = np.pad(potential, ((pad_height, pad_height), (pad_width, pad_width)), mode='constant')
127+
128+
probe, A_k, chi = probe_tools.get_probe(ab, potential.shape[0], potential.shape[1], scale = 'mrad', verbose= False)
129+
130+
131+
convolved = fftconvolve(potential, probe, mode='same')
132+
133+
# Crop to original potential size
134+
start_row = pad_height
135+
start_col = pad_width
136+
end_row = start_row + final_sizes[0]
137+
end_col = start_col + final_sizes[1]
138+
139+
image = convolved[start_row:end_row, start_col:end_col]
140+
141+
return probe, image
142+
143+
71144
# Wavelength in 1/nm
72145
def get_wavelength(e0):
73146
"""
@@ -280,20 +353,21 @@ def diffractogram_spots(dset, spot_threshold, return_center=True, eps=0.1):
280353
return spots, center
281354

282355

283-
def center_diffractogram(dset, return_plot = True, histogram_factor = None, smoothing = 1, min_samples = 100):
356+
def center_diffractogram(dset, return_plot = True, smoothing = 1, min_samples = 10, beamstop_size = 0.1):
284357
try:
285358
diff = np.array(dset).T.astype(np.float16)
286359
diff[diff < 0] = 0
287-
288-
if histogram_factor is not None:
289-
hist, bins = np.histogram(np.ravel(diff), bins=256, range=(0, 1), density=True)
290-
threshold = threshold_otsu(diff, hist = hist * histogram_factor)
291-
else:
292-
threshold = threshold_otsu(diff)
360+
threshold = threshold_otsu(diff)
293361
binary = (diff > threshold).astype(float)
294362
smoothed_image = ndimage.gaussian_filter(binary, sigma=smoothing) # Smooth before edge detection
295363
smooth_threshold = threshold_otsu(smoothed_image)
296364
smooth_binary = (smoothed_image > smooth_threshold).astype(float)
365+
366+
# add a circle to mask the beamstop
367+
x, y = np.meshgrid(np.arange(dset.shape[0]), np.arange(dset.shape[1]))
368+
circle = (x - dset.shape[0] / 2) ** 2 + (y - dset.shape[1] / 2) ** 2 < (beamstop_size * dset.shape[0] / 2) ** 2
369+
smooth_binary[circle] = 1
370+
297371
# Find the edges using the Sobel operator
298372
edges = sobel(smooth_binary)
299373
edge_points = np.argwhere(edges)
@@ -322,18 +396,21 @@ def calc_distance(c, x, y):
322396

323397
finally:
324398
if return_plot:
325-
fig, ax = plt.subplots(1, 4, figsize=(10, 4))
399+
fig, ax = plt.subplots(1, 5, figsize=(14, 4), sharex=True, sharey=True)
326400
ax[0].set_title('Diffractogram')
327401
ax[0].imshow(dset.T, cmap='viridis')
328402
ax[1].set_title('Otsu Binary Image')
329403
ax[1].imshow(binary, cmap='gray')
330404
ax[2].set_title('Smoothed Binary Image')
331-
ax[2].imshow(smooth_binary, cmap='gray')
332-
ax[3].set_title('Edge Detection and Fitting')
333-
ax[3].imshow(edges, cmap='gray')
334-
ax[3].scatter(center[0], center[1], c='r', s=10)
405+
ax[2].imshow(smoothed_image, cmap='gray')
406+
407+
ax[3].set_title('Smoothed Binary Image')
408+
ax[3].imshow(smooth_binary, cmap='gray')
409+
ax[4].set_title('Edge Detection and Fitting')
410+
ax[4].imshow(edges, cmap='gray')
411+
ax[4].scatter(center[0], center[1], c='r', s=10)
335412
circle = plt.Circle(center, mean_radius, color='red', fill=False)
336-
ax[3].add_artist(circle)
413+
ax[4].add_artist(circle)
337414
for axis in ax:
338415
axis.axis('off')
339416
fig.tight_layout()

0 commit comments

Comments
 (0)