diff --git a/decode/fit.py b/decode/fit.py index 01986bb..db2b90a 100644 --- a/decode/fit.py +++ b/decode/fit.py @@ -93,3 +93,49 @@ def dtau_dpwv(freq: NDArray[np.float_]) -> xr.DataArray: tau = load.atm(type="tau").interp(freq=freq, method="linear") fit = tau.curvefit("pwv", lambda x, a, b: a * x + b) return fit["curvefit_coefficients"].sel(param="a", drop=True) + + +def cube( + cube: xr.DataArray, + /, + *, + init_amp: float = 1.0, + init_x0: float = 0.0, + init_y0: float = 0.0, + init_sigma_x: float = 20.0, + init_sigma_y: float = 20.0, + init_theta: float = 0.0, + init_offset: float = 0.0, +) -> xr.Dataset: + """Apply 2D Gaussian fit to each channel of a 3D spectral cube.""" + return cube.curvefit( + coords=("lon", "lat"), + func=gaussian_2d, + p0={ + "amp": init_amp, + "x0": init_x0, + "y0": init_y0, + "sigma_x": init_sigma_x, + "sigma_y": init_sigma_y, + "theta": init_theta, + "offset": init_offset, + }, + errors="ignore", + ) + + +def gaussian_2d(xy, amp, x0, y0, sigma_x, sigma_y, theta, offset): + x, y = xy + x0 = float(x0) + y0 = float(y0) + a = (np.cos(theta) ** 2) / (2 * sigma_x**2) + (np.sin(theta) ** 2) / ( + 2 * sigma_y**2 + ) + b = -(np.sin(2 * theta)) / (4 * sigma_x**2) + (np.sin(2 * theta)) / (4 * sigma_y**2) + c = (np.sin(theta) ** 2) / (2 * sigma_x**2) + (np.cos(theta) ** 2) / ( + 2 * sigma_y**2 + ) + g = offset + amp * np.exp( + -(a * ((x - x0) ** 2) + 2 * b * (x - x0) * (y - y0) + c * ((y - y0) ** 2)) + ) + return g.ravel() diff --git a/decode/qlook.py b/decode/qlook.py index 83c2cb9..6fd32ed 100644 --- a/decode/qlook.py +++ b/decode/qlook.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import Any, Literal, Optional, Sequence, Union, cast from warnings import catch_warnings, simplefilter - +import copy # dependencies import numpy as np @@ -26,8 +26,9 @@ from astropy.units import Quantity from fire import Fire from matplotlib.figure import Figure -from . import assign, convert, load, make, plot, select, utils - +from scipy.optimize import curve_fit +from . import assign, convert, load, make, plot, select, utils, fit +import pandas as pd # constants DATA_FORMATS = "csv", "nc", "zarr", "zarr.zip" @@ -223,6 +224,10 @@ def daisy( ) cont = cube.weighted(weight.fillna(0)).mean("chan") + ### GaussFit (all chan) + fitted_cube = fit.cube(cube) + # to toml here + # save result suffixes = f".{suffix}.{format}" file = Path(outdir) / Path(dems).with_suffix(suffixes).name @@ -446,6 +451,10 @@ def raster( ) cont = cube.weighted(weight.fillna(0)).mean("chan") + ### GaussFit (all chan) + fitted_cube = fit.cube(cube) + # to toml here + # save result suffixes = f".{suffix}.{format}" file = Path(outdir) / Path(dems).with_suffix(suffixes).name