From 13742818df3242d09af9c3ac880387dd60c6e3a7 Mon Sep 17 00:00:00 2001 From: Nicola VIGANO Date: Mon, 23 Aug 2021 18:28:53 +0200 Subject: [PATCH] PSF: refactored and simplified implementation Signed-off-by: Nicola VIGANO --- doc_sources/examples.md | 2 +- examples/example_02_flower_refocus.py | 6 +- examples/example_04_letters.py | 2 +- plenoptomos/psf.py | 350 +++++++++++++++----------- plenoptomos/unsupported/testing.py | 2 +- tests/test_plenoptomos.py | 2 +- 6 files changed, 213 insertions(+), 151 deletions(-) diff --git a/doc_sources/examples.md b/doc_sources/examples.md index 3cdc1d4..9911ef5 100755 --- a/doc_sources/examples.md +++ b/doc_sources/examples.md @@ -25,7 +25,7 @@ By choosing the `mode='rgb'` we obtain three light-fields (one per RGB channel). We then create a (v, u) PSF for each color channel, using the following two lines: ``` -psf_ml_r = pleno.psf.PSF.create_theo_psf(lf_r.camera, coordinates='vu', airy_rings=2) +psf_ml_r = pleno.psf.PSF(lf_r.camera, coordinates='vu', airy_rings=2) psf_ml_r = pleno.psf.PSFApply2D(psf_d=psf_ml_r) ``` which first create the theoretical PSF for an incoherent light source (for the wavelengths indicated in `lf_r.camera`), including only the first two orders of the Airy function. diff --git a/examples/example_02_flower_refocus.py b/examples/example_02_flower_refocus.py index ae71cc0..2292564 100755 --- a/examples/example_02_flower_refocus.py +++ b/examples/example_02_flower_refocus.py @@ -45,13 +45,13 @@ (lf_r, lf_g, lf_b) = pleno.import_lf.from_lytro(dpath, jpath, source="eslf", mode="rgb") print("Creating the theoretical PSFs for the different color channels..") -psf_ml_r = pleno.psf.PSF.create_theo_psf(lf_r.camera, coordinates="vu", airy_rings=2) +psf_ml_r = pleno.psf.PSF(lf_r.camera, coordinates="vu", airy_rings=2) psf_ml_r = pleno.psf.PSFApply2D(psf_d=psf_ml_r) -psf_ml_g = pleno.psf.PSF.create_theo_psf(lf_g.camera, coordinates="vu", airy_rings=2) +psf_ml_g = pleno.psf.PSF(lf_g.camera, coordinates="vu", airy_rings=2) psf_ml_g = pleno.psf.PSFApply2D(psf_d=psf_ml_g) -psf_ml_b = pleno.psf.PSF.create_theo_psf(lf_b.camera, coordinates="vu", airy_rings=2) +psf_ml_b = pleno.psf.PSF(lf_b.camera, coordinates="vu", airy_rings=2) psf_ml_b = pleno.psf.PSFApply2D(psf_d=psf_ml_b) print("Computing refocusing distances..") diff --git a/examples/example_04_letters.py b/examples/example_04_letters.py index dffd5c6..b270a0a 100644 --- a/examples/example_04_letters.py +++ b/examples/example_04_letters.py @@ -53,7 +53,7 @@ z0 = lfv.camera.get_focused_distance() -psf_ml_raw = pleno.psf.PSF.create_theo_psf(lfv.camera, coordinates="vu", airy_rings=1) +psf_ml_raw = pleno.psf.PSF(lfv.camera, coordinates="vu", airy_rings=1) psf_ml = pleno.psf.PSFApply2D(psf_d=psf_ml_raw) print("Computing refocusing distances..") diff --git a/plenoptomos/psf.py b/plenoptomos/psf.py index 41cef82..a7db4f2 100755 --- a/plenoptomos/psf.py +++ b/plenoptomos/psf.py @@ -25,65 +25,82 @@ # functionality in the PSF creation from mpl_toolkits.mplot3d import Axes3D # noqa: F401 unused import -from collections import namedtuple +from dataclasses import dataclass from . import solvers +from . import lightfield import time as tm -class PSF(object): - """Data class that allows to store n-dimensional PSFs, and their - accompanying information. - """ +@dataclass +class PSFParameters(object): + airy_rings: int = 2 + over_sampling: int = 25 + is_beam_coherent: bool = False + wavelength_intensity: float = 1 + pixels_defocus: float = None + refocus_distance: float = None + guarantee_pixel_multiple: int = 1 + beam_transverse_shape: str = "circle" + data_type: np.floating = np.float32 - def __init__(self, coordinates, data=None, conf=None, data_format=None): - self.coordinates = (coordinates[0], coordinates[1]) - self.data = data - self.conf = conf - self.data_format = data_format - def clone(self): - return copy.deepcopy(self) +class PSF(object): + """Data class that allows to store n-dimensional PSFs, and their accompanying information.""" - @staticmethod - def create_theo_psf( - camera, - coordinates, - wavelength_steps=10, - wavelength_intensity=1, - airy_rings=2, - refocus_distance=None, - up_sampling=1, - beam_coherence="incoherent", - shape="circle", - over_sampling=25, - data_type=np.float32, - plot=False, + def __init__( + self, + camera: lightfield.Camera, + coordinates: str, + airy_rings: int = 2, + wavelength_intensity: float = 1, + refocus_distance: float = None, + up_sampling: int = 1, + is_beam_coherent: bool = False, + over_sampling: int = 25, + wavelength_steps: int = 10, + beam_transverse_shape: str = "circle", + data_type: np.floating = np.float32, + plot: bool = False, ): - """Compute the theoretical PSF for the given coordinates in the given - camera setup. - - :param camera: The camera setup object (lightfield.Camera) - :param coordinates: The coordinates where to compute the PSF. Options: 'vu' | 'ts' (string) - :param airy_rings: Orders of the Airy function to consider (int, default:2) - :param beam_coherence: coherence of the light source. Options: 'incoherent' | coherent (string, default: 'incoherent') - :param refocus_distance: Distance of refocusing, for fine tuning the disk of confusion (float, default: None) - - :returns: The computed PSF. - :rtype: PSF + """Initialize the theoretical PSF for the given coordinates in the given camera setup. + + Parameters + ---------- + camera : lightfield.Camera + The camera setup object. + coordinates : str + The coordinates where to compute the PSF. Options: 'vu' | 'ts'. + airy_rings : int, optional + Orders of the Airy function to consider. The default is 2. + wavelength_intensity : float, optional + Relative intensity of the beam across the different wavelengths. The default is 1. + refocus_distance : float, optional + Distance of refocusing, for fine tuning the disk of confusion. The default is None. + up_sampling : int, optional + Expected up-sampling of the PSF. The default is 1. + is_beam_coherent : bool, optional + Coherence of the light source. The default is False. + over_sampling : int, optional + Spatial oversampling. The default is 25. + wavelength_steps : int, optional + Wavelength oversampling. The default is 10. + beam_transverse_shape : str, optional + Beam's transverse shape. The default is "circle". + data_type : np.floating, optional + Data type of the PSF. The default is np.float32. + plot : bool, optional + Whether to plot the PSF. The default is False. + + Raises + ------ + ValueError + In case the camera wavelength unit is outside of the allowed range, + or the beam transverse shape is unknown. """ # We think in mm - if camera.wavelength_unit.lower() == "mm": - ls_unit = 1 - elif camera.wavelength_unit.lower() == "um": - ls_unit = 1e-3 - elif camera.wavelength_unit.lower() == "nm": - ls_unit = 1e-6 - elif camera.wavelength_unit.lower() == "pm": - ls_unit = 1e-9 - else: - raise ValueError("Unknown wavelength unit: %s" % camera.wavelength_unit) + ls_unit = self.get_unit_length(camera.wavelength_unit) if len(camera.wavelength_range) == 2: ls = np.linspace(camera.wavelength_range[0], camera.wavelength_range[1], wavelength_steps) * ls_unit @@ -101,130 +118,167 @@ def create_theo_psf( else: defocus_size = 0 - Conf = namedtuple( - "Conf", - ( - "airy_rings", - "over_sampling", - "beam_coherence", - "wavelength_intensity", - "pixels_defocus", - "guarantee_pixel_multiple", - "data_type", - ), + self.coordinates = (coordinates[0], coordinates[1]) + self.params = PSFParameters( + airy_rings=airy_rings, + over_sampling=over_sampling, + is_beam_coherent=is_beam_coherent, + wavelength_intensity=wavelength_intensity, + pixels_defocus=defocus_size, + refocus_distance=refocus_distance, + guarantee_pixel_multiple=1, + beam_transverse_shape=beam_transverse_shape, + data_type=data_type, ) - conf = Conf(airy_rings, over_sampling, beam_coherence, wavelength_intensity, defocus_size, 1, data_type) + self.data_format = None print( - "- Creating Theoretical PSFs for (%s, %s) coordinates (defocus size: %g).." - % (coordinates[0], coordinates[1], defocus_size), + "- Creating Theoretical PSFs for (%s, %s) coordinates (defocus size: %g).." % (*self.coordinates, defocus_size), end="", flush=True, ) c_in = tm.time() - if coordinates.lower() in ("uv", "vu"): - # Micro lenses PSF - d2 = camera.f2 / camera.aperture_f2 - if camera.is_focused(): - psf_data = PSF.compute_fraunhofer_psf(conf, d2, camera.b, camera.pixel_size_yx, ls) - else: - psf_data = PSF.compute_fraunhofer_psf(conf, d2, camera.f2, camera.pixel_size_yx, ls) - data_format = "raw" - elif coordinates.lower() in ("st", "ts"): - # Main lens PSF - d1 = camera.f1 / camera.aperture_f1 - if up_sampling > 1: - data_format = "subpixel" - else: - data_format = None - if camera.is_focused(): - psf_data = PSF.compute_fraunhofer_psf(conf, d1, camera.z1 + camera.a, camera.pixel_size_ts / up_sampling, ls) - else: - psf_data = PSF.compute_fraunhofer_psf(conf, d1, camera.z1, camera.pixel_size_ts / up_sampling, ls) + if self.params.airy_rings > 0: + if coordinates.lower() in ("uv", "vu"): + # Micro lenses / raw detector PSF + if camera.is_focused(): + effective_z = camera.b + else: + effective_z = camera.f2 + + d2 = camera.f2 / camera.aperture_f2 + h = PSF.compute_fraunhofer_psf(self.params, d2, effective_z, camera.pixel_size_yx, ls) + + self.data_format = "raw" + + elif coordinates.lower() in ("st", "ts"): + # Main lens PSF + if camera.is_focused(): + effective_z = camera.z1 + camera.a + else: + effective_z = camera.z1 + + d1 = camera.f1 / camera.aperture_f1 + effective_pixel_size_ts = camera.pixel_size_ts / up_sampling + h = PSF.compute_fraunhofer_psf(self.params, d1, effective_z, effective_pixel_size_ts, ls) + + if up_sampling > 1: + self.data_format = "subpixel" + else: + h = np.array(1, ndmin=2) + + if defocus_size > 0: + if self.params.beam_transverse_shape.lower() == "circle": + defocus_norm = 2 + elif self.params.beam_transverse_shape.lower() in ("rectangle", "square"): + defocus_norm = np.inf + h_defocus = self.compute_defocus_psf(self.params, norm=defocus_norm) + h = spsig.convolve2d(h, h_defocus, "full") + + # Producing the final impulse response h (=> psf), at the requested resolution + self.data = self._rebin_psf(self.params, h) c_out = tm.time() print("\b\b: Done in %g seconds." % (c_out - c_in)) - psf = PSF(coordinates, data=psf_data, conf=conf, data_format=data_format) - if plot: - pixels_distance = (psf_data.shape[0] - 1) / 2 - grid_p = np.linspace(-pixels_distance, pixels_distance, 2 * pixels_distance + 1) - [grid_p1, grid_p2] = np.meshgrid(grid_p, grid_p, indexing="ij") + self.plot() + + def plot(self): + pixels_distance = (self.data.shape[0] - 1) / 2 + grid_p = np.linspace(-pixels_distance, pixels_distance, 2 * pixels_distance + 1) + [grid_p1, grid_p2] = np.meshgrid(grid_p, grid_p, indexing="ij") - f = plt.figure() - ax = f.add_subplot(1, 1, 1, projection="3d") - ax.plot_surface(grid_p1, grid_p2, psf_data) - ax.view_init(12, -7.5) - plt.show() + f = plt.figure() + ax = f.add_subplot(1, 1, 1, projection="3d") + ax.plot_surface(grid_p1, grid_p2, self.data) + ax.view_init(12, -7.5) + plt.show() - return psf + def clone(self): + return copy.deepcopy(self) + + @staticmethod + def get_unit_length(unit_length): + # We think in mm + if unit_length.lower() == "mm": + return 1 + elif unit_length.lower() == "um": + return 1e-3 + elif unit_length.lower() == "nm": + return 1e-6 + elif unit_length.lower() == "pm": + return 1e-9 + else: + raise ValueError("Unknown wavelength unit: %s" % unit_length) + + @staticmethod + def _rebin_psf(params, h): + base_block_size = (np.array(h.shape) / params.over_sampling).astype(int) + + h = np.reshape(h, (base_block_size[0], params.over_sampling, base_block_size[1], params.over_sampling)) + h = np.sum(h, axis=(1, 3)) + + return h.astype(params.data_type) @staticmethod - def compute_fraunhofer_psf(conf, d, z, pixel_size, ls): + def compute_fraunhofer_psf(params, d, z, pixel_size, ls): # Computing real pixel_distance of first zero disk_d = np.mean(z * np.arcsin(1.22 * ls / d)) / np.mean(pixel_size) - pixels_distance = np.ceil(disk_d * conf.airy_rings).astype(np.int) + pixels_distance = np.ceil(disk_d * params.airy_rings).astype(np.int) base_block_size = 2 * pixels_distance + 1 - sampled_pixels = base_block_size * conf.over_sampling + sampled_pixels = base_block_size * params.over_sampling - samp_1 = np.linspace(-pixels_distance, pixels_distance, sampled_pixels) - samp_2 = np.linspace(-pixels_distance, pixels_distance, sampled_pixels) - - samp_1 = samp_1 * pixel_size[0] - samp_2 = samp_2 * pixel_size[1] + samp_1 = np.linspace(-pixels_distance, pixels_distance, sampled_pixels) * pixel_size[0] + samp_2 = np.linspace(-pixels_distance, pixels_distance, sampled_pixels) * pixel_size[1] [grid_1, grid_2] = np.meshgrid(samp_1, samp_2, indexing="ij") - data_center = pixels_distance * conf.over_sampling + np.floor(conf.over_sampling / 2) + data_center = pixels_distance * params.over_sampling + np.floor(params.over_sampling / 2) data_center = data_center.astype(np.int32) - # Airy function - r = np.sqrt(grid_1 ** 2 + grid_2 ** 2) - r = np.reshape(r, np.concatenate(((1,), r.shape))) - ls = np.reshape(ls, (-1, 1, 1)) - h = np.pi * d * r / (ls * z) - J10 = spspecial.jv(1, h) - h[:, data_center, data_center] = 1 # avoid warning abou NaN - h = 2 * J10 / h - # Setting central pixel to 1 (otherwise it would be NaN or 0) - h[:, data_center, data_center] = 1 - - if conf.beam_coherence.lower() == "coherent": + if params.beam_transverse_shape.lower() == "circle": + # Airy function + r = np.sqrt(grid_1 ** 2 + grid_2 ** 2) + r = np.reshape(r, np.concatenate(((1,), r.shape))) + ls = np.reshape(ls, (-1, 1, 1)) + h = np.pi * d * r / (ls * z) + J10 = spspecial.jv(1, h) + h[:, data_center, data_center] = 1 # avoid warning abou NaN + h = 2 * J10 / h + # Setting central pixel to 1 (otherwise it would be NaN or 0) + h[:, data_center, data_center] = 1 + elif params.beam_transverse_shape.lower() in ("rectangle", "square"): + # Sinc functions + raise NotImplementedError("Rectangular beam shape support not implemented, yet.") + else: + raise ValueError("Unknown beam transverse shape: %s" % params.beam_transverse_shape) + + if params.is_beam_coherent: int_exp = 2 else: int_exp = 4 - # The abs is not really needed, since the Airy function is a real - # valued function + # The abs is not really needed, since the Airy and Sinc functions are real valued h = h ** int_exp # Summing the contribution from all the wavelengths, and renormalizing - h *= np.reshape(conf.wavelength_intensity, (-1, 1, 1)) + h *= np.reshape(params.wavelength_intensity, (-1, 1, 1)) h = np.sum(h, axis=0) h /= np.sum(h) - if conf.pixels_defocus > 0: - hd = PSF.compute_defocus_psf(conf, keep_oversampling=True) - h = spsig.convolve2d(h, hd, "same") - - # Producing the final impulse response h, at the given resolution - h = np.reshape(h, (base_block_size, conf.over_sampling, base_block_size, conf.over_sampling)) - h = np.sum(h, axis=(1, 3)) - h = h.astype(conf.data_type) - return h @staticmethod - def compute_defocus_psf(conf, norm=2, keep_oversampling=False): - sampling_distance = np.ceil(conf.pixels_defocus).astype(np.intp) + def compute_defocus_psf(params, norm=2): + sampling_distance = np.ceil(params.pixels_defocus).astype(np.intp) h_size = 2 * sampling_distance + 1 - pixel_target = conf.guarantee_pixel_multiple + pixel_target = params.guarantee_pixel_multiple h_size = h_size + (pixel_target - (h_size % pixel_target)) % pixel_target h_size = (h_size + pixel_target * ((h_size / pixel_target - 1) % 2)).astype(np.intp) sampling_distance = (h_size - 1) / 2 - sampled_pixels = h_size * conf.over_sampling + sampled_pixels = h_size * params.over_sampling samp_1 = np.linspace(-sampling_distance, sampling_distance, sampled_pixels) samp_2 = np.linspace(-sampling_distance, sampling_distance, sampled_pixels) @@ -232,19 +286,13 @@ def compute_defocus_psf(conf, norm=2, keep_oversampling=False): [grid_1, grid_2] = np.meshgrid(samp_1, samp_2, indexing="ij") if (isinstance(norm, str) and norm.lower() == "inf") or norm == np.inf: - h = (np.fmax(np.abs(grid_1), np.abs(grid_2)) <= conf.pixels_defocus).astype(conf.data_type) + h = (np.fmax(np.abs(grid_1), np.abs(grid_2)) <= params.pixels_defocus).astype(params.data_type) else: - h = ((grid_1 ** norm + grid_2 ** norm) <= (conf.pixels_defocus ** norm)).astype(conf.data_type) + h = ((grid_1 ** norm + grid_2 ** norm) <= (params.pixels_defocus ** norm)).astype(params.data_type) # Summing the contribution from all the wavelengths, and renormalizing h /= np.sum(h) - if keep_oversampling is False: - # Producing the final impulse response h, at the given resolution - h = np.reshape(h, (h_size, conf.over_sampling, h_size, conf.over_sampling)) - h = np.sum(h, axis=(1, 3)) - h = h.astype(conf.data_type) - return h @@ -252,12 +300,13 @@ class PSFApply(object): """Class PSFApply handles all PSF/OTF applications """ - def __init__(self, psf_d=None, img_size=None, use_otf=True, data_format=None, use_fftconv=True): + def __init__(self, psf_d=None, params=None, img_size=None, use_otf=True, data_format=None, use_fftconv=True): print("- Initializing PSF application class..", end="", flush=True) c_in = tm.time() self._reset() + self.params = params self.data_format = data_format self.use_otf = use_otf self.use_fftconv = use_fftconv @@ -498,7 +547,7 @@ def _apply_otf(self, imgs, is_direct): # slicing images to remove padding used during convolution psf_edge_shape = ((self._get_psf_datashape() - 1) / 2).astype(np.int) - fslice = [slice(None),] * len(self.image_size) + fslice = [slice(None)] * len(self.image_size) for ii in self.otf_axes: fslice[ii] = slice(psf_edge_shape[ii], self.image_size[ii] + psf_edge_shape[ii]) return imgs[tuple(fslice)] @@ -524,16 +573,22 @@ class PSFApply2D(PSFApply): """Class PSFApply2D handles all PSF applications and """ - def __init__(self, psf_d, img_size=None, use_otf=True, data_format=None, use_fftconv=True): + def __init__(self, psf_d, img_size=None, params=None, use_otf=True, data_format=None, use_fftconv=True): self.otf_axes = (-2, -1) + if isinstance(psf_d, PSF): psf_inst = psf_d.clone() psf = np.squeeze(psf_inst.data) if data_format is None: data_format = psf_inst.data_format + if params is None: + params = psf_inst.params else: psf = np.squeeze(psf_d) - PSFApply.__init__(self, psf, img_size=img_size, use_otf=use_otf, data_format=data_format, use_fftconv=use_fftconv) + + PSFApply.__init__( + self, psf, img_size=img_size, params=params, use_otf=use_otf, data_format=data_format, use_fftconv=use_fftconv + ) def _check_incoming_psf(self, psf_d): if len(psf_d) == 0 or not len(psf_d.shape) in (2, 3): @@ -548,17 +603,24 @@ class PSFApply4D(PSFApply): """Class PSFApply4D handles all PSF applications and """ - def __init__(self, psf_d, img_size=None, use_otf=True, data_format=None, use_fftconv=True): + def __init__(self, psf_d, img_size=None, params=None, use_otf=True, data_format=None, use_fftconv=True): self.otf_axes = (-4, -3, -2, -1) + if isinstance(psf_d, PSF): - psf = psf_d.data + psf_inst = psf_d.clone() + psf = np.squeeze(psf_inst.data) if data_format is None: - data_format = psf_d.data_format + data_format = psf_inst.data_format + if params is None: + params = psf_inst.params else: if not len(psf_d.shape) == 4: psf_d = np.squeeze(psf_d) psf = psf_d - PSFApply.__init__(self, psf, img_size=img_size, use_otf=use_otf, data_format=data_format, use_fftconv=use_fftconv) + + PSFApply.__init__( + self, psf, img_size=img_size, params=params, use_otf=use_otf, data_format=data_format, use_fftconv=use_fftconv + ) def _check_incoming_psf(self, psf_d): if len(psf_d) == 0 or not len(psf_d.shape) == 4: diff --git a/plenoptomos/unsupported/testing.py b/plenoptomos/unsupported/testing.py index 0516e7f..037c5e0 100755 --- a/plenoptomos/unsupported/testing.py +++ b/plenoptomos/unsupported/testing.py @@ -233,7 +233,7 @@ def __init__(self, cameras, verbose=False): def create_psf_lenses( self, coordinates="vu", airy_rings=3, refocus_distance=None, up_sampling=1, plot_psfs=False, use_otf=False ): - make_psf = lambda x: psf.PSF.create_theo_psf( + make_psf = lambda x: psf.PSF( x, coordinates=coordinates, airy_rings=airy_rings, refocus_distance=refocus_distance, up_sampling=up_sampling ) self.psfs[coordinates] = [make_psf(c) for c in self.cameras] diff --git a/tests/test_plenoptomos.py b/tests/test_plenoptomos.py index 6169a59..07f2071 100755 --- a/tests/test_plenoptomos.py +++ b/tests/test_plenoptomos.py @@ -138,7 +138,7 @@ def setUp(self): """Set up test fixtures, if any.""" super().setUp() - self.psf_ml_raw = pleno.psf.PSF.create_theo_psf(lf.camera, coordinates='vu', airy_rings=2) + self.psf_ml_raw = pleno.psf.PSF(lf.camera, coordinates='vu', airy_rings=2) self.psf_ml_no = pleno.psf.PSFApply2D(psf_d=self.psf_ml_raw, use_otf=False) self.psf_ml_yo = pleno.psf.PSFApply2D(psf_d=self.psf_ml_raw, use_otf=True)