@@ -395,6 +395,15 @@ def align_zero_loss(dataset: sidpy.Dataset) -> sidpy.Dataset:
395
395
new_si .metadata .update ({'zero_loss' : {'shifted' : shifts }})
396
396
return new_si
397
397
398
+ from numba import jit
399
+
400
+ def get_zero_losses (energy , z_loss_params ):
401
+ z_loss_dset = np .zeros ((z_loss_params .shape [0 ], z_loss_params .shape [1 ], energy .shape [0 ]))
402
+ for x in range (z_loss_params .shape [0 ]):
403
+ for y in range (z_loss_params .shape [1 ]):
404
+ z_loss_dset [x , y ] += zl_func (energy , * z_loss_params [x , y ])
405
+ return z_loss_dset
406
+
398
407
399
408
400
409
@@ -488,11 +497,12 @@ def guess_function(xvec, yvec):
488
497
z_loss_dset = dataset .copy ()
489
498
z_loss_dset *= 0.0
490
499
491
- energy_grid = np .broadcast_to (energy .reshape ((1 , 1 , - 1 )), (z_loss_dset .shape [0 ],
492
- z_loss_dset .shape [1 ], energy .shape [0 ]))
493
- z_loss_peaks = zl_func (energy_grid , * z_loss_params )
494
- z_loss_dset += z_loss_peaks
495
-
500
+ #energy_grid = np.broadcast_to(energy.reshape((1, 1, -1)), (z_loss_dset.shape[0],
501
+ # z_loss_dset.shape[1], energy.shape[0]))
502
+ #z_loss_peaks = zl_func(energy_grid, *z_loss_params)
503
+ z_loss_params = np .array (z_loss_params )
504
+ z_loss_dset += get_zero_losses (np .array (energy ), np .array (z_loss_params ))
505
+
496
506
shifts = z_loss_params [:, :, 0 ] * z_loss_params [:, :, 3 ]
497
507
widths = z_loss_params [:, :, 2 ] * z_loss_params [:, :, 5 ]
498
508
@@ -522,7 +532,15 @@ def drude_lorentz(eps_inf, leng, ep, eb, gamma, e, amplitude):
522
532
return eps
523
533
524
534
525
- def fit_plasmon (dataset : Union [sidpy .Dataset , np .ndarray ], startFitEnergy : float , endFitEnergy : float , plot_result : bool = False , number_workers : int = 4 , number_threads : int = 8 ) -> Union [sidpy .Dataset , np .ndarray ]:
535
+ def get_plasmon_losses (energy , params ):
536
+ dset = np .zeros ((params .shape [0 ], params .shape [1 ], energy .shape [0 ]))
537
+ for x in range (params .shape [0 ]):
538
+ for y in range (params .shape [1 ]):
539
+ dset [x , y ] += energy_loss_function (energy , params [x , y ])
540
+ return dset
541
+
542
+
543
+ def fit_plasmon (dataset : Union [sidpy .Dataset , np .ndarray ], startFitEnergy : float , endFitEnergy : float , number_workers : int = 4 , number_threads : int = 8 ) -> Union [sidpy .Dataset , np .ndarray ]:
526
544
"""
527
545
Fit plasmon peak positions and widths in a TEM dataset using a Drude model.
528
546
@@ -567,8 +585,6 @@ def energy_loss_function(E: np.ndarray, Ep: float, Ew: float, A: float) -> np.nd
567
585
elf = (- 1 / eps ).imag
568
586
return A * elf
569
587
570
-
571
-
572
588
# define window for fitting
573
589
energy = dataset .get_spectral_dims (return_axis = True )[0 ].values
574
590
start_fit_pixel = np .searchsorted (energy , startFitEnergy )
@@ -589,37 +605,46 @@ def energy_loss_function(E: np.ndarray, Ep: float, Ew: float, A: float) -> np.nd
589
605
guess_pos = energy [guess_pos ]
590
606
if guess_width > 8 :
591
607
guess_width = 8
592
- popt , pcov = curve_fit (energy_loss_function , energy [start_fit_pixel :end_fit_pixel ], fit_dset ,
593
- p0 = [guess_pos , guess_width , guess_amplitude ])
608
+ try :
609
+ popt , pcov = curve_fit (energy_loss_function , energy [start_fit_pixel :end_fit_pixel ], fit_dset ,
610
+ p0 = [guess_pos , guess_width , guess_amplitude ])
611
+ except :
612
+ end_fit_pixel = np .searchsorted (energy , 30 )
613
+ fit_dset = np .array (dataset [start_fit_pixel :end_fit_pixel ]/ anglog [start_fit_pixel :end_fit_pixel ])
614
+ try :
615
+ popt , pcov = curve_fit (energy_loss_function , energy [start_fit_pixel :end_fit_pixel ], fit_dset ,
616
+ p0 = [guess_pos , guess_width , guess_amplitude ])
617
+ except :
618
+ popt = [0 ,0 ,0 ]
594
619
595
620
plasmon = dataset .like_data (energy_loss_function (energy , popt [0 ], popt [1 ], popt [2 ]))
596
621
plasmon *= anglog
597
622
start_plasmon = np .searchsorted (energy , 0 )+ 1
598
-
599
-
600
623
plasmon [:start_plasmon ] = 0.
624
+
601
625
epsilon = drude (energy , popt [0 ], popt [1 ], 1 ) * popt [2 ]
602
626
epsilon [:start_plasmon ] = 0.
603
-
627
+
604
628
plasmon .metadata ['plasmon' ] = {'parameter' : popt , 'epsilon' :epsilon }
605
629
return plasmon
606
630
607
631
# if it can be parallelized:
608
632
fitter = SidFitter (fit_dset , energy_loss_function , num_workers = number_workers ,
609
633
threads = number_threads , return_cov = False , return_fit = False , return_std = False ,
610
634
km_guess = False , num_fit_parms = 3 )
611
- [fitted_dataset ] = fitter .do_fit ()
635
+ [fit_parameter ] = fitter .do_fit ()
636
+
637
+ plasmon_dset = dataset * 0.0
638
+ fit_parameter = np .array (fit_parameter )
639
+ plasmon_dset += get_plasmon_losses (np .array (energy ), fit_parameter )
640
+ if 'plasmon' not in plasmon_dset .metadata :
641
+ plasmon_dset .metadata ['plasmon' ] = {}
642
+ plasmon_dset .metadata ['plasmon' ].update ({'startFitEnergy' : startFitEnergy ,
643
+ 'endFitEnergy' : endFitEnergy ,
644
+ 'fit_parameter' : fit_parameter ,
645
+ 'original_low_loss' : dataset .title })
612
646
613
- if plot_result :
614
- fig , (ax1 , ax2 , ax3 ) = plt .subplots (1 , 3 , sharex = True , sharey = True )
615
- ax1 .imshow (fitted_dataset [:, :, 0 ], cmap = 'jet' )
616
- ax1 .set_title ('Ep - Peak Position' )
617
- ax2 .imshow (fitted_dataset [:, :, 1 ], cmap = 'jet' )
618
- ax2 .set_title ('Ew - Peak Width' )
619
- ax3 .imshow (fitted_dataset [:, :, 2 ], cmap = 'jet' )
620
- ax3 .set_title ('A - Amplitude' )
621
- plt .show ()
622
- return fitted_dataset
647
+ return plasmon_dset
623
648
624
649
625
650
def angle_correction (spectrum ):
@@ -722,8 +747,11 @@ def multiple_scattering(energy_scale: np.ndarray, p: list, core_loss=False)-> np
722
747
ssd = ssd * ssd2
723
748
724
749
PSD /= tmfp * np .exp (- tmfp )
725
- BGDcoef = scipy .interpolate .splrep (LLene , PSD , s = 0 )
726
- return scipy .interpolate .splev (energy_scale , BGDcoef )
750
+ BGDcoef = scipy .interpolate .splrep (LLene , PSD , s = 0 )
751
+ msd = scipy .interpolate .splev (energy_scale , BGDcoef )
752
+ start_plasmon = np .searchsorted (energy_scale , 0 )+ 1
753
+ msd [:start_plasmon ] = 0.
754
+ return msd
727
755
728
756
def fit_multiple_scattering (dataset : Union [sidpy .Dataset , np .ndarray ], startFitEnergy : float , endFitEnergy : float ,pin = None , number_workers : int = 4 , number_threads : int = 8 ) -> Union [sidpy .Dataset , np .ndarray ]:
729
757
"""
0 commit comments