diff --git a/docs/api.rst b/docs/api.rst index 867c062..fa7ec8d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -2,6 +2,6 @@ API === .. toctree:: - :maxdepth: 4 + :maxdepth: 5 - widetrax \ No newline at end of file + widetrax diff --git a/docs/index.rst b/docs/index.rst index ffad0bb..ed2a6f6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,14 +7,14 @@ It computes the power spectrum density(PSD) of Sea surface Height (SSH) for any .. toctree:: - :maxdepth: 4 + :maxdepth: 5 :caption: Contents: Contents: --------- .. toctree:: - :maxdepth: 4 + :maxdepth: 5 Overview api diff --git a/docs/widetrax.DataPreprocessing.rst b/docs/widetrax.DataPreprocessing.rst index 7d3e817..9c73e6a 100644 --- a/docs/widetrax.DataPreprocessing.rst +++ b/docs/widetrax.DataPreprocessing.rst @@ -5,4 +5,3 @@ widetrax.DataPreprocessing module :members: :undoc-members: :show-inheritance: - \ No newline at end of file diff --git a/docs/widetrax.rst b/docs/widetrax.rst index 306a94f..6fadd41 100644 --- a/docs/widetrax.rst +++ b/docs/widetrax.rst @@ -1,11 +1,11 @@ widetrax package ================= -Submodules ----------- +modules +------- .. toctree:: - :maxdepth: 4 + :maxdepth: 5 - widetrax.DataPreprocessing widetrax.Spectram + widetrax.DataPreprocessing diff --git a/requirements.txt b/requirements.txt index 304dc46..d6b19d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,5 @@ netCDF4 numpy scipy xarray -zarr \ No newline at end of file +zarr +datetime diff --git a/widetrax/DataPreprocessing.py b/widetrax/DataPreprocessing.py index 165ca96..e39584b 100644 --- a/widetrax/DataPreprocessing.py +++ b/widetrax/DataPreprocessing.py @@ -1,99 +1,53 @@ -from collections import defaultdict -from datetime import datetime, timedelta -import os -import re -import sys - -import cartopy.crs as ccrs -import cartopy.feature as cfeature -import matplotlib.pyplot as plt import numpy as np -import pyinterp -import pyinterp.fill as fill import xarray as xr # ============================================================================= -# extract_xarray_in_region +# filtre_donnees # ============================================================================= - -def extract_xarray_in_region(directory, area): +def filtre_donnees(donnees, seuil_min, seuil_max, type_filtre="passe-bas"): """ - Extracts xarray datasets from SWOT NetCDF data for a specific region - + Filters the data based on the specified thresholds and the chosen filter type. + This function applies a filter to the provided data, either in low-pass or high-pass mode. Data outside the thresholds will be excluded + Parameters - ------------ - directory : str + ----------- + param donnees : str Path to the directory containing the NetCDF files - area : list + param seuil_min : list List with the boundaries of the region of interest [longitude_min, latitude_min, longitude_max, latitude_max] + param seuil_max : str + the max seuil + param type_filtre : int + just for test - Returns - --------- - datasets : Dict + -------- + donnees_filtrees : Dict Dictionary containing the xarray.Datasets for the region - + """ - - lon_min, lat_min, lon_max, lat_max = area - datasets = {} - i = 0 - - variables_to_load = ["ssha", "mdt", "latitude", "longitude"] - files_in_dir = os.listdir(directory) - - for filename in files_in_dir: - file_path = os.path.join(directory, filename) - - ds_tmp = xr.open_dataset(file_path, chunks={}) - variables_to_drop = [var for var in ds_tmp.variables if var not in variables_to_load] - ds_tmp.close() - del ds_tmp - - # Open the file (lazy loading) excluding unnecessary variables - ds = xr.open_dataset(file_path, chunks={}, drop_variables=variables_to_drop) - - if ds: - if lon_min < lon_max: - selection = ( - - (ds['latitude'] >= lat_min) & - (ds['latitude'] <= lat_max) & - (ds['longitude'] >= lon_min) & - (ds['longitude'] <= lon_max) - ) - - else: - selection = ( - - (ds['latitude'] >= lat_min) & - (ds['latitude'] <= lat_max) & - (((ds['longitude'] >= lon_min) & (ds['longitude'] <= 360)) | (ds['longitude'] <= lon_max)) - - ) - - selection = selection.compute() - - ds_area = ds.where(selection, drop=True) - ds.close() - - if ds_area['latitude'].size > 0: - datasets[i] = ds_area - i += 1 - - ds_area.close() - - return datasets - + if type_filtre not in ["passe-bas", "passe-haut"]: + raise ValueError("Le type de filtre doit être 'passe-bas' ou 'passe-haut'.") + + if not isinstance(donnees, (list, np.ndarray)): + raise TypeError("Les données doivent être une liste ou un tableau numpy.") + + if type_filtre == "passe-bas": + # Garder uniquement les valeurs en dessous du seuil_max + donnees_filtrees = [x for x in donnees if x <= seuil_max] + else: # Passe-haut + # Garder uniquement les valeurs au-dessus du seuil_min + donnees_filtrees = [x for x in donnees if x >= seuil_min] + + return donnees_filtrees # ============================================================================= # count_observations # ============================================================================= - def count_observations(datasets, area, resolution): """ Calculates the number of available observations per bin in the region of interest. @@ -158,12 +112,84 @@ def count_observations(datasets, area, resolution): return obs_count +# ============================================================================= +# extract_xarray_in_region +# ============================================================================= + +def extract_xarray_in_region(directory, area): + """ + Extracts xarray datasets from SWOT NetCDF data for a specific region + + + Parameters + ------------ + directory : str + Path to the directory containing the NetCDF files + area : list + List with the boundaries of the region of interest [longitude_min, latitude_min, longitude_max, latitude_max] + + + Returns + --------- + datasets : Dict + Dictionary containing the xarray.Datasets for the region + + """ + + lon_min, lat_min, lon_max, lat_max = area + datasets = {} + i = 0 + + variables_to_load = ["ssha", "mdt", "latitude", "longitude"] + files_in_dir = os.listdir(directory) + + for filename in files_in_dir: + file_path = os.path.join(directory, filename) + + ds_tmp = xr.open_dataset(file_path, chunks={}) + variables_to_drop = [var for var in ds_tmp.variables if var not in variables_to_load] + ds_tmp.close() + del ds_tmp + + # Open the file (lazy loading) excluding unnecessary variables + ds = xr.open_dataset(file_path, chunks={}, drop_variables=variables_to_drop) + + if ds: + if lon_min < lon_max: + selection = ( + + (ds['latitude'] >= lat_min) & + (ds['latitude'] <= lat_max) & + (ds['longitude'] >= lon_min) & + (ds['longitude'] <= lon_max) + ) + + else: + selection = ( + + (ds['latitude'] >= lat_min) & + (ds['latitude'] <= lat_max) & + (((ds['longitude'] >= lon_min) & (ds['longitude'] <= 360)) | (ds['longitude'] <= lon_max)) + + ) + + selection = selection.compute() + + ds_area = ds.where(selection, drop=True) + ds.close() + + if ds_area['latitude'].size > 0: + datasets[i] = ds_area + i += 1 + + ds_area.close() + + return datasets # ============================================================================= # fill_nan # ============================================================================= - def fill_nan(datasets, varname: str = "ssha"): """ Fills in missing values (NaN) in each xarray.Dataset using Gauss-Seidel method. @@ -222,424 +248,3 @@ def fill_nan(datasets, varname: str = "ssha"): print(f"Size of longitudes/latitudes is zero for dict number {key}") return has_converged, filled_datasets - - -# ============================================================================= -# check_directory -# ============================================================================= - - -def check_directory(database_path, start_date_str, end_date_str): - """ - - Scans the folders in the `database_path` directory, identifies the folders - containing NetCDF files whose dates are between `start_date_str` and `end_date_str`, - and returns a list of these folder names. - - - Parameters - ------------ - database_path : str - Path to the `database` directory - start_date_str : str - Start date in 'YYYYMMDD' format - end_date_str : str - End date in 'YYYYMMDD' format - - Returns - --------- - - matching_folders : list - List of folder names containing NetCDF files within the specified date range - If an error occurs, an error message is printed and an empty list is returned. - - """ - - # Regex pattern to match folder names like cycle_001, cycle_002, etc. - folder_pattern = re.compile(r'cycle_\d{3}') - # Regex pattern to match the date in the file name SWOT........_YYYYMMDD... - file_pattern = re.compile(r'SWOT.*_(\d{8})T.*\.nc') - - # Convert date strings to datetime objects - start_date = datetime.strptime(start_date_str, '%Y%m%d') - end_date = datetime.strptime(end_date_str, '%Y%m%d') - - # List to store the names of folders that meet the criteria - matching_folders = [] - - try: - # List all items in the database directory - items = os.listdir(database_path) - - for item in items: - folder_path = os.path.join(database_path, item) - - # Check if the item is a directory and matches the pattern - if os.path.isdir(folder_path) and folder_pattern.match(item): - netcdf_files = [f for f in os.listdir(folder_path) if f.endswith('.nc')] - for nc_file in netcdf_files: - match = file_pattern.search(nc_file) - if match: - file_date_str = match.group(1) - file_date = datetime.strptime(file_date_str, '%Y%m%d') - - if start_date <= file_date <= end_date: - matching_folders.append(item) - break # Stop checking files in this folder once a match is found - - return matching_folders - - except Exception as e: - print(f"An error occurred: {e}") - return [] - - -# ============================================================================= -# extract_xarrays_by_time -# ============================================================================= - - -def extract_xarrays_by_time(database_path, start_date_str, end_date_str, area): - """ - Processes folders in the `database_path` directory, applies the `extract_xarray_in_region` - function to each folder that contains NetCDF files within the date range specified - by `start_date_str` and `end_date_str`, and combines the results into a single dictionary. - - - Parameters - ------------ - database_path : str - Path to the `database` directory - start_date_str : str - Start date in 'YYYYMMDD' format - end_date_str : str - End date in 'YYYYMMDD' format - area : list - List with the boundaries of the region of interest [longitude_min, latitude_min, longitude_max, latitude_max] - - - Returns - --------- - - combined_datasets_dict : Dict - A dictionary of xarray.Datasets combining the results from `extract_xarray_in_region` function for each folder. - - """ - - matching_folders = check_directory(database_path, start_date_str, end_date_str) - combined_datasets_dict = defaultdict(list) - current_key = 0 - - for folder in matching_folders: - folder_path = os.path.join(database_path, folder) - result_dict = extract_xarray_in_region(folder_path, area) - - # Add entries to the combined_dict with sequential keys - for value in result_dict.values(): - combined_datasets_dict[current_key] = value - current_key += 1 - - # Convert defaultdict back to dict - combined_datasets_dict = dict(combined_datasets_dict) - - return combined_datasets_dict - - -# ============================================================================= -# plot_obs_count -# ============================================================================= - - -def plot_obs_count(obs_count, area, obs_count2=None, title=None, title2=None, save_fig=None): - """ - Plots the number of observations on a geographical map - - Parameters - ------------ - obs_count : numpy.ndarray - A 2D array containing the count of observations in each geographical bin. - obs_count2 : numpy.ndarray, optional - A second 2D array containing the count of observations in each geographical bin. - area : list - List with the boundaries of the region of interest [longitude_min, latitude_min, longitude_max, latitude_max] - title : str, optional - The title of the plot. Defaults to None. - title2 : str, optional - The title of the 2nd plot. Defaults to None. - save_fig: Optional[str], default None - Name of the file to save the plot to. - Does not save if None. - - - Returns - --------- - None - - """ - lon_min, lat_min, lon_max, lat_max = area - - if lon_min > 180: - lon_min = lon_min - 360 - - if lon_max > 180: - lon_max = lon_max - 360 - - # Créer des sous-graphes côte à côte - if obs_count2 is not None: - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 12), subplot_kw={'projection': ccrs.PlateCarree()}) - maxobs = max(obs_count.max(), obs_count2.max()) - else: - fig, ax1 = plt.subplots(1, 1, figsize=(8, 7), subplot_kw={'projection': ccrs.PlateCarree()}) - maxobs = obs_count.max() - - ax1.add_feature(cfeature.LAND, facecolor='gray') - ax1.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) - # Create a mask for zero values to present the continent in its color - obs_count_masked = np.ma.masked_where(obs_count == 0, obs_count) - - im1 = ax1.imshow(obs_count_masked, extent=[lon_min, lon_max, lat_min, lat_max], - origin='lower', cmap="jet", transform=ccrs.PlateCarree(), - vmin=0, vmax=maxobs) - - ax1.coastlines() - gl1 = ax1.gridlines(draw_labels=True) - gl1.right_labels = False - gl1.top_labels = False - - plt.colorbar(im1, ax=ax1, label='Number of observations per bin', shrink=0.5) - - if title: - if len(title) >= 40: - ax1.set_title(title, fontsize=11, fontweight='bold', color='black') - else: - ax1.set_title(title, fontsize=15, fontweight='bold', color='black') - - if obs_count2 is not None: - ax2.add_feature(cfeature.LAND, facecolor='gray') - ax2.set_extent([lon_min, lon_max, lat_min, lat_max], crs=ccrs.PlateCarree()) - # Create a mask for zero values to present the continent in its color - obs_count_masked2 = np.ma.masked_where(obs_count2 == 0, obs_count2) - - im2 = ax2.imshow(obs_count_masked2, extent=[lon_min, lon_max, lat_min, lat_max], - origin='lower', cmap="jet", transform=ccrs.PlateCarree(), - vmin=0, vmax=maxobs) - - ax2.coastlines() - gl2 = ax2.gridlines(draw_labels=True) - gl2.right_labels = False - gl2.top_labels = False - plt.colorbar(im2, ax=ax2, label='Number of observations per bin', shrink=0.5) - if title2: - ax2.set_title(title2, fontsize=15, fontweight='bold', color='black') - - # Optionally save the figure - if save_fig is not None: - plt.savefig(save_fig) - - -# ============================================================================= -# read_zarr_to_xarray_dict -# ============================================================================= - - -# Fonction pour lire les fichiers Zarr pour une plage de dates spécifique -def read_zarr_to_xarray_dict(base_directory, area, start_date_str, end_date_str, variables_to_keep): - """ - - Reads Zarr files from a directory structure organized by month and day, converts them into a dictionnary of xarray.Dataset objects, retains only specified variables, and extracts a specific geographical region based on latitude and longitude limits. - - Parameters - ------------ - base_directory : str - The path to the base directory containing Zarr data organized by month and day. - area : list - List with the boundaries of the region of interest [longitude_min, latitude_min, longitude_max, latitude_max] - start_date_str : str - The desired start date in the format 'YYYYMMDD'. - - end_date_str : str - The desired end date in the format 'YYYYMMDD'. - - variables_to_keep : list of str - A list of variable names to retain in each xarray.Dataset. - - Returns - --------- - datasets_dict : Dict - A dictionary containing the resulting xarray.Dataset objects, indexed by unique integers. - - """ - - lon_min = area[0] - lon_max = area[2] - lat_min = area[1] - lat_max = area[3] - - # le nombre de fichiers - nfiles = 0 - - datasets_dict = {} - index = 0 - - # Convertir les chaînes de date en objets datetime - start_date = datetime.strptime(start_date_str, '%Y%m%d') - end_date = datetime.strptime(end_date_str, '%Y%m%d') - - # Parcourir chaque jour entre les dates de début et de fin - current_date = start_date - while current_date <= end_date: - month = current_date.strftime('%m') - day = current_date.strftime('%d') - - month_directory = os.path.join(base_directory, f'month={month}') - if os.path.exists(month_directory): - day_directory = os.path.join(month_directory, f'day={day}') - if os.path.exists(day_directory): - - zarr_ds = xr.open_zarr(day_directory) - - # Garder uniquement les variables spécifiées - zarr_ds = zarr_ds[variables_to_keep] - - coord_vars = ['latitude', 'longitude'] - zarr_ds = zarr_ds.set_coords(coord_vars) - - if lon_min > lon_max: - - selection = ( - - (zarr_ds['latitude'] >= lat_min) & - (zarr_ds['latitude'] <= lat_max) & - (((zarr_ds['longitude'] >= lon_min) & (zarr_ds['longitude'] <= 360)) | ( - zarr_ds['longitude'] <= lon_max)) - ) - else: - selection = ( - - (zarr_ds['latitude'] >= lat_min) & - (zarr_ds['latitude'] <= lat_max) & - (zarr_ds['longitude'] >= lon_min) & - (zarr_ds['longitude'] <= lon_max) - ) - - selection = selection.compute() - - zarr_ds = zarr_ds.where(selection, drop=True) - - # créer la variable ssha - zarr_ds['ssha'] = zarr_ds.duacs_ssha_karin_2_calibrated.where(zarr_ds.duacs_editing_flag == 0) - zarr_ds = zarr_ds.drop_vars('duacs_ssha_karin_2_calibrated') - zarr_ds = zarr_ds.drop_vars('duacs_editing_flag') - - # créer la variable mdt - zarr_ds['mdt'] = zarr_ds.cvl_mean_dynamic_topography_cnes_cls_22 - zarr_ds = zarr_ds.drop_vars('cvl_mean_dynamic_topography_cnes_cls_22') - - datasets_dict[index] = zarr_ds - - index += 1 - nfiles += 1 - - else: - print(f'The directory for day{day} does not exist in month {month}') - else: - print(f'The directory for month{month} does not exist') - - # Passer au jour suivant - current_date += timedelta(days=1) - - return datasets_dict - - -# ============================================================================= -# split_dsets_based_cnum -# ============================================================================= - - -def split_dsets_based_cnum(datasets_dict): - """ - Splits xarray.dataset objects based on unique cycle and pass numbers. - - The function takes a dictionary of xarray.dataset objects and splits each dataset into - smaller datasets based on unique values of 'cycle_number' and 'pass_number'. - The resulting datasets are stored in a new dictionary with sequential keys. - - Conditions: - - A dataset is split if it contains at least 2 different 'cycle_number' values. - - For each unique 'cycle_number', the dataset is further split if it contains - at least 2 different 'pass_number' values. - - If an xarray.dataset in the input dictionary meets the splitting conditions (having at least 2 different 'cycle_number' and 'pass_number'), it is split into smaller xarray datasets. Otherwise, the original dataset is included as is. - - Parameters - ------------ - datasets_dict : Dict - A dictionary where each key corresponds to an xarray Dataset. - Each xarray Dataset is expected to have 'cycle_number' and 'pass_number' attributes. - - Returns - --------- - splited_dict : Dict - A new dictionary containing the split xarray.dataset objects. - - """ - - splited_dict = {} - index = 0 - - for key in range(len(datasets_dict)): - # Check if the condition is met (you can define your own condition here) - if len(np.unique(datasets_dict[key].cycle_number.compute())) >= 2: - # recuperer le num de cycle - long_cycle = len(np.unique(datasets_dict[key].cycle_number.compute())) - # print(f"la longuer de cycle number est {long_cycle}") - - for i in np.arange(long_cycle - 1): - cycle = np.unique(datasets_dict[key].cycle_number.compute())[i] - ds_cycle = datasets_dict[key].where(datasets_dict[key]['cycle_number'].compute() == cycle.astype(float), - drop=True) - - # condition séparation sur les pass number - if len(np.unique(ds_cycle.pass_number.compute())) >= 2: - - long_pass = len(np.unique(ds_cycle.pass_number.compute())) - # print(f"la longuer de pass number est {long_pass}") - - for i in np.arange(long_pass - 1): - passs = np.unique(ds_cycle.pass_number.compute())[i] - ds_cyclepass = ds_cycle.where(ds_cycle.pass_number.compute() == passs.astype(float), drop=True) - - # Add the new datasets to the output dictionary with sequential keys - splited_dict[index] = ds_cyclepass - index += 1 - else: - # If the dataset doesn't meet the condition, add it to the output dictionary as is - splited_dict[index] = datasets_dict[key] - index += 1 - - return splited_dict - - -# ============================================================================= -# remove_duplicates_from_sys_path -# ============================================================================= - -def remove_duplicates_from_sys_path(): - """ - Removes duplicates from the sys.path list while preserving the order of elements. - - Iterates through the sys.path list and constructs a new list without duplicates. - Updates sys.path with this new list. - - Returns - -------- - None - - """ - seen = set() - new_sys_path = [] - for path in sys.path: - if path not in seen: - new_sys_path.append(path) - seen.add(path) - sys.path = new_sys_path diff --git a/widetrax/__init__.py b/widetrax/__init__.py index 2a6b05a..01b50d2 100644 --- a/widetrax/__init__.py +++ b/widetrax/__init__.py @@ -1,5 +1,6 @@ from . import Spectram from . import DataPreprocessing +from . import trynewname from .version import __version__ # noqa -__all__ = ["Spectram", "DataPreprocessing"] +__all__ = ["Spectram", "DataPreprocessing","trynewname"]