diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..c144172 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,25 @@ +[run] +source = pytorch_caney +omit = + */site-packages/* + */dist-packages/* + */tests/* + setup.py + +[report] +exclude_lines = + pragma: no cover + def __repr__ + if self.debug: + if __name__ == .__main__.: + raise NotImplementedError + pass + except ImportError: + +show_missing = True + +[html] +directory = htmlcov + +[xml] +output = coverage.xml \ No newline at end of file diff --git a/.github/workflows/dockerhub-dev.yml b/.github/workflows/dockerhub-dev.yml index ffd2cd3..31a30b6 100644 --- a/.github/workflows/dockerhub-dev.yml +++ b/.github/workflows/dockerhub-dev.yml @@ -28,10 +28,24 @@ jobs: - name: Lower github-runner storage run: | - sudo rm -rf /usr/share/dotnet - sudo rm -rf /opt/ghc - sudo rm -rf "/usr/local/share/boost" - sudo rm -rf "$AGENT_TOOLSDIRECTORY" + # Remove software and language runtimes we're not using + sudo rm -rf \ + "$AGENT_TOOLSDIRECTORY" \ + /opt/google/chrome \ + /opt/microsoft/msedge \ + /opt/microsoft/powershell \ + /opt/pipx \ + /usr/lib/mono \ + /usr/local/julia* \ + /usr/local/lib/android \ + /usr/local/lib/node_modules \ + /usr/local/share/chromium \ + /usr/local/share/powershell \ + /usr/share/dotnet \ + /usr/share/swift + df -h / + + - name: Build and push diff --git a/README.md b/README.md index 344f546..de3eda0 100644 --- a/README.md +++ b/README.md @@ -13,23 +13,24 @@ Python package for lots of Pytorch tools. - Latest: https://nasa-nccs-hpda.github.io/pytorch-caney/latest -## Objectives +# pytorch-caney + +Python package for a variety of PyTorch tools for geospatial science problems. + +[![DOI](https://zenodo.org/badge/472450059.svg)](https://zenodo.org/badge/latestdoi/472450059) +## Objectives - Library to process remote sensing imagery using GPU and CPU parallelization. - Machine Learning and Deep Learning image classification and regression. - Agnostic array and vector-like data structures. -- User interface environments via Notebooks for easy to use AI/ML projects. -- Example notebooks for quick AI/ML start with your own data. +- User interface environments via Notebooks for easy-to-use AI/ML projects. +- Example notebooks for a quick AI/ML start with your own data. ## Installation -The following library is intended to be used to accelerate the development of data science products -for remote sensing satellite imagery, or any other applications. pytorch-caney can be installed -by itself, but instructions for installing the full environments are listed under the requirements -directory so projects, examples, and notebooks can be run. +The following library is intended to be used to accelerate the development of data science products for remote sensing satellite imagery, or other applications. `pytorch-caney` can be installed by itself, but instructions for installing the full environments are listed under the `requirements` directory so projects, examples, and notebooks can be run. -Note: PIP installations do not include CUDA libraries for GPU support. Make sure NVIDIA libraries -are installed locally in the system if not using conda/mamba. +**Note:** PIP installations do not include CUDA libraries for GPU support. Make sure NVIDIA libraries are installed locally in the system if not using conda/mamba. ```bash module load singularity # if a module needs to be loaded @@ -42,91 +43,679 @@ singularity build --sandbox pytorch-caney-container docker://nasanccs/pytorch-ca ## Contributors -- Jordan Alexis Caraballo-Vega, jordan.a.caraballo-vega@nasa.gov -- Caleb Spradlin, caleb.s.spradlin@nasa.gov +- **Jordan Alexis Caraballo-Vega**: [jordan.a.caraballo-vega@nasa.gov](mailto:jordan.a.caraballo-vega@nasa.gov) +- **Caleb Spradlin**: [caleb.s.spradlin@nasa.gov](mailto:caleb.s.spradlin@nasa.gov) +- **Jian Li**: [jian.li@nasa.gov](mailto:jian.li@nasa.gov) ## Contributing -Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md). +Please see our [guide for contributing to pytorch-caney](CONTRIBUTING.md). -## SatVision +# User Guide +--- -| name | pretrain | resolution | #params | -| :---: | :---: | :---: | :---: | -| SatVision-B | MODIS-1.9-M | 192x192 | 84.5M | +## 1. SatVision-TOA -## SatVision Datasets +|Name|Pretrain|Resolution|Channels | Parameters| +|---|---|---|---|---| +|SatVision-TOA-GIANT|MODIS-TOA-100-M|128x128|14|3B| -| name | bands | resolution | #chips | -| :---: | :---: | :---: | :---: | -| MODIS-Small | 7 | 128x128 | 1,994,131 | +### Accessing the Model -## MODIS Surface Reflectance (MOD09GA) Band Details +Model Repository: [HuggingFace](https://huggingface.co/nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128) -| Band Name | Bandwidth | -| :------------: | :-----------: | -| sur_refl_b01_1 | 0.620 - 0.670 | -| sur_refl_b02_1 | 0.841 - 0.876 | -| sur_refl_b03_1 | 0.459 - 0.479 | -| sur_refl_b04_1 | 0.545 - 0.565 | -| sur_refl_b05_1 | 1.230 - 1.250 | -| sur_refl_b06_1 | 1.628 - 1.652 | -| sur_refl_b07_1 | 2.105 - 2.155 | +#### **Clone the Model Checkpoint** -## Pre-training with Masked Image Modeling +1. Load `git-lfs`: +```bash + module load git-lfs +``` +```bash + git lfs install +``` -To pre-train the swinv2 base model with masked image modeling pre-training, run: +2. Clone the repository: ```bash -torchrun --nproc_per_node pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg --dataset --data-paths --batch-size --output --enable-amp + git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128 ``` -For example to run on a compute node with 4 GPUs and a batch size of 128 on the MODIS SatVision pre-training dataset with a base swinv2 model, run: + Note: Using SSH authentication +Ensure SSH keys are configured. Troubleshooting steps: +- Check SSH connection: ```bash -singularity shell --nv -B /path/to/container/pytorch-caney-container -Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney -Singularity> torchrun --nproc_per_node 4 pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py --cfg pytorch-caney/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml --dataset MODIS --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* --batch-size 128 --output . --enable-amp + ssh -T git@hf.co # If reports back as anonymous follow the next steps ``` - -This example script runs the exact configuration used to make the SatVision-base model pre-training with MiM and the MODIS pre-training dataset. +- Add your SSH key: ```bash -singularity shell --nv -B /path/to/container/pytorch-caney-container -Singularity> cd pytorch-caney/examples/satvision -Singularity> ./run_satvision_pretrain.sh + eval $(ssh-agent) + ssh-add ~/.ssh/your-key # Path to your SSH key ``` -## Fine-tuning Satvision-base -To fine-tune the satvision-base pre-trained model, run: +## Running SatVision-TOA Pipelines + +### Command-Line Interface (CLI) + +To run tasks with **PyTorch-Caney**, use the following command: + ```bash -torchrun --nproc_per_node pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py --cfg --pretrained --dataset --data-paths --batch-size --output --enable-amp +$ python pytorch-caney/pytorch_caney/ptc_cli.py --config-path ``` -See example config files pytorch-caney/examples/satvision/finetune_satvision_base_*.yaml to see how to structure your config file for fine-tuning. +### Common CLI Arguments +| Command-line-argument | Description |Required/Optional/Flag | Default | Example | +| --------------------- |:----------------------------------------------------|:---------|:---------|:--------------------------------------| +| `-config-path` | Path to training config | Required | N/A |`--config-path pytorch-caney/configs/3dcloudtask_swinv2_satvision_gaint_test.yaml` | +| `-h, --help` | show this help message and exit | Optional | N/a |`--help`, `-h` | -## Testing -For unittests, run this bash command to run linting and unit test runs. This will execute unit tests and linting in a temporary venv environment only used for testing. +### Examples + +**Run 3D Cloud Task with Pretrained Model**: +```shell +$ python pytorch-caney/pytorch_caney/ptc_cli.py --config-path pytorch-caney/configs/3dcloudtask_swinv2_satvision_giant_test.yaml +``` +**Run 3D Cloud Task with baseline model**: +```shell +$ python pytorch-caney/pytorch_caney/ptc_cli.py --config-path pytorch-caney/configs/3dcloudtask_fcn_baseline_test.yaml +``` + +**Run SatVision-TOA Pretraining from Scratch**: +```shell +$ python pytorch-caney/pytorch_caney/ptc_cli.py --config-path pytorch-caney/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep.yaml +``` + +## **3. Using Singularity for Containerized Execution** + +**Shell Access** + ```bash -git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git -cd pytorch-caney; bash test.sh +$ singularity shell --nv -B +Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney ``` -or run unit tests directly with container or anaconda env +**Command Execution** ```bash -git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git -singularity build --sandbox pytorch-caney-container docker://nasanccs/pytorch-caney:latest -singularity shell --nv -B /path/to/container/pytorch-caney-container -cd pytorch-caney; python -m unittest discover pytorch_caney/tests +$ singularity exec --nv -B , --env PYTHONPATH=$PWD:$PWD/pytorch-caney COMMAND ``` +### **Example** + +Running the 3D Cloud Task inside the container: + +```bash +$ singularity shell --nv -B +Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney +Singularity> python pytorch-caney/pytorch_caney/ptc_cli.py --config-path pytorch-caney/configs/3dcloudtask_swinv2_satvision_giant_test.yaml +``` + +--- + +## 4. ThreeDCloudTask Pipeline + +This document describes how to run the `ThreeDCloudTask` pipeline using the provided configuration files and PyTorch Lightning setup. This requires downloading the 3D Cloud dataset from HuggingFace. + +## Pipeline Overview + +The `ThreeDCloudTask` is a PyTorch Lightning module designed for regression tasks predicting a 3D cloud vertical structure. The pipeline is configurable through YAML files and leverages custom components for the encoder, decoder, loss functions, and metrics. + +## Running the Pipeline + +Follow the steps below to train or validate the `ThreeDCloudTask` pipeline. + +### Prepare Configuration + +Two example configuration files are provided: + +- `3dcloudtask_swinv2_satvision_gaint_test.yaml`: Configures a pipeline using the SwinV2-based SatVision encoder. +- `3dcloudtask_fcn_baseline_test.yaml`: Configures a baseline pipeline with a fully convolutional network (FCN). + +Modify the configuration file to suit your dataset and training parameters. + +### Run the Training Script + +Example: ```bash -git clone git@github.com:nasa-nccs-hpda/pytorch-caney.git -cd pytorch-caney; conda env create -f requirements/environment_gpu.yml; -conda activate pytorch-caney -python -m unittest discover pytorch_caney/tests +$ singularity shell --nv -B +Singularity> export PYTHONPATH=$PWD:$PWD/pytorch-caney +Singularity> python python pytorch-caney/pytorch_caney/ptc_cli.py --config-path pytorch-caney/configs/3dcloudtask_swinv2_satvision_giant_test.yaml +``` + +### Script Behavior + +- **Pipeline Initialization**: The script initializes the pipeline using the `PIPELINES` registry, based on the `PIPELINE`value in the configuration file. +- **Model and Data Module Setup**: The script automatically detects and uses the appropriate `DATAMODULE` and `MODEL` components specified in the configuration. +- **Training Strategy**: The `get_strategy` function selects the optimal training strategy, including distributed training if applicable. +- **Checkpoints**: If a checkpoint path is provided in the configuration (`MODEL.RESUME`), training resumes from that checkpoint. +### Output + +The results, logs, and model checkpoints are saved in a directory specified by: +`///` + +Example: + +`./outputs/3dcloud-svtoa-finetune-giant/3dcloud_task_swinv2_g_satvision_128_scaled_bt_minmax/` + +### Configuration Details + +#### 3D Cloud Task Example Configurations + +```yaml +PIPELINE: '3dcloud' +DATAMODULE: 'abitoa3dcloud' +MODEL: + ENCODER: 'satvision' + DECODER: 'fcn' + PRETRAINED: satvision-toa-giant-patch8-window8-128/mp_rank_00_model_states.pt + TYPE: swinv2 + NAME: 3dcloud-svtoa-finetune-giant + IN_CHANS: 14 + DROP_PATH_RATE: 0.1 + SWINV2: + IN_CHANS: 14 + EMBED_DIM: 512 + DEPTHS: [ 2, 2, 42, 2 ] + NUM_HEADS: [ 16, 32, 64, 128 ] + WINDOW_SIZE: 8 + NORM_PERIOD: 6 +DATA: + BATCH_SIZE: 32 + DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + TEST_DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + IMG_SIZE: 128 +TRAIN: + USE_CHECKPOINT: True + EPOCHS: 50 + WARMUP_EPOCHS: 10 + BASE_LR: 3e-4 + MIN_LR: 2e-4 + WARMUP_LR: 1e-4 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] +LOSS: + NAME: 'bce' +PRECISION: 'bf16' +PRINT_FREQ: 10 +SAVE_FREQ: 50 +VALIDATION_FREQ: 20 +TAG: 3dcloud_task_swinv2_g_satvision_128_scaled_bt_minmax ``` -## References -- [Pytorch Lightning](https://github.com/Lightning-AI/lightning) -- [Swin Transformer](https://github.com/microsoft/Swin-Transformer) -- [SimMIM](https://github.com/microsoft/SimMIM) +#### FCN Baseline Configuration + +```yaml +PIPELINE: '3dcloud' +DATAMODULE: 'abitoa3dcloud' +MODEL: + ENCODER: 'fcn' + DECODER: 'fcn' + NAME: 3dcloud-fcn-baseline + IN_CHANS: 14 + DROP_PATH_RATE: 0.1 +DATA: + BATCH_SIZE: 32 + DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + TEST_DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + IMG_SIZE: 128 +TRAIN: + ACCELERATOR: 'gpu' + STRATEGY: 'auto' + EPOCHS: 50 + WARMUP_EPOCHS: 10 + BASE_LR: 3e-4 + MIN_LR: 2e-4 + WARMUP_LR: 1e-4 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] +LOSS: + NAME: 'bce' +PRINT_FREQ: 10 +SAVE_FREQ: 50 +VALIDATION_FREQ: 20 +TAG: 3dcloud_task_fcn_baseline_128_scaled_bt_minmax +``` + +### Key Components + +#### Model Components + +- **Encoder**: Handles feature extraction from input data. +- **Decoder**: Processes features into an intermediate representation. +- **Segmentation Head**: Produces the final output with a specific shape (91x40). + +#### Loss Function + +- **Binary Cross-Entropy Loss (`bce`)** is used for training. + +#### Metrics + +- **Jaccard Index (IoU)**: Evaluates model accuracy. +- **Mean Loss**: Tracks average loss during training and validation. + +#### Optimizer + +- Custom optimizer configurations are handled by `build_optimizer`. + +### Additional Notes + +- Customize your `DATAMODULE` and `MODEL` definitions as per the dataset and task requirements. +- To run the pipeline with GPUs, ensure your system has compatible hardware and CUDA installed. + +--- + +## Masked-Image-Modeling Pre-Training Pipeline + +--- + +For an example of how MiM pre-trained models work, see the example inference notebook in `pytorch-caney/notebooks/` + +# SatVision-TOA Model Input Data Generation and Pre-processing + +--- + +## Overview  + +- For expected model input see "Expected Model Input" section +- For steps taken for generating the MODIS-TOA pre-training dataset see "MODIS-TOA Dataset Generation" section + +![MODIS TOA Bands](docs/static/modis_toa_bands.png) + +## MODIS-TOA Dataset Generation + +The MODIS TOA dataset is derived from MODIS MOD02 Level 1B swaths, which provide calibrated and geolocated irradiances across 36 spectral bands. The data processing pipeline involves compositing, calibration, and normalization steps to convert raw data into a format suitable for deep learning model ingestion.  + +MODIS data comes in three spatial resolutions 250 m, 500 m and 1 km where bands 1 and 2 are natively 250 m, bands 3 – 7 are natively 500 m and bands 8 – 36 are natively 1 km.  For this work all bands need to be at the same spatial resolution so the finer resolution bands 1 – 7 have been aggregated to 1 km. + +The initial step involves compositing MODIS 5-minute swaths into daily global composites at 1 km spatial resolution. This step consolidates continuous swath data into a consistent daily global grid.  + +The SatVision TOA model is pre-trained on 14 MODIS band L1B Top-Of-Atmosphere (TOA) irradiance imageries. Bands were selected based on which ones were most similar to spectral profiles of other instruments such as GOES ABI. See Table 1 for mapping each band to one of the 14 indices and the central wavelength for each band. + +## Conversion to TOA Reflectance and Brightness Temperature + +After generating daily composites, digital numbers (DNs) from the MODIS bands are converted into Top-of-Atmosphere (TOA) reflectance for visible and Near-Infrared (NIR) bands, and brightness temperature (BT) for Thermal Infrared (TIR) bands. The conversion is guided by the MODIS Level 1B product user guide and implemented through the `SatPy` Python package. These transformations give the data physical units (reflectance and temperature).  + +## Expected Model Input + +The pre-processed data should closely match the bands listed in the table provided in the model documentation, ensuring that each input channel accurately corresponds to a specific MODIS band and spectral range. The exact bands required depend on the task; however, the general expectation is for consistency with the MODIS TOA reflectance and BT band specifications. + +## Equations for MODIS DN Conversion + +Radiance and reflectance scales and offsets are found in the MOD021KM metadata, specifically within each subdataset. + +Radiance: `radianceScales` and `radianceOffsets` + +Reflectance: `reflectanceScales` and `reflectanceOffsets` + +### Reflectance Calibration +The formula for converting MODIS DN values to TOA reflectance is: + +$$\text{Reflectance} = (DN - \text{reflectanceOffsets}) \times \text{reflectanceScales} \times 100$$ + +This formula scales and converts the values into percentage reflectance. + +### Brightness Temperature Calibration + +For TIR bands, the calibration to Brightness Temperature ($BT$) is more involved and relies on physical constants and the effective central wavenumber ($WN$). + +The equation for converting MODIS DN to BT is: + +$$\text{Radiance} = (DN - \text{radianceOffsets}) \times \text{radianceScales}$$ + + +$$BT = \frac{c_2}{\text{WN} \times \ln\left(\frac{c_1}{\text{Radiance} \times \text{WN}^5} + 1\right)}$$ + + +$$BT = \frac{(BT - tci)}{tcs}$$ + +Where:  + +- $c_1$ and $c_2$ are derived constants based on the Planck constant $h$, the speed of light $c$, and the Boltzmann constant $k$. +- $tcs$ is the temperature correction slope, and $tci$ is the temperature correction intercept. + +### Scaling for Machine Learning Compatibility + +Both TOA reflectance and BT values are scaled to a range of 0-1 to ensure compatibility with neural networks, aiding model convergence and training stability:  + + **TOA Reflectance Scaling** + +Reflectance values are scaled by a factor of 0.01, transforming the original 0-100 range to 0-1.  + +$$\text{TOA Reflectance (scaled)} = \text{TOA Reflectance} \times 0.01$$ + +**Brightness Temperature Scaling** + +Brightness temperatures are min-max scaled to a range of 0-1, based on global minimum and maximum values for each of the 8 TIR channels in the dataset. + +$$\text{Scaled Value} = \frac{\text{Original Value} - \text{Min}}{\text{Max} - \text{Min}}$$ + +This normalization process aligns the dynamic range of both feature types, contributing to more stable model performance.  + +## Example Python Code + +### MODIS L1B + +- https://github.com/pytroll/satpy/blob/main/satpy/readers/modis_l1b.py + +Below is an example of the Python code used in SatPy for calibrating radiance, reflectance, and BT for MODIS L1B products: + +```python + +def calibrate_radiance(array, attributes, index): + """Calibration for radiance channels.""" + offset = np.float32(attributes["radiance_offsets"][index]) + scale = np.float32(attributes["radiance_scales"][index]) + array = (array - offset) * scale + return array + + +def calibrate_refl(array, attributes, index): + """Calibration for reflective channels.""" + offset = np.float32(attributes["reflectance_offsets"][index]) + scale = np.float32(attributes["reflectance_scales"][index]) + # convert to reflectance and convert from 1 to % + array = (array - offset) * scale * 100 + return array + + +def calibrate_bt(array, attributes, index, band_name): + """Calibration for the emissive channels.""" + offset = np.float32(attributes["radiance_offsets"][index]) + scale = np.float32(attributes["radiance_scales"][index]) + + array = (array - offset) * scale + + # Planck constant (Joule second) + h__ = np.float32(6.6260755e-34) + + # Speed of light in vacuum (meters per second) + c__ = np.float32(2.9979246e+8) + + # Boltzmann constant (Joules per Kelvin) + k__ = np.float32(1.380658e-23) + + # Derived constants + c_1 = 2 * h__ * c__ * c__ + c_2 = (h__ * c__) / k__ + + # Effective central wavenumber (inverse centimeters) + cwn = np.array([ + 2.641775E+3, 2.505277E+3, 2.518028E+3, 2.465428E+3, + 2.235815E+3, 2.200346E+3, 1.477967E+3, 1.362737E+3, + 1.173190E+3, 1.027715E+3, 9.080884E+2, 8.315399E+2, + 7.483394E+2, 7.308963E+2, 7.188681E+2, 7.045367E+2], + dtype=np.float32) + + # Temperature correction slope (no units) + tcs = np.array([ + 9.993411E-1, 9.998646E-1, 9.998584E-1, 9.998682E-1, + 9.998819E-1, 9.998845E-1, 9.994877E-1, 9.994918E-1, + 9.995495E-1, 9.997398E-1, 9.995608E-1, 9.997256E-1, + 9.999160E-1, 9.999167E-1, 9.999191E-1, 9.999281E-1], + dtype=np.float32) + + # Temperature correction intercept (Kelvin) + tci = np.array([ + 4.770532E-1, 9.262664E-2, 9.757996E-2, 8.929242E-2, + 7.310901E-2, 7.060415E-2, 2.204921E-1, 2.046087E-1, + 1.599191E-1, 8.253401E-2, 1.302699E-1, 7.181833E-2, + 1.972608E-2, 1.913568E-2, 1.817817E-2, 1.583042E-2], + dtype=np.float32) + + # Transfer wavenumber [cm^(-1)] to wavelength [m] + cwn = 1. / (cwn * 100) + + # Some versions of the modis files do not contain all the bands. + emmissive_channels = ["20", "21", "22", "23", "24", "25", "27", "28", "29", + "30", "31", "32", "33", "34", "35", "36"] + global_index = emmissive_channels.index(band_name) + + cwn = cwn[global_index] + tcs = tcs[global_index] + tci = tci[global_index] + array = c_2 / (cwn * np.log(c_1 / (1000000 * array * cwn ** 5) + 1)) + array = (array - tci) / tcs + return array +``` + +### ABI L1B + +- https://github.com/pytroll/satpy/blob/main/satpy/readers/abi_l1b.py + +Below is an example of the Python code used in SatPy for calibrating radiance, reflectance, and BT for ABI L1B products: + +```python + def _rad_calibrate(self, data): + """Calibrate any channel to radiances. + + This no-op method is just to keep the flow consistent - + each valid cal type results in a calibration method call + """ + res = data + res.attrs = data.attrs + return res + + def _raw_calibrate(self, data): + """Calibrate any channel to raw counts. + + Useful for cases where a copy requires no calibration. + """ + res = data + res.attrs = data.attrs + res.attrs["units"] = "1" + res.attrs["long_name"] = "Raw Counts" + res.attrs["standard_name"] = "counts" + return res + + def _vis_calibrate(self, data): + """Calibrate visible channels to reflectance.""" + solar_irradiance = self["esun"] + esd = self["earth_sun_distance_anomaly_in_AU"] + + factor = np.pi * esd * esd / solar_irradiance + + res = data * np.float32(factor) + res.attrs = data.attrs + res.attrs["units"] = "1" + res.attrs["long_name"] = "Bidirectional Reflectance" + res.attrs["standard_name"] = "toa_bidirectional_reflectance" + return res + + def _get_minimum_radiance(self, data): + """Estimate minimum radiance from Rad DataArray.""" + attrs = data.attrs + scale_factor = attrs["scale_factor"] + add_offset = attrs["add_offset"] + count_zero_rad = - add_offset / scale_factor + count_pos = np.ceil(count_zero_rad) + min_rad = count_pos * scale_factor + add_offset + return min_rad + + def _ir_calibrate(self, data): + """Calibrate IR channels to BT.""" + fk1 = float(self["planck_fk1"]) + fk2 = float(self["planck_fk2"]) + bc1 = float(self["planck_bc1"]) + bc2 = float(self["planck_bc2"]) + + if self.clip_negative_radiances: + min_rad = self._get_minimum_radiance(data) + data = data.clip(min=data.dtype.type(min_rad)) + + res = (fk2 / np.log(fk1 / data + 1) - bc1) / bc2 + res.attrs = data.attrs + res.attrs["units"] = "K" + res.attrs["long_name"] = "Brightness Temperature" + res.attrs["standard_name"] = "toa_brightness_temperature" + return res +``` + +### Performing scaling as a torch transform + +For MODIS-TOA data: + +```python +import numpy as np + + +# ----------------------------------------------------------------------- +# MinMaxEmissiveScaleReflectance +# ----------------------------------------------------------------------- +class MinMaxEmissiveScaleReflectance(object): + """ + Performs scaling of MODIS TOA data + - Scales reflectance percentages to reflectance units (% -> (0,1)) + - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) + """ + + def __init__(self): + + self.reflectance_indices = [0, 1, 2, 3, 4, 6] + self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] + + self.emissive_mins = np.array( + [223.1222, 178.9174, 204.3739, 204.7677, + 194.8686, 202.1759, 201.3823, 203.3537], + dtype=np.float32) + + self.emissive_maxs = np.array( + [352.7182, 261.2920, 282.5529, 319.0373, + 295.0209, 324.0677, 321.5254, 285.9848], + dtype=np.float32) + + def __call__(self, img): + + # Reflectance % to reflectance units + img[:, :, self.reflectance_indices] = \ + img[:, :, self.reflectance_indices] * 0.01 + + # Brightness temp scaled to (0,1) range + img[:, :, self.emissive_indices] = \ + (img[:, :, self.emissive_indices] - self.emissive_mins) / \ + (self.emissive_maxs - self.emissive_mins) + + return img + + +# ------------------------------------------------------------------------ +# ModisToaTransform +# ------------------------------------------------------------------------ +class ModisToaTransform: + """ + torchvision transform which transforms the input imagery + """ + + def __init__(self, config): + + self.transform_img = \ + T.Compose([ + MinMaxEmissiveScaleReflectance(), + T.ToTensor(), + T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), + ]) + + def __call__(self, img): + + img = self.transform_img(img) + + return img +``` + +For ABI data + +```python +# ----------------------------------------------------------------------- +# ConvertABIToReflectanceBT +# ----------------------------------------------------------------------- +class ConvertABIToReflectanceBT(object): + """ + Performs scaling of MODIS TOA data + - Scales reflectance percentages to reflectance units (% -> (0,1)) + - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) + """ + + def __init__(self): + + self.reflectance_indices = [0, 1, 2, 3, 4, 6] + self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] + + def __call__(self, img): + + # Digital Numbers to TOA reflectance units + img[:, :, self.reflectance_indices] = \ + vis_calibrate(img[:, :, self.reflectance_indices]) + + # Digital Numbers -> Radiance -> Brightness Temp (K) + img[:, :, self.emissive_indices] = ir_calibrate(img[:, :, self.emissive_indices]) + + return img + +# ------------------------------------------------------------------------ +# MinMaxEmissiveScaleReflectance +# ------------------------------------------------------------------------ +class MinMaxEmissiveScaleReflectance(object): + """ + Performs scaling of MODIS TOA data + - Scales reflectance percentages to reflectance units (% -> (0,1)) + - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) + """ + + def __init__(self): + + self.reflectance_indices = [0, 1, 2, 3, 4, 6] + self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] + + self.emissive_mins = np.array( + [117.04327, 152.00592, 157.96591, 176.15349, + 210.60493, 210.52264, 218.10147, 225.9894], + dtype=np.float32) + + self.emissive_maxs = np.array( + [221.07022, 224.44113, 242.3326, 307.42004, + 290.8879, 343.72617, 345.72894, 323.5239], + dtype=np.float32) + + def __call__(self, img): + + # Reflectance % to reflectance units + img[:, :, self.reflectance_indices] = \ + img[:, :, self.reflectance_indices] * 0.01 + + # Brightness temp scaled to (0,1) range + img[:, :, self.emissive_indices] = \ + (img[:, :, self.emissive_indices] - self.emissive_mins) / \ + (self.emissive_maxs - self.emissive_mins) + + return img + +# ------------------------------------------------------------------------ +# AbiToaTransform +# ------------------------------------------------------------------------ +class AbiToaTransform: + """ + torchvision transform which transforms the input imagery into + addition to generating a MiM mask + """ + + def __init__(self, img_size): + + self.transform_img = \ + T.Compose([ + ConvertABIToReflectanceBT(), + MinMaxEmissiveScaleReflectance(), + T.ToTensor(), + T.Resize((img_size, img_size)), + ]) + + def __call__(self, img): + + img = self.transform_img(img) + + return img + +``` diff --git a/configs/3dcloudtask_fcn_baseline_test.yaml b/configs/3dcloudtask_fcn_baseline_test.yaml new file mode 100644 index 0000000..19d4ed7 --- /dev/null +++ b/configs/3dcloudtask_fcn_baseline_test.yaml @@ -0,0 +1,32 @@ +PIPELINE: '3dcloud' +DATAMODULE: 'abitoa3dcloud' +MODEL: + ENCODER: 'fcn' + DECODER: 'fcn' + NAME: 3dcloud-fcn-baseline + IN_CHANS: 14 + DROP_PATH_RATE: 0.1 +DATA: + BATCH_SIZE: 32 + DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + TEST_DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + IMG_SIZE: 128 +TRAIN: + ACCELERATOR: 'gpu' + STRATEGY: 'auto' + EPOCHS: 50 + WARMUP_EPOCHS: 10 + BASE_LR: 3e-4 + MIN_LR: 2e-4 + WARMUP_LR: 1e-4 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] +LOSS: + NAME: 'bce' +PRINT_FREQ: 10 +SAVE_FREQ: 50 +VALIDATION_FREQ: 20 +TAG: 3dcloud_task_fcn_baseline_128_scaled_bt_minmax diff --git a/configs/3dcloudtask_swinv2_satvision_giant_test.yaml b/configs/3dcloudtask_swinv2_satvision_giant_test.yaml new file mode 100644 index 0000000..2d67c57 --- /dev/null +++ b/configs/3dcloudtask_swinv2_satvision_giant_test.yaml @@ -0,0 +1,41 @@ +PIPELINE: '3dcloud' +DATAMODULE: 'abitoa3dcloud' +MODEL: + ENCODER: 'satvision' + DECODER: 'fcn' + PRETRAINED: /panfs/ccds02/nobackup/projects/ilab/projects/3DClouds/models/SV-TOA/3B_2M/mp_rank_00_model_states.pt + TYPE: swinv2 + NAME: 3dcloud-svtoa-finetune-giant + IN_CHANS: 14 + DROP_PATH_RATE: 0.1 + SWINV2: + IN_CHANS: 14 + EMBED_DIM: 512 + DEPTHS: [ 2, 2, 42, 2 ] + NUM_HEADS: [ 16, 32, 64, 128 ] + WINDOW_SIZE: 8 + NORM_PERIOD: 6 +DATA: + BATCH_SIZE: 32 + DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + TEST_DATA_PATHS: [/explore/nobackup/projects/ilab/data/satvision-toa/3dcloud.data/abiChipsNew/] + IMG_SIZE: 128 +TRAIN: + USE_CHECKPOINT: True + EPOCHS: 50 + WARMUP_EPOCHS: 10 + BASE_LR: 3e-4 + MIN_LR: 2e-4 + WARMUP_LR: 1e-4 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] +LOSS: + NAME: 'bce' +PRECISION: 'bf16' +PRINT_FREQ: 10 +SAVE_FREQ: 50 +VALIDATION_FREQ: 20 +TAG: 3dcloud_task_swinv2_g_satvision_128_scaled_bt_minmax diff --git a/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep.yaml b/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep.yaml new file mode 100644 index 0000000..cfa75d6 --- /dev/null +++ b/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep.yaml @@ -0,0 +1,48 @@ +PIPELINE: 'satvisiontoapretrain' + +MODEL: + TYPE: swinv2 + NAME: mim_satvision_pretrain-giant + DROP_PATH_RATE: 0.1 + SWINV2: + IN_CHANS: 14 + EMBED_DIM: 512 + DEPTHS: [ 2, 2, 42, 2 ] + NUM_HEADS: [ 16, 32, 64, 128 ] + WINDOW_SIZE: 8 + NORM_PERIOD: 6 + +DATA: + DATAMODULE: False + BATCH_SIZE: 64 + LENGTH: 1_920_000 + PIN_MEMORY: True + NUM_WORKERS: 4 + DATA_PATHS: [/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets/shards] + IMG_SIZE: 128 + MASK_PATCH_SIZE: 8 + MASK_RATIO: 0.6 + +TRAIN: + ACCELERATOR: 'gpu' + STRATEGY: 'deepspeed' + USE_CHECKPOINT: True + EPOCHS: 50 + WARMUP_EPOCHS: 10 + BASE_LR: 3e-4 + MIN_LR: 2e-4 + WARMUP_LR: 1e-4 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] + +DEEPSPEED: + STAGE: 2 + +PRECISION: 'bf16' + +PRINT_FREQ: 10 +SAVE_FREQ: 50 +TAG: mim_pretrain_giant_satvision_128_scaled_bt_minmax_50ep diff --git a/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep_resume.yaml b/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep_resume.yaml new file mode 100644 index 0000000..a338624 --- /dev/null +++ b/configs/mim_pretrain_swinv2_satvision_giant_128_onecycle_100ep_resume.yaml @@ -0,0 +1,49 @@ +PIPELINE: 'satvisiontoapretrain' + +MODEL: + TYPE: swinv2 + NAME: mim_satvision_pretrain-giant + DROP_PATH_RATE: 0.1 + PRETRAINED: /panfs/ccds02/nobackup/projects/ilab/projects/3DClouds/models/SV-TOA/3B_2M/mp_rank_00_model_states.pt + SWINV2: + IN_CHANS: 14 + EMBED_DIM: 512 + DEPTHS: [ 2, 2, 42, 2 ] + NUM_HEADS: [ 16, 32, 64, 128 ] + WINDOW_SIZE: 8 + NORM_PERIOD: 6 + +DATA: + DATAMODULE: False + BATCH_SIZE: 64 + LENGTH: 1_920_000 + PIN_MEMORY: True + NUM_WORKERS: 4 + DATA_PATHS: [/explore/nobackup/projects/ilab/projects/3DClouds/data/mosaic-v3/webdatasets/shards] + IMG_SIZE: 128 + MASK_PATCH_SIZE: 8 + MASK_RATIO: 0.6 + +TRAIN: + ACCELERATOR: 'gpu' + STRATEGY: 'deepspeed' + USE_CHECKPOINT: True + EPOCHS: 50 + WARMUP_EPOCHS: 10 + BASE_LR: 3e-4 + MIN_LR: 2e-4 + WARMUP_LR: 1e-4 + WEIGHT_DECAY: 0.05 + LR_SCHEDULER: + NAME: 'multistep' + GAMMA: 0.1 + MULTISTEPS: [700,] + +DEEPSPEED: + STAGE: 2 + +PRECISION: 'bf16' + +PRINT_FREQ: 10 +SAVE_FREQ: 50 +TAG: mim_pretrain_giant_satvision_128_scaled_bt_minmax_50ep_resume diff --git a/docs/static/modis_toa_bands.png b/docs/static/modis_toa_bands.png new file mode 100644 index 0000000..cee5f71 Binary files /dev/null and b/docs/static/modis_toa_bands.png differ diff --git a/examples/inference/satvision-toa-reconstruction_giant.py b/examples/inference/satvision-toa-reconstruction_giant.py deleted file mode 100644 index a47755e..0000000 --- a/examples/inference/satvision-toa-reconstruction_giant.py +++ /dev/null @@ -1,296 +0,0 @@ -import argparse -import glob -import datetime -import os -import logging -import numpy as np -import torch -import torchvision.transforms as T -from pytorch_caney.data.utils import SimmimMaskGenerator -import matplotlib.pyplot as plt -from matplotlib.backends.backend_pdf import PdfPages -from tqdm import tqdm - -import warnings -warnings.filterwarnings('ignore') - -from pytorch_caney.config import _C, _update_config_from_file -from pytorch_caney.models.build import build_model -from pytorch_caney.data.transforms import SimmimMaskGenerator - - -# Dictionary to map indices to band numbers -idx_to_band = { - 0: 1, - 1: 2, - 2: 3, - 3: 6, - 4: 7, - 5: 21, - 6: 26, - 7: 27, - 8: 28, - 9: 29, - 10: 30, - 11: 31, - 12: 32, - 13: 33 -} - - -def parse_args(): - parser = argparse.ArgumentParser(description="Predict and generate PDF using a pre-trained model.") - parser.add_argument('--pretrained_model_dir', type=str, required=True, help="Directory containing pre-trained model files (including .pt and .yaml)") - parser.add_argument("--output_dir", required=True, help="Directory where the output PDF will be saved.") - parser.add_argument("--data_path", default='/explore/nobackup/projects/ilab/projects/3DClouds/data/validation/sv_toa_128_chip_validation_04_24.npy', help="Path to validation data file.") - return parser.parse_args() - - -# Load model and config -def load_config_and_model(pretrained_model_dir, validation_data_path): - # Search for .pt and .yaml files - model_path = os.path.join(pretrained_model_dir, 'mp_rank_00_model_states.pt') - config_path = glob.glob(os.path.join(pretrained_model_dir, '*.yaml')) - if len(config_path) == 0: - raise FileNotFoundError(f"No YAML config found in {pretrained_model_dir}") - config_path = config_path[0] - - # Load config - config = _C.clone() - _update_config_from_file(config, config_path) - config.defrost() - config.MODEL.RESUME = model_path - config.DATA.DATA_PATHS = [validation_data_path] - config.OUTPUT = pretrained_model_dir - config.TAG = 'satvision-huge-toa-reconstruction' - config.freeze() - - # Load model - checkpoint = torch.load(model_path) - model = build_model(config, pretrain=True) - model.load_state_dict(checkpoint['module']) # Use 'model' if 'module' not present - model.eval() - - return model, config - - -def configure_logging(): - logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')) - logger = logging.getLogger('') - logger.addHandler(console) - return logger - - -class MinMaxEmissiveScaleReflectance(object): - """ - Performs scaling of MODIS TOA data - - Scales reflectance percentages to reflectance units (% -> (0,1)) - - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) - """ - - def __init__(self): - - self.reflectance_indices = [0, 1, 2, 3, 4, 6] - self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] - - self.emissive_mins = np.array( - [223.1222, 178.9174, 204.3739, 204.7677, - 194.8686, 202.1759, 201.3823, 203.3537], - dtype=np.float32) - - self.emissive_maxs = np.array( - [352.7182, 261.2920, 282.5529, 319.0373, - 295.0209, 324.0677, 321.5254, 285.9848], - dtype=np.float32) - - def __call__(self, img): - - # Reflectance % to reflectance units - img[:, :, self.reflectance_indices] = \ - img[:, :, self.reflectance_indices] * 0.01 - - # Brightness temp scaled to (0,1) range - img[:, :, self.emissive_indices] = \ - (img[:, :, self.emissive_indices] - self.emissive_mins) / \ - (self.emissive_maxs - self.emissive_mins) - - return img - - -class SimmimTransform: - """ - torchvision transform which transforms the input imagery into - addition to generating a MiM mask - """ - - def __init__(self, config): - - self.transform_img = \ - T.Compose([ - MinMaxEmissiveScaleReflectance(), # New transform for MinMax - T.ToTensor(), - T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), - ]) - - if config.MODEL.TYPE in ['swin', 'swinv2']: - - model_patch_size = config.MODEL.SWINV2.PATCH_SIZE - - else: - - raise NotImplementedError - - self.mask_generator = SimmimMaskGenerator( - input_size=config.DATA.IMG_SIZE, - mask_patch_size=config.DATA.MASK_PATCH_SIZE, - model_patch_size=model_patch_size, - mask_ratio=config.DATA.MASK_RATIO, - ) - - def __call__(self, img): - - img = self.transform_img(img) - mask = self.mask_generator() - - return img, mask - - -def get_batch_info(img): - channels = img.shape[1] - - for channelIdx in range(channels): - channel = idx_to_band.get(channelIdx, 'Unknown') # Retrieve band number or mark as 'Unknown' - img_band_array = img[:, channelIdx, :, :] - min_ = img_band_array.min().item() - mean_ = img_band_array.mean().item() - max_ = img_band_array.max().item() - print(f'Channel {channel}, min {min_}, mean {mean_}, max {max_}') - - -def load_validation_data(config): - validation_dataset_path = config.DATA.DATA_PATHS[0] - validation_dataset = np.load(validation_dataset_path) - transform = SimmimTransform(config) - imgMasks = [transform(validation_dataset[idx]) for idx in range(validation_dataset.shape[0])] - img = torch.stack([imgMask[0] for imgMask in imgMasks]) - mask = torch.stack([torch.from_numpy(imgMask[1]) for imgMask in imgMasks]) - return img, mask - - -def predict(model, img, mask): - inputs, outputs, masks, losses = [], [], [], [] - for i in tqdm(range(img.shape[0])): - single_img = img[i].unsqueeze(0) - single_mask = mask[i].unsqueeze(0) - - with torch.no_grad(): - z = model.encoder(single_img, single_mask) - img_recon = model.decoder(z) - loss = model(single_img, single_mask) - - inputs.extend(single_img.cpu()) - masks.extend(single_mask.cpu()) - outputs.extend(img_recon.cpu()) - losses.append(loss.cpu()) - - return inputs, outputs, masks, losses - -def process_mask(mask): - mask_img = mask.unsqueeze(0) - mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous() - mask_img = mask_img[0, 0, :, :] - mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1) - return mask_img - - -def minmax_norm(img_arr): - arr_min = img_arr.min() - arr_max = img_arr.max() - img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) - img_arr_scaled = img_arr_scaled * 255 - img_arr_scaled = img_arr_scaled.astype(np.uint8) - return img_arr_scaled - - -def reverse_transform(image): - minMaxTransform = MinMaxEmissiveScaleReflectance() - image = image.transpose((1,2,0)) - - image[:, :, minMaxTransform.reflectance_indices] = image[:, :, minMaxTransform.reflectance_indices] * 100 - image[:, :, minMaxTransform.emissive_indices] = ( - image[:, :, minMaxTransform.emissive_indices] * \ - (minMaxTransform.emissive_maxs - minMaxTransform.emissive_mins)) + minMaxTransform.emissive_mins - - image = image.transpose((2,0,1)) - return image - - -def process_prediction(image, img_recon, mask, rgb_index): - mask = process_mask(mask) - - red_idx = rgb_index[0] - blue_idx = rgb_index[1] - green_idx = rgb_index[2] - - image = reverse_transform(image.numpy()) - img_recon = reverse_transform(img_recon.numpy()) - - rgb_image = np.stack((image[red_idx, :, :], image[blue_idx, :, :], image[green_idx, :, :]), axis=-1) - rgb_image = minmax_norm(rgb_image) - - rgb_image_recon = np.stack((img_recon[red_idx, :, :], img_recon[blue_idx, :, :], img_recon[green_idx, :, :]), axis=-1) - rgb_image_recon = minmax_norm(rgb_image_recon) - - rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) - rgb_image_masked = np.where(mask == 1, 0, rgb_image) - rgb_recon_masked = rgb_masked - - return rgb_image, rgb_image_masked, rgb_recon_masked, mask - - -def plot_export_pdf(path, inputs, outputs, masks, rgb_index): - pdf_plot_obj = PdfPages(path) - for idx in range(len(inputs)): - rgb_image, rgb_image_masked, rgb_recon_masked, mask = process_prediction(inputs[idx], outputs[idx], masks[idx], rgb_index) - - fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30)) - ax0, ax1 = ax01 - ax2, ax3 = ax23 - - ax2.imshow(rgb_image) - ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}") - - ax0.imshow(rgb_recon_masked) - ax0.set_title(f"Idx: {idx} Model reconstruction") - - ax1.imshow(rgb_image_masked) - ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked") - - ax3.matshow(mask[:, :, 0]) - ax3.set_title(f"Idx: {idx} Reconstruction Mask") - - pdf_plot_obj.savefig() - - pdf_plot_obj.close() - - -if __name__ == "__main__": - args = parse_args() - model, config = load_config_and_model(args.pretrained_model_dir, args.data_path) - logger = configure_logging() - - img, mask = load_validation_data(config) - logger.info("Logging batch information before predictions:") - get_batch_info(img) - imgs = np.asarray(img) - channel_ranges = [abs(imgs[:, channel].max() - imgs[:, channel].min()) for channel in range(0, 14)] - - inputs, outputs, masks, losses = predict(model, img, mask) - - output_pdf_path = os.path.join(args.output_dir, f'satvision-toa-reconstruction-giant-{datetime.datetime.now().strftime("%Y-%m-%d")}.pdf') - rgb_index = [0, 2, 1] # Red, Green, Blue band indices - plot_export_pdf(output_pdf_path, inputs, outputs, masks, rgb_index) - logger.info(f"PDF saved to {output_pdf_path}") diff --git a/examples/inference/satvision-toa-reconstruction_huge.py b/examples/inference/satvision-toa-reconstruction_huge.py deleted file mode 100644 index f006f65..0000000 --- a/examples/inference/satvision-toa-reconstruction_huge.py +++ /dev/null @@ -1,207 +0,0 @@ -import argparse -import glob -import datetime -import os -import logging -import numpy as np -import torch -import matplotlib.pyplot as plt -from matplotlib.backends.backend_pdf import PdfPages -from tqdm import tqdm - -import warnings -warnings.filterwarnings('ignore') - -from pytorch_caney.config import _C, _update_config_from_file -from pytorch_caney.models.build import build_model -from pytorch_caney.data.transforms import SimmimTransform - - -# Dictionary to map indices to band numbers -idx_to_band = { - 0: 1, - 1: 2, - 2: 3, - 3: 6, - 4: 7, - 5: 21, - 6: 26, - 7: 27, - 8: 28, - 9: 29, - 10: 30, - 11: 31, - 12: 32, - 13: 33 -} - - -def parse_args(): - parser = argparse.ArgumentParser(description="Predict and generate PDF using a pre-trained model.") - parser.add_argument('--pretrained_model_dir', type=str, required=True, help="Directory containing pre-trained model files (including .pt and .yaml)") - parser.add_argument("--output_dir", required=True, help="Directory where the output PDF will be saved.") - parser.add_argument("--data_path", default='/explore/nobackup/projects/ilab/projects/3DClouds/data/validation/sv_toa_128_chip_validation_04_24.npy', help="Path to validation data file.") - return parser.parse_args() - - -# Load model and config -def load_config_and_model(pretrained_model_dir, validation_data_path): - # Search for .pt and .yaml files - model_path = os.path.join(pretrained_model_dir, 'mp_rank_00_model_states.pt') - config_path = glob.glob(os.path.join(pretrained_model_dir, '*.yaml')) - if len(config_path) == 0: - raise FileNotFoundError(f"No YAML config found in {pretrained_model_dir}") - config_path = config_path[0] - - # Load config - config = _C.clone() - _update_config_from_file(config, config_path) - config.defrost() - config.MODEL.RESUME = model_path - config.DATA.DATA_PATHS = [validation_data_path] - config.OUTPUT = pretrained_model_dir - config.TAG = 'satvision-huge-toa-reconstruction' - config.freeze() - - # Load model - checkpoint = torch.load(model_path) - model = build_model(config, pretrain=True) - model.load_state_dict(checkpoint['module']) # Use 'model' if 'module' not present - model.eval() - - return model, config - - -def configure_logging(): - logging.basicConfig(filename='app.log', level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S') - console = logging.StreamHandler() - console.setLevel(logging.INFO) - console.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')) - logger = logging.getLogger('') - logger.addHandler(console) - return logger - - -def load_validation_data(config): - validation_dataset_path = config.DATA.DATA_PATHS[0] - validation_dataset = np.load(validation_dataset_path) - transform = SimmimTransform(config) - imgMasks = [transform(validation_dataset[idx]) for idx in range(validation_dataset.shape[0])] - img = torch.stack([imgMask[0] for imgMask in imgMasks]) - mask = torch.stack([torch.from_numpy(imgMask[1]) for imgMask in imgMasks]) - return img, mask - - -def get_batch_info(img): - channels = img.shape[1] - - for channelIdx in range(channels): - channel = idx_to_band.get(channelIdx, 'Unknown') # Retrieve band number or mark as 'Unknown' - img_band_array = img[:, channelIdx, :, :] - min_ = img_band_array.min().item() - mean_ = img_band_array.mean().item() - max_ = img_band_array.max().item() - print(f'Channel {channel}, min {min_}, mean {mean_}, max {max_}') - - -def predict(model, img, mask): - inputs, outputs, masks, losses = [], [], [], [] - for i in tqdm(range(img.shape[0])): - single_img = img[i].unsqueeze(0) - single_mask = mask[i].unsqueeze(0) - - with torch.no_grad(): - z = model.encoder(single_img, single_mask) - img_recon = model.decoder(z) - loss = model(single_img, single_mask) - - inputs.extend(single_img.cpu()) - masks.extend(single_mask.cpu()) - outputs.extend(img_recon.cpu()) - losses.append(loss.cpu()) - - return inputs, outputs, masks, losses - - -def process_mask(mask): - mask_img = mask.unsqueeze(0) - mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2).unsqueeze(1).contiguous() - mask_img = mask_img[0, 0, :, :] - mask_img = np.stack([mask_img, mask_img, mask_img], axis=-1) - return mask_img - - -def minmax_norm(img_arr): - arr_min = img_arr.min() - arr_max = img_arr.max() - img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) - img_arr_scaled = img_arr_scaled * 255 - img_arr_scaled = img_arr_scaled.astype(np.uint8) - return img_arr_scaled - - -def process_prediction(image, img_recon, mask, rgb_index): - mask = process_mask(mask) - - red_idx = rgb_index[0] - blue_idx = rgb_index[1] - green_idx = rgb_index[2] - - image = image.numpy() - rgb_image = np.stack((image[red_idx, :, :], image[blue_idx, :, :], image[green_idx, :, :]), axis=-1) - rgb_image = minmax_norm(rgb_image) - - img_recon = img_recon.numpy() - rgb_image_recon = np.stack((img_recon[red_idx, :, :], img_recon[blue_idx, :, :], img_recon[green_idx, :, :]), axis=-1) - rgb_image_recon = minmax_norm(rgb_image_recon) - - rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) - rgb_image_masked = np.where(mask == 1, 0, rgb_image) - rgb_recon_masked = rgb_masked - - return rgb_image, rgb_image_masked, rgb_recon_masked, mask - - -def plot_export_pdf(path, inputs, outputs, masks, rgb_index): - pdf_plot_obj = PdfPages(path) - for idx in range(len(inputs)): - rgb_image, rgb_image_masked, rgb_recon_masked, mask = process_prediction(inputs[idx], outputs[idx], masks[idx], rgb_index) - - fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30)) - ax0, ax1 = ax01 - ax2, ax3 = ax23 - - ax2.imshow(rgb_image) - ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}") - - ax0.imshow(rgb_recon_masked) - ax0.set_title(f"Idx: {idx} Model reconstruction") - - ax1.imshow(rgb_image_masked) - ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked") - - ax3.matshow(mask[:, :, 0]) - ax3.set_title(f"Idx: {idx} Reconstruction Mask") - - pdf_plot_obj.savefig() - - pdf_plot_obj.close() - - -if __name__ == "__main__": - args = parse_args() - model, config = load_config_and_model(args.pretrained_model_dir, args.data_path) - logger = configure_logging() - - img, mask = load_validation_data(config) - get_batch_info(img) - - imgs = np.asarray(img) - channel_ranges = [abs(imgs[:, channel].max() - imgs[:, channel].min()) for channel in range(0, 14)] - - inputs, outputs, masks, losses = predict(model, img, mask) - - output_pdf_path = os.path.join(args.output_dir, f'satvision-toa-reconstruction-huge-{datetime.datetime.now().strftime("%Y-%m-%d")}.pdf') - rgb_index = [0, 2, 1] # Red, Green, Blue band indices - plot_export_pdf(output_pdf_path, inputs, outputs, masks, rgb_index) - logger.info(f"PDF saved to {output_pdf_path}") diff --git a/examples/inference/svtoa_reconstruction_giant_slurm.sh b/examples/inference/svtoa_reconstruction_giant_slurm.sh deleted file mode 100644 index 3dc6ea2..0000000 --- a/examples/inference/svtoa_reconstruction_giant_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -#SBATCH -G 1 -#SBATCH --time 2:00:00 -#SBATCH -N 1 -#SBATCH -J sv-inference-giant - -module load singularity - -srun singularity exec --nv --env PYTHONPATH=$PWD:$PWD/pytorch-caney -B /explore,/panfs /explore/nobackup/projects/ilab/containers/pytorch-caney-2024-08.dev python pytorch-caney/examples/inference/satvision-toa-reconstruction_giant.py --pretrained_model_dir /explore/nobackup/people/cssprad1/projects/satvision-toa/models/satvision-toa-giant-patch8-window8-128 --output_dir . \ No newline at end of file diff --git a/examples/inference/svtoa_reconstruction_huge_slurm.sh b/examples/inference/svtoa_reconstruction_huge_slurm.sh deleted file mode 100644 index 2ad686d..0000000 --- a/examples/inference/svtoa_reconstruction_huge_slurm.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash -#SBATCH -G 1 -#SBATCH --time 2:00:00 -#SBATCH -N 1 -#SBATCH -J sv-inference-huge - -module load singularity - -srun singularity exec --nv --env PYTHONPATH=$PWD:$PWD/pytorch-caney -B /explore,/panfs /explore/nobackup/projects/ilab/containers/pytorch-caney-2024-08.dev python pytorch-caney/examples/inference/satvision-toa-reconstruction_giant.py --pretrained_model_dir /explore/nobackup/people/cssprad1/projects/satvision-toa/models/satvision-toa-giant-patch8-window8-128 --output_dir . \ No newline at end of file diff --git a/examples/satvision-giant/README.md b/examples/satvision-giant/README.md deleted file mode 100644 index 7fe3149..0000000 --- a/examples/satvision-giant/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# SatVision-Giant (SwinV2-Giant) - -`sbatch run_satvision_pretrain ` \ No newline at end of file diff --git a/examples/satvision-giant/mim_pretrain_swinv2_satvision_giant_192_window12_200ep.yaml b/examples/satvision-giant/mim_pretrain_swinv2_satvision_giant_192_window12_200ep.yaml deleted file mode 100644 index 0872262..0000000 --- a/examples/satvision-giant/mim_pretrain_swinv2_satvision_giant_192_window12_200ep.yaml +++ /dev/null @@ -1,29 +0,0 @@ -MODEL: - TYPE: swinv2 - NAME: mim_satvision_pretrain-giant - DROP_PATH_RATE: 0.1 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 512 - DEPTHS: [ 2, 2, 42, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 12 - NORM_PERIOD: 6 - -DATA: - IMG_SIZE: 192 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -TRAIN: - EPOCHS: 200 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.05 - LR_SCHEDULER: - NAME: 'multistep' - GAMMA: 0.1 - MULTISTEPS: [700,] -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: mim_pretrain_swinv2_g_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision-giant/run_satvision_pretrain.sh b/examples/satvision-giant/run_satvision_pretrain.sh deleted file mode 100755 index 770c7b7..0000000 --- a/examples/satvision-giant/run_satvision_pretrain.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash - -#SBATCH -J deepspeed-satvision-giant -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - -module load singularity - -srun -n 1 singularity exec \ - --env PYTHONPATH="$PWD:$PWD/pytorch-caney" \ - --nv -B /lscratch,/explore,/panfs \ - $1 \ - deepspeed \ - pytorch-caney/pytorch_caney/pipelines/pretraining/mim_deepspeed.py \ - --cfg $2 \ - --dataset MODIS \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \ - --batch-size 32 \ - --output . \ - --enable-amp - - - diff --git a/examples/satvision-huge/README.md b/examples/satvision-huge/README.md deleted file mode 100644 index 6c607d5..0000000 --- a/examples/satvision-huge/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# SatVision-Huge (SwinV2-Huge) - -`sbatch run_satvision_pretrain ` diff --git a/examples/satvision-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml b/examples/satvision-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml deleted file mode 100644 index 9e31f10..0000000 --- a/examples/satvision-huge/mim_pretrain_swinv2_satvision_huge_192_window12_200ep.yaml +++ /dev/null @@ -1,29 +0,0 @@ -MODEL: - TYPE: swinv2 - NAME: mim_satvision_pretrain-huge - DROP_PATH_RATE: 0.1 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 352 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 12 - NORM_PERIOD: 6 - -DATA: - IMG_SIZE: 192 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -TRAIN: - EPOCHS: 200 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.05 - LR_SCHEDULER: - NAME: 'multistep' - GAMMA: 0.1 - MULTISTEPS: [700,] -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: mim_pretrain_swinv2_h_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision-huge/run_satvision_pretrain.sh b/examples/satvision-huge/run_satvision_pretrain.sh deleted file mode 100755 index e91edbe..0000000 --- a/examples/satvision-huge/run_satvision_pretrain.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash - -#SBATCH -J deepspeed-satvision-giant -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - -module load singularity - -srun -n 1 singularity exec \ - --env PYTHONPATH="$PWD:$PWD/pytorch-caney" \ - --nv -B /lscratch,/explore,/panfs \ - $1 \ - deepspeed \ - pytorch-caney/pytorch_caney/pipelines/pretraining/mim_deepspeed.py \ - --cfg $2 \ - --dataset MODIS \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \ - --batch-size 32 \ - --output . \ - --enable-amp - diff --git a/examples/satvision-toa-finetune/finetune_satvision_base_landcover5class_192_window12_100ep.yaml b/examples/satvision-toa-finetune/finetune_satvision_base_landcover5class_192_window12_100ep.yaml deleted file mode 100644 index 5f41c64..0000000 --- a/examples/satvision-toa-finetune/finetune_satvision_base_landcover5class_192_window12_100ep.yaml +++ /dev/null @@ -1,33 +0,0 @@ -MODEL: - TYPE: swinv2 - DECODER: unet - NAME: satvision_finetune_lc5class - DROP_PATH_RATE: 0.1 - NUM_CLASSES: 5 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 128 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 14 - PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] -DATA: - IMG_SIZE: 224 - DATASET: MODISLC5 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -LOSS: - NAME: 'tversky' - MODE: 'multiclass' - ALPHA: 0.4 - BETA: 0.6 -TRAIN: - EPOCHS: 100 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.01 - LAYER_DECAY: 0.8 -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision-toa-finetune/finetune_satvision_base_landcover9class_192_window12_100ep.yaml b/examples/satvision-toa-finetune/finetune_satvision_base_landcover9class_192_window12_100ep.yaml deleted file mode 100644 index f55651a..0000000 --- a/examples/satvision-toa-finetune/finetune_satvision_base_landcover9class_192_window12_100ep.yaml +++ /dev/null @@ -1,33 +0,0 @@ -MODEL: - TYPE: swinv2 - DECODER: unet - NAME: satvision_toa_finetune_lc9class - DROP_PATH_RATE: 0.1 - NUM_CLASSES: 9 - SWINV2: - IN_CHANS: 14 - EMBED_DIM: 352 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 14 - NORM_PERIOD: 6 -DATA: - IMG_SIZE: 224 - DATASET: MODISLC9 - MASK_PATCH_SIZE: 8 - MASK_RATIO: 0.6 -LOSS: - NAME: 'tversky' - MODE: 'multiclass' - ALPHA: 0.4 - BETA: 0.6 -TRAIN: - EPOCHS: 100 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.01 - LAYER_DECAY: 0.8 -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: satvision_toa_finetune_land_cover_9class_swinv2_satvision_224_window12__100ep diff --git a/examples/satvision-toa-finetune/run_satvision_finetune_lc_fiveclass.sh b/examples/satvision-toa-finetune/run_satvision_finetune_lc_fiveclass.sh deleted file mode 100755 index 155abf6..0000000 --- a/examples/satvision-toa-finetune/run_satvision_finetune_lc_fiveclass.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -#SBATCH -J finetune_satvision_lc5 -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - - -export PYTHONPATH=$PWD:../../../:../../../pytorch-caney -export NGPUS=8 - -torchrun --nproc_per_node $NGPUS \ - ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ - --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ - --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ - --dataset MODISLC9 \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \ - --batch-size 4 \ - --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ - --enable-amp \ No newline at end of file diff --git a/examples/satvision-toa-finetune/run_satvision_finetune_lc_nineclass.sh b/examples/satvision-toa-finetune/run_satvision_finetune_lc_nineclass.sh deleted file mode 100755 index 618dcab..0000000 --- a/examples/satvision-toa-finetune/run_satvision_finetune_lc_nineclass.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -#SBATCH -J finetune_satvision_lc9 -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - -export PYTHONPATH=$PWD:$PWD/pytorch-caney -export NGPUS=4 - -torchrun --nproc_per_node $NGPUS \ - pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ - --cfg $1 \ - --pretrained $2 \ - --dataset MODISLC9 \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \ - --batch-size 4 \ - --output . \ - --enable-amp diff --git a/examples/satvision/README.md b/examples/satvision/README.md deleted file mode 100644 index 136a140..0000000 --- a/examples/satvision/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# SatVision Examples - -The following is an example on how to run SatVision finetune. This is only an example and does not limit other decoder possibilities, or other ways of dealing with the encoder. - -## SatVision Finetune Land Cover Five Class - -The script run_satvision_finetune_lc_fiveclass.sh has an example on how to run the finetuning of a 5 class land cover model using a simple UNet architecture. The dependencies of this model are as follow: - -- finetune.py script (pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py): this script has the basics for training the finetuning model. Below you will find an example of this script: - -```bash -export PYTHONPATH=$PWD:pytorch-caney -export NGPUS=8 - -torchrun --nproc_per_node $NGPUS \ - pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ - --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ - --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ - --dataset MODISLC9 \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \ - --batch-size 4 \ - --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ - --enable-amp -``` - -From these parameters note that: - -- the pretrained model path is given by --pretrained -- the data paths is given by --data-paths and is expecting a directory whose internal structure is one for images and one from labels, but this can be modified if both input and target files are stored in the same file -- the dataloader is simply called from the script using the --dataset option, which is simply calling build_finetune_dataloaders -from pytorch-caney - -These is simply a guide script on how to run a finetuning pipeline. If you want to get additional insights on how to build other -types of decoders, the build_model function from pytorch_caney/models/build.py has additional details on how to combine the different -encoder and decoders. \ No newline at end of file diff --git a/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml b/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml deleted file mode 100644 index 5f41c64..0000000 --- a/examples/satvision/finetune_satvision_base_landcover5class_192_window12_100ep.yaml +++ /dev/null @@ -1,33 +0,0 @@ -MODEL: - TYPE: swinv2 - DECODER: unet - NAME: satvision_finetune_lc5class - DROP_PATH_RATE: 0.1 - NUM_CLASSES: 5 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 128 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 14 - PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] -DATA: - IMG_SIZE: 224 - DATASET: MODISLC5 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -LOSS: - NAME: 'tversky' - MODE: 'multiclass' - ALPHA: 0.4 - BETA: 0.6 -TRAIN: - EPOCHS: 100 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.01 - LAYER_DECAY: 0.8 -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: satvision_finetune_land_cover_5class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml b/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml deleted file mode 100644 index 2e96121..0000000 --- a/examples/satvision/finetune_satvision_base_landcover9class_192_window12_100ep.yaml +++ /dev/null @@ -1,33 +0,0 @@ -MODEL: - TYPE: swinv2 - DECODER: unet - NAME: satvision_finetune_lc9class - DROP_PATH_RATE: 0.1 - NUM_CLASSES: 9 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 128 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 14 - PRETRAINED_WINDOW_SIZES: [ 12, 12, 12, 6 ] -DATA: - IMG_SIZE: 224 - DATASET: MODISLC5 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -LOSS: - NAME: 'tversky' - MODE: 'multiclass' - ALPHA: 0.4 - BETA: 0.6 -TRAIN: - EPOCHS: 100 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.01 - LAYER_DECAY: 0.8 -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: satvision_finetune_land_cover_9class_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml b/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml deleted file mode 100644 index 4c188bf..0000000 --- a/examples/satvision/mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml +++ /dev/null @@ -1,27 +0,0 @@ -MODEL: - TYPE: swinv2 - NAME: mim_satvision_pretrain - DROP_PATH_RATE: 0.1 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 128 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 12 -DATA: - IMG_SIZE: 192 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -TRAIN: - EPOCHS: 800 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.05 - LR_SCHEDULER: - NAME: 'multistep' - GAMMA: 0.1 - MULTISTEPS: [700,] -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: mim_pretrain_swinv2_satvision_192_window12__800ep \ No newline at end of file diff --git a/examples/satvision/run_satvision_finetune_lc_fiveclass.sh b/examples/satvision/run_satvision_finetune_lc_fiveclass.sh deleted file mode 100755 index 155abf6..0000000 --- a/examples/satvision/run_satvision_finetune_lc_fiveclass.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -#SBATCH -J finetune_satvision_lc5 -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - - -export PYTHONPATH=$PWD:../../../:../../../pytorch-caney -export NGPUS=8 - -torchrun --nproc_per_node $NGPUS \ - ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ - --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ - --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ - --dataset MODISLC9 \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_9classes_224 \ - --batch-size 4 \ - --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ - --enable-amp \ No newline at end of file diff --git a/examples/satvision/run_satvision_finetune_lc_nineclass.sh b/examples/satvision/run_satvision_finetune_lc_nineclass.sh deleted file mode 100755 index 7008967..0000000 --- a/examples/satvision/run_satvision_finetune_lc_nineclass.sh +++ /dev/null @@ -1,20 +0,0 @@ -#!/bin/bash - -#SBATCH -J finetune_satvision_lc9 -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - - -export PYTHONPATH=$PWD:../../../:../../../pytorch-caney -export NGPUS=8 - -torchrun --nproc_per_node $NGPUS \ - ../../../pytorch-caney/pytorch_caney/pipelines/finetuning/finetune.py \ - --cfg finetune_satvision_base_landcover5class_192_window12_100ep.yaml \ - --pretrained /explore/nobackup/people/cssprad1/projects/satnet/code/development/masked_image_modeling/development/models/simmim_satnet_pretrain_pretrain/simmim_pretrain__satnet_swinv2_base__img192_window12__800ep_v3_no_norm/ckpt_epoch_800.pth \ - --dataset MODISLC9 \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/finetuning/h18v04/labels_5classes_224 \ - --batch-size 4 \ - --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/finetune/models \ - --enable-amp \ No newline at end of file diff --git a/examples/satvision/run_satvision_pretrain.sh b/examples/satvision/run_satvision_pretrain.sh deleted file mode 100755 index 0ff9598..0000000 --- a/examples/satvision/run_satvision_pretrain.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -#SBATCH -J pretrain_satvision_swinv2 -#SBATCH -t 3-00:00:00 -#SBATCH -G 4 -#SBATCH -N 1 - - -export PYTHONPATH=$PWD:../../../:../../../pytorch-caney -export NGPUS=4 - -torchrun --nproc_per_node $NGPUS \ - ../../../pytorch-caney/pytorch_caney/pipelines/pretraining/mim.py \ - --cfg mim_pretrain_swinv2_satvision_base_192_window12_800ep.yaml \ - --dataset MODIS \ - --data-paths /explore/nobackup/projects/ilab/data/satvision/pretraining/training_* \ - --batch-size 128 \ - --output /explore/nobackup/people/cssprad1/projects/satnet/code/development/cleanup/trf/transformer/models \ - --enable-amp \ No newline at end of file diff --git a/pytorch_caney/console/__init__.py b/notebooks/README.md old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/console/__init__.py rename to notebooks/README.md diff --git a/notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb b/notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb new file mode 100644 index 0000000..853d0db --- /dev/null +++ b/notebooks/satvision_toa_modis_reconstruction_example_notebook.ipynb @@ -0,0 +1,368 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5facdc34-efbd-4082-91ef-e70a4f34c441", + "metadata": {}, + "source": [ + "# SatVision-TOA Reconstruction Example Notebook\n", + "\n", + "This notebook demonstrates the reconstruction capabilities of the SatVision-TOA model, designed to process and reconstruct MODIS TOA (Top of Atmosphere) imagery using Masked Image Modeling (MIM) for Earth observation tasks.\n", + "\n", + "Follow this step-by-step guide to install necessary dependencies, load model weights, transform data, make predictions, and visualize the results.\n", + "\n", + "## 1. Setup and Install Dependencies\n", + "\n", + "The following packages are required to run the notebook:\n", + "- `yacs` – for handling configuration\n", + "- `timm` – for Transformer and Image Models in PyTorch\n", + "- `segmentation-models-pytorch` – for segmentation utilities\n", + "- `termcolor` – for colored terminal text\n", + "- `webdataset==0.2.86` – for handling datasets from web sources" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5e08cd1-d8df-4dd8-b884-d452ef90943b", + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install yacs timm segmentation-models-pytorch termcolor webdataset==0.2.86" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4506576-5e30-417d-96de-8953d71c76c2", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import time\n", + "import random\n", + "import datetime\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import logging\n", + "\n", + "import torch\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.backends.backend_pdf import PdfPages\n", + "\n", + "import warnings\n", + "\n", + "warnings.filterwarnings('ignore') " + ] + }, + { + "cell_type": "markdown", + "id": "775cb720-5151-49fa-a7d5-7291ef663d45", + "metadata": {}, + "source": [ + "## 2. Model and Configuration Imports\n", + "\n", + "We load necessary modules from the pytorch-caney library, including the model, transformations, and plotting utilities." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edf47149-f489-497b-8601-89a7e8dbd9b9", + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append('../../pytorch-caney')\n", + "\n", + "from pytorch_caney.models.mim import build_mim_model\n", + "from pytorch_caney.transforms.mim_modis_toa import MimTransform\n", + "from pytorch_caney.configs.config import _C, _update_config_from_file\n", + "from pytorch_caney.plotting.modis_toa import plot_export_pdf" + ] + }, + { + "cell_type": "markdown", + "id": "fe00e78e-fca3-4221-86dd-da205fed4192", + "metadata": {}, + "source": [ + "## 2. Fetching the model\n", + "\n", + "### 2.1 Clone model ckpt from huggingface\n", + "\n", + "Model repo: https://huggingface.co/nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n", + "\n", + "```bash\n", + "# On prism/explore system\n", + "module load git-lfs\n", + "\n", + "git lfs install\n", + "\n", + "git clone git@hf.co:nasa-cisto-data-science-group/satvision-toa-giant-patch8-window8-128\n", + "```\n", + "\n", + "Note: If using git w/ ssh, make sure you have ssh keys enabled to clone using ssh auth.\n", + "https://huggingface.co/docs/hub/security-git-ssh\n", + "\n", + "```bash\n", + "# If this outputs as anon, follow the next steps.\n", + "ssh -T git@hf.co\n", + "```\n", + "\n", + "\n", + "```bash\n", + "eval $(ssh-agent)\n", + "\n", + "# Check if ssh-agent is using the proper key\n", + "ssh-add -l\n", + "\n", + "# If not\n", + "ssh-add ~/.ssh/your-key\n", + "\n", + "# Or if you want to use the default id_* key, just do\n", + "ssh-add\n", + "\n", + "```\n", + "\n", + "## 3. Fetching the validation dataset\n", + "\n", + "### 3.1 Clone dataset repo from huggingface\n", + "\n", + "Dataset repo: https://huggingface.co/datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n", + "\n", + "\n", + "```bash\n", + "# On prims/explore system\n", + "module load git-lfs\n", + "\n", + "git lfs install\n", + "\n", + "git clone git@hf.co:datasets/nasa-cisto-data-science-group/modis_toa_cloud_reconstruction_validation\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "abb754ff-1753-4a4c-804e-8e3e5461fd0a", + "metadata": {}, + "source": [ + "## 4. Define Model and Data Paths\n", + "\n", + "Specify paths to model checkpoint, configuration file, and the validation dataset. Customize these paths as needed for your environment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7ec267ce-ded1-40e6-8443-e1037297f710", + "metadata": {}, + "outputs": [], + "source": [ + "MODEL_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mp_rank_00_model_states.pt'\n", + "CONFIG_PATH: str = '../../satvision-toa-giant-patch8-window8-128/mim_pretrain_swinv2_satvision_giant_128_window08_50ep.yaml'\n", + "\n", + "OUTPUT: str = '.'\n", + "DATA_PATH: str = '../../modis_toa_cloud_reconstruction_validation/sv_toa_128_chip_validation_04_24.npy'" + ] + }, + { + "cell_type": "markdown", + "id": "bd7d0b93-7fd3-49cb-ab9e-7536820ec5f2", + "metadata": {}, + "source": [ + "## 5. Configure Model\n", + "\n", + "Load and update the configuration for the SatVision-TOA model, specifying model and data paths." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aac43f0e-dc4b-49ba-a482-933b5bab4b79", + "metadata": {}, + "outputs": [], + "source": [ + "# Update config given configurations\n", + "\n", + "config = _C.clone()\n", + "_update_config_from_file(config, CONFIG_PATH)\n", + "\n", + "config.defrost()\n", + "config.MODEL.PRETRAINED = MODEL_PATH\n", + "config.DATA.DATA_PATHS = [DATA_PATH]\n", + "config.OUTPUT = OUTPUT\n", + "config.freeze()" + ] + }, + { + "cell_type": "markdown", + "id": "1d596904-d1df-4f6d-8e88-4c647ac26924", + "metadata": {}, + "source": [ + "## 6. Load Model Weights from Checkpoint\n", + "\n", + "Build and initialize the model from the checkpoint to prepare for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfe245f7-589e-4b02-9990-15cb1733f6cb", + "metadata": {}, + "outputs": [], + "source": [ + "print('Building un-initialized model')\n", + "model = build_mim_model(config)\n", + "print('Successfully built uninitialized model')\n", + "\n", + "print(f'Attempting to load checkpoint from {config.MODEL.PRETRAINED}')\n", + "checkpoint = torch.load(config.MODEL.PRETRAINED)\n", + "model.load_state_dict(checkpoint['module'])\n", + "print('Successfully applied checkpoint')\n", + "model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "markdown", + "id": "20c26d1e-125a-4b4c-a21e-ab07d6977222", + "metadata": {}, + "source": [ + "## 7. Transform Validation Data\n", + "\n", + "The MODIS TOA dataset is loaded and transformed using MimTransform, generating a masked dataset for reconstruction." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b3b47b1-0690-4ef9-bed6-ec243b5d42cb", + "metadata": {}, + "outputs": [], + "source": [ + "# Use the Masked-Image-Modeling transform specific to MODIS TOA data\n", + "transform = MimTransform(config)\n", + "\n", + "# The reconstruction evaluation set is a single numpy file\n", + "validation_dataset_path = config.DATA.DATA_PATHS[0]\n", + "validation_dataset = np.load(validation_dataset_path)\n", + "len_batch = range(validation_dataset.shape[0])\n", + "\n", + "# Apply transform to each image in the batch\n", + "# A mask is auto-generated in the transform\n", + "imgMasks = [transform(validation_dataset[idx]) for idx \\\n", + " in len_batch]\n", + "\n", + "# Seperate img and masks, cast masks to torch tensor\n", + "img = torch.stack([imgMask[0] for imgMask in imgMasks])\n", + "mask = torch.stack([torch.from_numpy(imgMask[1]) for \\\n", + " imgMask in imgMasks])" + ] + }, + { + "cell_type": "markdown", + "id": "8b2148e4-da6d-4ae0-a194-c7adb62728a0", + "metadata": { + "tags": [] + }, + "source": [ + "## 8. Prediction\n", + "\n", + "Run predictions on each sample and calculate reconstruction losses. Each image is processed individually to track individual losses." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a3814751-f352-456e-850c-fe1d289b1d6b", + "metadata": {}, + "outputs": [], + "source": [ + "inputs = []\n", + "outputs = []\n", + "masks = []\n", + "losses = []\n", + "\n", + "# We could do this in a single batch however we\n", + "# want to report the loss per-image, in place of\n", + "# loss per-batch.\n", + "for i in tqdm(range(img.shape[0])):\n", + " single_img = img[i].unsqueeze(0)\n", + " single_mask = mask[i].unsqueeze(0)\n", + " single_img = single_img.cuda(non_blocking=True)\n", + " single_mask = single_mask.cuda(non_blocking=True)\n", + "\n", + " with torch.no_grad():\n", + " z = model.encoder(single_img, single_mask)\n", + " img_recon = model.decoder(z)\n", + " loss = model(single_img, single_mask)\n", + "\n", + " inputs.extend(single_img.cpu())\n", + " masks.extend(single_mask.cpu())\n", + " outputs.extend(img_recon.cpu())\n", + " losses.append(loss.cpu()) " + ] + }, + { + "cell_type": "markdown", + "id": "22329bb4-5c6e-42dc-a492-8863fc2bf672", + "metadata": {}, + "source": [ + "## 9. Export Reconstruction Results to PDF\n", + "\n", + "Save and visualize the reconstruction results. The output PDF will contain reconstructed images with original and masked versions." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8ac6a09d-5fe2-4aa9-ac37-f235d5a8020a", + "metadata": {}, + "outputs": [], + "source": [ + "pdfPath = '../../satvision-toa-reconstruction-validation-giant-example.pdf'\n", + "rgbIndex = [0, 2, 1] # Indices of [Red band, Blue band, Green band]\n", + "plot_export_pdf(pdfPath, inputs, outputs, masks, rgbIndex)" + ] + }, + { + "cell_type": "markdown", + "id": "1e0eb426-c7b4-47d4-aefa-2199ecfce2ab", + "metadata": {}, + "source": [ + "This notebook provides an end-to-end example for reconstructing satellite images with the SatVision-TOA model, from setup through prediction and output visualization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62065e24-ddf2-4bf1-8362-90dc0c9bf49e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ILAB Kernel (Pytorch)", + "language": "python", + "name": "pytorch-kernel" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 931e61b..df70096 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,4 +3,4 @@ requires = ["setuptools", "wheel"] [tool.black] -target_version = ['py39'] +target_version = ['py310'] diff --git a/pytorch_caney/__init__.py b/pytorch_caney/__init__.py old mode 100755 new mode 100644 diff --git a/pytorch_caney/data/datamodules/__init__.py b/pytorch_caney/configs/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/data/datamodules/__init__.py rename to pytorch_caney/configs/__init__.py diff --git a/pytorch_caney/config.py b/pytorch_caney/configs/config.py similarity index 82% rename from pytorch_caney/config.py rename to pytorch_caney/configs/config.py index 10562f3..f633293 100644 --- a/pytorch_caney/config.py +++ b/pytorch_caney/configs/config.py @@ -11,14 +11,24 @@ # Data settings # ----------------------------------------------------------------------------- _C.DATA = CN() +# Use a PL data module +_C.DATA.DATAMODULE = True # Batch size for a single GPU, could be overwritten by command line argument _C.DATA.BATCH_SIZE = 128 # Path(s) to dataset, could be overwritten by command line argument _C.DATA.DATA_PATHS = [''] +# Path(s) to the validation/test dataset +_C.DATA.TEST_DATA_PATHS = [''] +# Path(s) to dataset masks +_C.DATA.MASK_PATHS = [''] +# Path to validation numpy dataset +_C.DATA.VALIDATION_PATH = '' # Dataset name _C.DATA.DATASET = 'MODIS' # Input image size _C.DATA.IMG_SIZE = 224 +# Dataset length (for datasets where len cannot be used) +_C.DATA.LENGTH = 1920000 # Interpolation to resize image (random, bilinear, bicubic) _C.DATA.INTERPOLATION = 'bicubic' # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. @@ -36,8 +46,10 @@ _C.MODEL = CN() # Model type _C.MODEL.TYPE = 'swinv2' -# Decoder type -_C.MODEL.DECODER = None +# Encoder type for fine-tuning +_C.MODEL.ENCODER = '' +# Decoder type for fine-tuning +_C.MODEL.DECODER = '' # Model name _C.MODEL.NAME = 'swinv2_base_patch4_window7_224' # Pretrained weight from checkpoint, could be from previous pre-training @@ -47,6 +59,8 @@ _C.MODEL.RESUME = '' # Number of classes, overwritten in data preparation _C.MODEL.NUM_CLASSES = 17 +# Number of channels the input image has +_C.MODEL.IN_CHANS = 3 # Dropout rate _C.MODEL.DROP_RATE = 0.0 # Drop path rate @@ -88,9 +102,14 @@ # Training settings # ----------------------------------------------------------------------------- _C.TRAIN = CN() +_C.TRAIN.ACCELERATOR = 'gpu' +_C.TRAIN.STRATEGY = 'deepspeed' +_C.TRAIN.LIMIT_TRAIN_BATCHES = True +_C.TRAIN.NUM_TRAIN_BATCHES = None _C.TRAIN.START_EPOCH = 0 _C.TRAIN.EPOCHS = 300 _C.TRAIN.WARMUP_EPOCHS = 20 +_C.TRAIN.WARMUP_STEPS = 200 _C.TRAIN.WEIGHT_DECAY = 0.05 _C.TRAIN.BASE_LR = 5e-4 _C.TRAIN.WARMUP_LR = 5e-7 @@ -101,7 +120,7 @@ _C.TRAIN.AUTO_RESUME = True # Gradient accumulation steps # could be overwritten by command line argument -_C.TRAIN.ACCUMULATION_STEPS = 1 +_C.TRAIN.ACCUMULATION_STEPS = 1 # Whether to use gradient checkpointing to save memory # could be overwritten by command line argument _C.TRAIN.USE_CHECKPOINT = False @@ -116,6 +135,8 @@ # Gamma / Multi steps value, used in MultiStepLRScheduler _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1 _C.TRAIN.LR_SCHEDULER.MULTISTEPS = [] +# OneCycle LR Scheduler max LR percentage +_C.TRAIN.LR_SCHEDULER.CYCLE_PERCENTAGE = 0.3 # Optimizer _C.TRAIN.OPTIMIZER = CN() @@ -130,6 +151,18 @@ # [SimMIM] Layer decay for fine-tuning _C.TRAIN.LAYER_DECAY = 1.0 +# Tensorboard settings +_C.TENSORBOARD = CN() +_C.TENSORBOARD.WRITER_DIR = '.' + +# DeepSpeed configuration settings +_C.DEEPSPEED = CN() +_C.DEEPSPEED.STAGE = 2 +_C.DEEPSPEED.REDUCE_BUCKET_SIZE = 5e8 +_C.DEEPSPEED.ALLGATHER_BUCKET_SIZE = 5e8 +_C.DEEPSPEED.CONTIGUOUS_GRADIENTS = True +_C.DEEPSPEED.OVERLAP_COMM = True + # ----------------------------------------------------------------------------- # Testing settings @@ -142,21 +175,29 @@ # Misc # ----------------------------------------------------------------------------- # Whether to enable pytorch amp, overwritten by command line argument -_C.ENABLE_AMP = False +_C.PRECISION = '32' # Enable Pytorch automatic mixed precision (amp). _C.AMP_ENABLE = True # Path to output folder, overwritten by command line argument -_C.OUTPUT = '' +_C.OUTPUT = '.' # Tag of experiment, overwritten by command line argument _C.TAG = 'pt-caney-default-tag' # Frequency to save checkpoint _C.SAVE_FREQ = 1 # Frequency to logging info _C.PRINT_FREQ = 10 +# Frequency for running validation step +_C.VALIDATION_FREQ = 1 # Fixed random seed _C.SEED = 42 # Perform evaluation only, overwritten by command line argument _C.EVAL_MODE = False +# Pipeline +_C.PIPELINE = 'satvisiontoapretrain' +# Data module +_C.DATAMODULE = 'abitoa3dcloud' +# Fast dev run +_C.FAST_DEV_RUN = False def _update_config_from_file(config, cfg_file): @@ -189,6 +230,8 @@ def _check_args(name): config.DATA.BATCH_SIZE = args.batch_size if _check_args('data_paths'): config.DATA.DATA_PATHS = args.data_paths + if _check_args('validation_path'): + config.DATA.VALIDATION_PATH = args.validation_path if _check_args('dataset'): config.DATA.DATASET = args.dataset if _check_args('resume'): @@ -211,6 +254,8 @@ def _check_args(name): config.EVAL_MODE = True if _check_args('enable_amp'): config.ENABLE_AMP = args.enable_amp + if _check_args('tensorboard_dir'): + config.TENSORBOARD.WRITER_DIR = args.tensorboard_dir # output folder config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) diff --git a/pytorch_caney/console/cli.py b/pytorch_caney/console/cli.py deleted file mode 100755 index e02d571..0000000 --- a/pytorch_caney/console/cli.py +++ /dev/null @@ -1,62 +0,0 @@ -from pytorch_lightning.utilities.cli import LightningCLI - -import torch - - -class TerraGPULightningCLI(LightningCLI): - - def add_arguments_to_parser(self, parser): - - # Trainer - performance - parser.set_defaults({"trainer.accelerator": "auto"}) - parser.set_defaults({"trainer.devices": "auto"}) - parser.set_defaults({"trainer.auto_select_gpus": True}) - parser.set_defaults({"trainer.precision": 32}) - - # Trainer - training - parser.set_defaults({"trainer.max_epochs": 500}) - parser.set_defaults({"trainer.min_epochs": 1}) - parser.set_defaults({"trainer.detect_anomaly": True}) - parser.set_defaults({"trainer.logger": True}) - parser.set_defaults({"trainer.default_root_dir": "output_model"}) - - # Trainer - optimizer - TODO - _ = { - "class_path": torch.optim.Adam, - "init_args": { - "lr": 0.01 - } - } - - # Trainer - callbacks - default_callbacks = [ - {"class_path": "pytorch_lightning.callbacks.DeviceStatsMonitor"}, - { - "class_path": "pytorch_lightning.callbacks.EarlyStopping", - "init_args": { - "monitor": "val_loss", - "patience": 5, - "mode": "min" - } - }, - # { - # "class_path": "pytorch_lightning.callbacks.ModelCheckpoint", - # "init_args": { - # "dirpath": "output_model", - # "monitor": "val_loss", - # "auto_insert_metric_name": True - # } - # }, - ] - parser.set_defaults({"trainer.callbacks": default_callbacks}) - - # { - # "class_path": "pytorch_lightning.callbacks.ModelCheckpoint", - # "init_args": { - # "dirpath": "output_model", - # "monitor": "val_loss", - # "auto_insert_metric_name": True - # } - # }, - # ] - # parser.set_defaults({"trainer.callbacks": default_callbacks}) diff --git a/pytorch_caney/console/dl_pipeline.py b/pytorch_caney/console/dl_pipeline.py deleted file mode 100755 index 4840a95..0000000 --- a/pytorch_caney/console/dl_pipeline.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- -# RF pipeline: preprocess, train, and predict. - -import sys -import logging - -# from terragpu import unet_model -# from terragpu.decorators import DuplicateFilter -# from terragpu.ai.deep_learning.datamodules.segmentation_datamodule \ -# import SegmentationDataModule - -from pytorch_lightning import seed_everything # , trainer -# from pytorch_lightning import LightningModule, LightningDataModule -from terragpu.ai.deep_learning.console.cli import TerraGPULightningCLI - - -# ----------------------------------------------------------------------------- -# main -# -# python rf_pipeline.py options here -# ----------------------------------------------------------------------------- -def main(): - - # ------------------------------------------------------------------------- - # Set logging - # ------------------------------------------------------------------------- - logger = logging.getLogger() - logger.setLevel(logging.INFO) - ch = logging.StreamHandler(sys.stdout) - ch.setLevel(logging.INFO) - - # Set formatter and handlers - formatter = logging.Formatter( - "%(asctime)s; %(levelname)s; %(message)s", "%Y-%m-%d %H:%M:%S") - ch.setFormatter(formatter) - logger.addHandler(ch) - - # ------------------------------------------------------------------------- - # Execute pipeline step - # ------------------------------------------------------------------------- - # Seed every library - seed_everything(1234, workers=True) - _ = TerraGPULightningCLI(save_config_callback=None) - # unet_model.UNetSegmentation, SegmentationDataModule) - - # train - # trainer = pl.Trainer() - # trainer.fit(model, datamodule=dm) - # validate - # trainer.validate(datamodule=dm) - # test - # trainer.test(datamodule=dm) - # predict - # predictions = trainer.predict(datamodule=dm) - return - - -# ----------------------------------------------------------------------------- -# Invoke the main -# ----------------------------------------------------------------------------- -if __name__ == "__main__": - sys.exit(main()) diff --git a/pytorch_caney/data/datamodules/finetune_datamodule.py b/pytorch_caney/data/datamodules/finetune_datamodule.py deleted file mode 100644 index bdbed40..0000000 --- a/pytorch_caney/data/datamodules/finetune_datamodule.py +++ /dev/null @@ -1,114 +0,0 @@ -from ..datasets.modis_dataset import MODISDataset -from ..datasets.modis_lc_five_dataset import MODISLCFiveDataset -from ..datasets.modis_lc_nine_dataset import MODISLCNineDataset - -from ..transforms import TensorResizeTransform - -import torch.distributed as dist -from torch.utils.data import DataLoader, DistributedSampler - - -DATASETS = { - 'modis': MODISDataset, - 'modislc9': MODISLCNineDataset, - 'modislc5': MODISLCFiveDataset, - # 'modis tree': MODISTree, -} - - -def get_dataset_from_dict(dataset_name: str): - """Gets the proper dataset given a dataset name. - - Args: - dataset_name (str): name of the dataset - - Raises: - KeyError: thrown if dataset key is not present in dict - - Returns: - dataset: pytorch dataset - """ - - dataset_name = dataset_name.lower() - - try: - - dataset_to_use = DATASETS[dataset_name] - - except KeyError: - - error_msg = f"{dataset_name} is not an existing dataset" - - error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" - - raise KeyError(error_msg) - - return dataset_to_use - - -def build_finetune_dataloaders(config, logger): - """Builds the dataloaders and datasets for a fine-tuning task. - - Args: - config: config object - logger: logging logger - - Returns: - dataloader_train: training dataloader - dataloader_val: validation dataloader - """ - - transform = TensorResizeTransform(config) - - logger.info(f'Finetuning data transform:\n{transform}') - - dataset_name = config.DATA.DATASET - - logger.info(f'Dataset: {dataset_name}') - logger.info(f'Data Paths: {config.DATA.DATA_PATHS}') - - dataset_to_use = get_dataset_from_dict(dataset_name) - - logger.info(f'Dataset obj: {dataset_to_use}') - - dataset_train = dataset_to_use(data_paths=config.DATA.DATA_PATHS, - split="train", - img_size=config.DATA.IMG_SIZE, - transform=transform) - - dataset_val = dataset_to_use(data_paths=config.DATA.DATA_PATHS, - split="val", - img_size=config.DATA.IMG_SIZE, - transform=transform) - - logger.info(f'Build dataset: train images = {len(dataset_train)}') - - logger.info(f'Build dataset: val images = {len(dataset_val)}') - - sampler_train = DistributedSampler( - dataset_train, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=True) - - sampler_val = DistributedSampler( - dataset_val, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=False) - - dataloader_train = DataLoader(dataset_train, - config.DATA.BATCH_SIZE, - sampler=sampler_train, - num_workers=config.DATA.NUM_WORKERS, - pin_memory=True, - drop_last=True) - - dataloader_val = DataLoader(dataset_val, - config.DATA.BATCH_SIZE, - sampler=sampler_val, - num_workers=config.DATA.NUM_WORKERS, - pin_memory=True, - drop_last=False) - - return dataloader_train, dataloader_val diff --git a/pytorch_caney/data/datamodules/mim_datamodule.py b/pytorch_caney/data/datamodules/mim_datamodule.py deleted file mode 100644 index b70ee74..0000000 --- a/pytorch_caney/data/datamodules/mim_datamodule.py +++ /dev/null @@ -1,80 +0,0 @@ -from ..datasets.simmim_modis_dataset import MODISDataset - -from ..transforms import SimmimTransform - -import torch.distributed as dist -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.data._utils.collate import default_collate - - -DATASETS = { - 'MODIS': MODISDataset, -} - - -def collate_fn(batch): - if not isinstance(batch[0][0], tuple): - return default_collate(batch) - else: - batch_num = len(batch) - ret = [] - for item_idx in range(len(batch[0][0])): - if batch[0][0][item_idx] is None: - ret.append(None) - else: - ret.append(default_collate( - [batch[i][0][item_idx] for i in range(batch_num)])) - ret.append(default_collate([batch[i][1] for i in range(batch_num)])) - return ret - - -def get_dataset_from_dict(dataset_name): - - try: - - dataset_to_use = DATASETS[dataset_name] - - except KeyError: - - error_msg = f"{dataset_name} is not an existing dataset" - - error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" - - raise KeyError(error_msg) - - return dataset_to_use - - -def build_mim_dataloader(config, logger): - - transform = SimmimTransform(config) - - logger.info(f'Pre-train data transform:\n{transform}') - - dataset_name = config.DATA.DATASET - - dataset_to_use = get_dataset_from_dict(dataset_name) - - dataset = dataset_to_use(config, - config.DATA.DATA_PATHS, - split="train", - img_size=config.DATA.IMG_SIZE, - transform=transform) - - logger.info(f'Build dataset: train images = {len(dataset)}') - - sampler = DistributedSampler( - dataset, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=True) - - dataloader = DataLoader(dataset, - config.DATA.BATCH_SIZE, - sampler=sampler, - num_workers=config.DATA.NUM_WORKERS, - pin_memory=True, - drop_last=True, - collate_fn=collate_fn) - - return dataloader diff --git a/pytorch_caney/data/datamodules/mim_webdataset_datamodule.py b/pytorch_caney/data/datamodules/mim_webdataset_datamodule.py deleted file mode 100644 index 47b9a35..0000000 --- a/pytorch_caney/data/datamodules/mim_webdataset_datamodule.py +++ /dev/null @@ -1,48 +0,0 @@ -from ..datasets.mim_modis_22m_dataset import MODIS22MDataset - -from ..transforms import SimmimTransform - -from torch.utils.data import DataLoader -from torch.utils.data._utils.collate import default_collate - -import os - - -def collate_fn(batch): - if not isinstance(batch[0][0], tuple): - return default_collate(batch) - else: - batch_num = len(batch) - ret = [] - for item_idx in range(len(batch[0][0])): - if batch[0][0][item_idx] is None: - ret.append(None) - else: - ret.append(default_collate( - [batch[i][0][item_idx] for i in range(batch_num)])) - ret.append(default_collate([batch[i][1] for i in range(batch_num)])) - return ret - - -def build_mim_dataloader(config, logger): - - transform = SimmimTransform(config) - - logger.info(f'Pre-train data transform:\n{transform}') - - dataset = MODIS22MDataset(config, - config.DATA.DATA_PATHS, - split="train", - img_size=config.DATA.IMG_SIZE, - transform=transform, - batch_size=config.DATA.BATCH_SIZE).dataset() - - dataloader = DataLoader(dataset, - batch_size=None, - shuffle=False, - num_workers=int(os.environ["SLURM_CPUS_PER_TASK"]), - pin_memory=True) - # NEED TO GET ACTUAL SIZE - # dataloader = dataloader.ddp_equalize(21643764 // config.DATA.BATCH_SIZE) - - return dataloader diff --git a/pytorch_caney/data/datamodules/segmentation_datamodule.py b/pytorch_caney/data/datamodules/segmentation_datamodule.py deleted file mode 100755 index fb6d166..0000000 --- a/pytorch_caney/data/datamodules/segmentation_datamodule.py +++ /dev/null @@ -1,164 +0,0 @@ -import os -import logging -from typing import Any, Union, Optional - -import torch -from torch.utils.data import DataLoader -from torch.utils.data.dataset import random_split -from pytorch_lightning import LightningDataModule -from pytorch_lightning.utilities.cli import DATAMODULE_REGISTRY - -from terragpu.ai.deep_learning.datasets.segmentation_dataset \ - import SegmentationDataset - - -@DATAMODULE_REGISTRY -class SegmentationDataModule(LightningDataModule): - - def __init__( - self, - - # Dataset parameters - dataset_dir: str = 'dataset/', - images_regex: str = 'dataset/images/*.tif', - labels_regex: str = 'dataset/labels/*.tif', - generate_dataset: bool = True, - tile_size: int = 256, - max_patches: Union[float, int] = 100, - augment: bool = True, - chunks: dict = {'band': 1, 'x': 2048, 'y': 2048}, - input_bands: list = ['CB', 'B', 'G', 'Y', 'R', 'RE', 'N1', 'N2'], - output_bands: list = ['B', 'G', 'R'], - seed: int = 24, - normalize: bool = True, - pytorch: bool = True, - - # Datamodule parameters - val_split: float = 0.2, - test_split: float = 0.1, - num_workers: int = os.cpu_count(), - batch_size: int = 32, - shuffle: bool = True, - pin_memory: bool = False, - drop_last: bool = False, - - # Inference parameters - raster_regex: str = 'rasters/*.tif', - - *args: Any, - **kwargs: Any, - - ) -> None: - - super().__init__(*args, **kwargs) - - # Dataset parameters - self.images_regex = images_regex - self.labels_regex = labels_regex - self.dataset_dir = dataset_dir - self.generate_dataset = generate_dataset - self.tile_size = tile_size - self.max_patches = max_patches - self.augment = augment - self.chunks = chunks - self.input_bands = input_bands - self.output_bands = output_bands - self.seed = seed - self.normalize = normalize - self.pytorch = pytorch - - self.val_split = val_split - self.test_split = test_split - self.raster_regex = raster_regex - - # Performance parameters - self.batch_size = batch_size - self.num_workers = num_workers - self.shuffle = shuffle - self.pin_memory = pin_memory - self.drop_last = drop_last - - def prepare_data(self): - if self.generate_dataset: - SegmentationDataset( - images_regex=self.images_regex, - labels_regex=self.labels_regex, - dataset_dir=self.dataset_dir, - generate_dataset=self.generate_dataset, - tile_size=self.tile_size, - max_patches=self.max_patches, - augment=self.augment, - chunks=self.chunks, - input_bands=self.input_bands, - output_bands=self.output_bands, - seed=self.seed, - normalize=self.normalize, - pytorch=self.pytorch, - ) - - def setup(self, stage: Optional[str] = None): - - # Split into train, val, test - segmentation_dataset = SegmentationDataset( - images_regex=self.images_regex, - labels_regex=self.labels_regex, - dataset_dir=self.dataset_dir, - generate_dataset=False, - tile_size=self.tile_size, - max_patches=self.max_patches, - augment=self.augment, - chunks=self.chunks, - input_bands=self.input_bands, - output_bands=self.output_bands, - seed=self.seed, - normalize=self.normalize, - pytorch=self.pytorch, - ) - - # Split datasets into train, val, and test sets - val_len = round(self.val_split * len(segmentation_dataset)) - test_len = round(self.test_split * len(segmentation_dataset)) - train_len = len(segmentation_dataset) - val_len - test_len - - # Initialize datasets - self.train_set, self.val_set, self.test_set = random_split( - segmentation_dataset, lengths=[train_len, val_len, test_len], - generator=torch.Generator().manual_seed(self.seed) - ) - logging.info("Initialized datasets...") - - def train_dataloader(self) -> DataLoader: - loader = DataLoader( - self.train_set, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - ) - return loader - - def val_dataloader(self) -> DataLoader: - loader = DataLoader( - self.val_set, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - ) - return loader - - def test_dataloader(self) -> DataLoader: - loader = DataLoader( - self.test_set, - batch_size=self.batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - ) - return loader - - def predict_dataloader(self) -> DataLoader: - raise NotImplementedError diff --git a/pytorch_caney/data/datamodules/simmim_datamodule.py b/pytorch_caney/data/datamodules/simmim_datamodule.py deleted file mode 100644 index b70ee74..0000000 --- a/pytorch_caney/data/datamodules/simmim_datamodule.py +++ /dev/null @@ -1,80 +0,0 @@ -from ..datasets.simmim_modis_dataset import MODISDataset - -from ..transforms import SimmimTransform - -import torch.distributed as dist -from torch.utils.data import DataLoader, DistributedSampler -from torch.utils.data._utils.collate import default_collate - - -DATASETS = { - 'MODIS': MODISDataset, -} - - -def collate_fn(batch): - if not isinstance(batch[0][0], tuple): - return default_collate(batch) - else: - batch_num = len(batch) - ret = [] - for item_idx in range(len(batch[0][0])): - if batch[0][0][item_idx] is None: - ret.append(None) - else: - ret.append(default_collate( - [batch[i][0][item_idx] for i in range(batch_num)])) - ret.append(default_collate([batch[i][1] for i in range(batch_num)])) - return ret - - -def get_dataset_from_dict(dataset_name): - - try: - - dataset_to_use = DATASETS[dataset_name] - - except KeyError: - - error_msg = f"{dataset_name} is not an existing dataset" - - error_msg = f"{error_msg}. Available datasets: {DATASETS.keys()}" - - raise KeyError(error_msg) - - return dataset_to_use - - -def build_mim_dataloader(config, logger): - - transform = SimmimTransform(config) - - logger.info(f'Pre-train data transform:\n{transform}') - - dataset_name = config.DATA.DATASET - - dataset_to_use = get_dataset_from_dict(dataset_name) - - dataset = dataset_to_use(config, - config.DATA.DATA_PATHS, - split="train", - img_size=config.DATA.IMG_SIZE, - transform=transform) - - logger.info(f'Build dataset: train images = {len(dataset)}') - - sampler = DistributedSampler( - dataset, - num_replicas=dist.get_world_size(), - rank=dist.get_rank(), - shuffle=True) - - dataloader = DataLoader(dataset, - config.DATA.BATCH_SIZE, - sampler=sampler, - num_workers=config.DATA.NUM_WORKERS, - pin_memory=True, - drop_last=True, - collate_fn=collate_fn) - - return dataloader diff --git a/pytorch_caney/data/datasets/modis_dataset.py b/pytorch_caney/data/datasets/modis_dataset.py deleted file mode 100644 index 89a4923..0000000 --- a/pytorch_caney/data/datasets/modis_dataset.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -import random - -import numpy as np - -from torch.utils.data import Dataset - - -class MODISDataset(Dataset): - """ - MODIS Landcover 17-class pytorch fine-tuning dataset - """ - - IMAGE_PATH = os.path.join("images") - MASK_PATH = os.path.join("labels") - - def __init__( - self, - data_paths: list, - split: str, - img_size: tuple = (256, 256), - transform=None, - ): - self.img_size = img_size - self.transform = transform - self.split = split - self.data_paths = data_paths - self.img_list = [] - self.mask_list = [] - - self._init_data_paths(self.data_paths) - - # Split between train and valid set (80/20) - random_inst = random.Random(12345) # for repeatability - n_items = len(self.img_list) - idxs = set(random_inst.sample(range(n_items), n_items // 5)) - total_idxs = set(range(n_items)) - if self.split == "train": - idxs = total_idxs - idxs - - print(f'> Found {len(idxs)} patches for this dataset ({split})') - self.img_list = [self.img_list[i] for i in idxs] - self.mask_list = [self.mask_list[i] for i in idxs] - - def _init_data_paths(self, data_paths: list) -> None: - """ - Given a list of datapaths, get all filenames matching - regex from each subdatapath and compile to a single list. - """ - for data_path in data_paths: - img_path = os.path.join(data_path, self.IMAGE_PATH) - mask_path = os.path.join(data_path, self.MASK_PATH) - self.img_list.extend(self.get_filenames(img_path)) - self.mask_list.extend(self.get_filenames(mask_path)) - - def __len__(self): - return len(self.img_list) - - def __getitem__(self, idx, transpose=True): - - # load image - img = np.load(self.img_list[idx]) - - # load mask - mask = np.load(self.mask_list[idx]) - if len(mask.shape) > 2: - mask = np.argmax(mask, axis=-1) - - # perform transformations - if self.transform is not None: - img = self.transform(img) - - return img, mask - - def get_filenames(self, path): - """ - Returns a list of absolute paths to images inside given `path` - """ - files_list = [] - for filename in sorted(os.listdir(path)): - files_list.append(os.path.join(path, filename)) - return files_list diff --git a/pytorch_caney/data/datasets/modis_lc_five_dataset.py b/pytorch_caney/data/datasets/modis_lc_five_dataset.py deleted file mode 100644 index c8948a0..0000000 --- a/pytorch_caney/data/datasets/modis_lc_five_dataset.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -from torch.utils.data import Dataset - -import numpy as np -import random - - -class MODISLCFiveDataset(Dataset): - """ - MODIS Landcover five-class pytorch fine-tuning dataset - """ - - IMAGE_PATH = os.path.join("images") - MASK_PATH = os.path.join("labels") - - def __init__( - self, - data_paths: list, - split: str, - img_size: tuple = (224, 224), - transform=None, - ): - self.img_size = img_size - self.transform = transform - self.split = split - self.data_paths = data_paths - self.img_list = [] - self.mask_list = [] - for data_path in data_paths: - img_path = os.path.join(data_path, self.IMAGE_PATH) - mask_path = os.path.join(data_path, self.MASK_PATH) - self.img_list.extend(self.get_filenames(img_path)) - self.mask_list.extend(self.get_filenames(mask_path)) - # Split between train and valid set (80/20) - - random_inst = random.Random(12345) # for repeatability - n_items = len(self.img_list) - print(f'Found {n_items} possible patches to use') - range_n_items = range(n_items) - range_n_items = random_inst.sample(range_n_items, int(n_items*0.5)) - idxs = set(random_inst.sample(range_n_items, len(range_n_items) // 5)) - total_idxs = set(range_n_items) - if split == 'train': - idxs = total_idxs - idxs - print(f'> Using {len(idxs)} patches for this dataset ({split})') - self.img_list = [self.img_list[i] for i in idxs] - self.mask_list = [self.mask_list[i] for i in idxs] - print(f'>> {split}: {len(self.img_list)}') - - def __len__(self): - return len(self.img_list) - - def __getitem__(self, idx, transpose=True): - - # load image - img = np.load(self.img_list[idx]) - - img = np.clip(img, 0, 1.0) - - # load mask - mask = np.load(self.mask_list[idx]) - - mask = np.argmax(mask, axis=-1) - - mask = mask-1 - - # perform transformations - img = self.transform(img) - - return img, mask - - def get_filenames(self, path): - """ - Returns a list of absolute paths to images inside given `path` - """ - files_list = [] - for filename in sorted(os.listdir(path)): - files_list.append(os.path.join(path, filename)) - return files_list diff --git a/pytorch_caney/data/datasets/modis_lc_nine_dataset.py b/pytorch_caney/data/datasets/modis_lc_nine_dataset.py deleted file mode 100644 index ebbe8d5..0000000 --- a/pytorch_caney/data/datasets/modis_lc_nine_dataset.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -import random - -import numpy as np - -from torch.utils.data import Dataset - - -class MODISLCNineDataset(Dataset): - """ - MODIS Landcover nine-class pytorch fine-tuning dataset - """ - IMAGE_PATH = os.path.join("images") - MASK_PATH = os.path.join("labels") - - def __init__( - self, - data_paths: list, - split: str, - img_size: tuple = (224, 224), - transform=None, - ): - self.img_size = img_size - self.transform = transform - self.split = split - self.data_paths = data_paths - self.img_list = [] - self.mask_list = [] - for data_path in data_paths: - img_path = os.path.join(data_path, self.IMAGE_PATH) - mask_path = os.path.join(data_path, self.MASK_PATH) - self.img_list.extend(self.get_filenames(img_path)) - self.mask_list.extend(self.get_filenames(mask_path)) - # Split between train and valid set (80/20) - - random_inst = random.Random(12345) # for repeatability - n_items = len(self.img_list) - print(f'Found {n_items} possible patches to use') - range_n_items = range(n_items) - range_n_items = random_inst.sample(range_n_items, int(n_items*0.5)) - idxs = set(random_inst.sample(range_n_items, len(range_n_items) // 5)) - total_idxs = set(range_n_items) - if split == 'train': - idxs = total_idxs - idxs - print(f'> Using {len(idxs)} patches for this dataset ({split})') - self.img_list = [self.img_list[i] for i in idxs] - self.mask_list = [self.mask_list[i] for i in idxs] - print(f'>> {split}: {len(self.img_list)}') - - def __len__(self): - return len(self.img_list) - - def __getitem__(self, idx, transpose=True): - - # load image - img = np.load(self.img_list[idx]) - - # load mask - mask = np.load(self.mask_list[idx]) - - mask = np.argmax(mask, axis=-1) - - mask = mask-1 - - # perform transformations - img = self.transform(img) - - return img, mask - - def get_filenames(self, path): - """ - Returns a list of absolute paths to images inside given `path` - """ - files_list = [] - for filename in sorted(os.listdir(path)): - files_list.append(os.path.join(path, filename)) - return files_list diff --git a/pytorch_caney/data/datasets/segmentation_dataset.py b/pytorch_caney/data/datasets/segmentation_dataset.py deleted file mode 100755 index d81c757..0000000 --- a/pytorch_caney/data/datasets/segmentation_dataset.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import logging -from glob import glob -from pathlib import Path -from typing import Optional, Union - -import torch -import numpy as np -from torch.utils.data import Dataset -from torch.utils.dlpack import from_dlpack - -import xarray as xr -from terragpu.engine import array_module, df_module - -import terragpu.ai.preprocessing as preprocessing - -xp = array_module() -xf = df_module() - - -class PLSegmentationDataset(Dataset): - - def __init__( - self, - images_regex: Optional[str] = None, - labels_regex: Optional[str] = None, - dataset_dir: Optional[str] = None, - generate_dataset: bool = False, - tile_size: int = 256, - max_patches: Union[float, int] = 100, - augment: bool = True, - chunks: dict = {'band': 1, 'x': 2048, 'y': 2048}, - input_bands: list = ['CB', 'B', 'G', 'Y', 'R', 'RE', 'N1', 'N2'], - output_bands: list = ['B', 'G', 'R'], - seed: int = 24, - normalize: bool = True, - pytorch: bool = True): - - super().__init__() - - # Dataset metadata - self.input_bands = input_bands - self.output_bands = output_bands - self.chunks = chunks - self.tile_size = tile_size - self.seed = seed - self.max_patches = max_patches - - # Preprocessing metadata - self.generate_dataset = generate_dataset - self.normalize = normalize - - # Validate several input sources - assert dataset_dir is not None, \ - f'dataset_dir: {dataset_dir} does not exist.' - - # Setup directories structure - self.dataset_dir = dataset_dir # where to store dataset - self.images_dir = os.path.join(self.dataset_dir, 'images') - self.labels_dir = os.path.join(self.dataset_dir, 'labels') - - if self.generate_dataset: - - logging.info(f"Starting to prepare dataset: {self.dataset_dir}") - # Assert images_dir and labels_dir to be not None - self.images_regex = images_regex # images location - self.labels_regex = labels_regex # labels location - - # Create directories to store dataset - os.makedirs(self.images_dir, exist_ok=True) - os.makedirs(self.labels_dir, exist_ok=True) - - self.prepare_data() - - assert os.path.exists(self.images_dir), \ - f'{self.images_dir} does not exist. Make sure prepare_data: true.' - assert os.path.exists(self.labels_dir), \ - f'{self.labels_dir} does not exist. Make sure prepare_data: true.' - - self.files = self.get_filenames() - self.augment = augment - self.pytorch = pytorch - - # ------------------------------------------------------------------------- - # Dataset methods - # ------------------------------------------------------------------------- - def __len__(self): - return len(self.files) - - def __repr__(self): - s = 'Dataset class with {} files'.format(self.__len__()) - return s - - def __getitem__(self, idx): - - idx = idx % len(self.files) - x, y = self.open_image(idx), self.open_mask(idx) - - if self.augment: - x, y = self.transform(x, y) - return x, y - - def transform(self, x, y): - - if xp.random.random_sample() > 0.5: # flip left and right - x = torch.fliplr(x) - y = torch.fliplr(y) - if xp.random.random_sample() > 0.5: # reverse second dimension - x = torch.flipud(x) - y = torch.flipud(y) - if xp.random.random_sample() > 0.5: # rotate 90 degrees - x = torch.rot90(x, k=1, dims=[1, 2]) - y = torch.rot90(y, k=1, dims=[0, 1]) - if xp.random.random_sample() > 0.5: # rotate 180 degrees - x = torch.rot90(x, k=2, dims=[1, 2]) - y = torch.rot90(y, k=2, dims=[0, 1]) - if xp.random.random_sample() > 0.5: # rotate 270 degrees - x = torch.rot90(x, k=3, dims=[1, 2]) - y = torch.rot90(y, k=3, dims=[0, 1]) - - # standardize 0.70, 0.30 - # if np.random.random_sample() > 0.70: - # image = preprocess.standardizeLocalCalcTensor(image, means, stds) - # else: - # image = preprocess.standardizeGlobalCalcTensor(image) - return x, y - - # ------------------------------------------------------------------------- - # preprocess methods - # ------------------------------------------------------------------------- - def prepare_data(self): - - logging.info("Preparing dataset...") - images_list = sorted(glob(self.images_regex)) - labels_list = sorted(glob(self.labels_regex)) - - for image, label in zip(images_list, labels_list): - - # Read imagery from disk and process both image and mask - filename = Path(image).stem - image = xr.open_rasterio(image, chunks=self.chunks).load() - label = xr.open_rasterio(label, chunks=self.chunks).values - - # Modify bands if necessary - in a future version, add indices - image = preprocessing.modify_bands( - img=image, input_bands=self.input_bands, - output_bands=self.output_bands) - - # Asarray option to force array type - image = xp.asarray(image.values) - label = xp.asarray(label) - - # Move from chw to hwc, squeze mask if required - image = xp.moveaxis(image, 0, -1).astype(np.int16) - label = xp.squeeze(label) if len(label.shape) != 2 else label - logging.info(f'Label classes from image: {xp.unique(label)}') - - # Generate dataset tiles - image_tiles, label_tiles = preprocessing.gen_random_tiles( - image=image, label=label, tile_size=self.tile_size, - max_patches=self.max_patches, seed=self.seed) - logging.info(f"Tiles: {image_tiles.shape}, {label_tiles.shape}") - - # Save to disk - for id in range(image_tiles.shape[0]): - xp.save( - os.path.join(self.images_dir, f'{filename}_{id}.npy'), - image_tiles[id, :, :, :]) - xp.save( - os.path.join(self.labels_dir, f'{filename}_{id}.npy'), - label_tiles[id, :, :]) - return - - # ------------------------------------------------------------------------- - # dataset methods - # ------------------------------------------------------------------------- - def list_files(self, files_list: list = []): - - for i in os.listdir(self.images_dir): - files_list.append( - { - 'image': os.path.join(self.images_dir, i), - 'label': os.path.join(self.labels_dir, i) - } - ) - return files_list - - def open_image(self, idx: int, invert: bool = True): - # image = imread(self.files[idx]['image']) - image = xp.load(self.files[idx]['image'], allow_pickle=False) - image = image.transpose((2, 0, 1)) if invert else image - image = ( - image / xp.iinfo(image.dtype).max) if self.normalize else image - return from_dlpack(image.toDlpack()) # .to(torch.float32) - - def open_mask(self, idx: int, add_dims: bool = False): - # mask = imread(self.files[idx]['label']) - mask = xp.load(self.files[idx]['label'], allow_pickle=False) - mask = xp.expand_dims(mask, 0) if add_dims else mask - return from_dlpack(mask.toDlpack()) # .to(torch.torch.int64) - - -class SegmentationDataset(Dataset): - - def __init__( - self, dataset_dir, pytorch=True, augment=True): - - super().__init__() - - self.files: list = self.list_files(dataset_dir) - self.augment: bool = augment - self.pytorch: bool = pytorch - self.invert: bool = True - self.normalize: bool = True - self.standardize: bool = True - - # ------------------------------------------------------------------------- - # Common methods - # ------------------------------------------------------------------------- - def __len__(self): - return len(self.files) - - def __repr__(self): - s = 'Dataset class with {} files'.format(self.__len__()) - return s - - def __getitem__(self, idx): - - # get data - x = self.open_image(idx) - y = self.open_mask(idx) - - # augment the data - if self.augment: - - if xp.random.random_sample() > 0.5: # flip left and right - x = torch.fliplr(x) - y = torch.fliplr(y) - if xp.random.random_sample() > 0.5: # reverse second dimension - x = torch.flipud(x) - y = torch.flipud(y) - if xp.random.random_sample() > 0.5: # rotate 90 degrees - x = torch.rot90(x, k=1, dims=[1, 2]) - y = torch.rot90(y, k=1, dims=[0, 1]) - if xp.random.random_sample() > 0.5: # rotate 180 degrees - x = torch.rot90(x, k=2, dims=[1, 2]) - y = torch.rot90(y, k=2, dims=[0, 1]) - if xp.random.random_sample() > 0.5: # rotate 270 degrees - x = torch.rot90(x, k=3, dims=[1, 2]) - y = torch.rot90(y, k=3, dims=[0, 1]) - - return x, y - - # ------------------------------------------------------------------------- - # IO methods - # ------------------------------------------------------------------------- - def get_filenames(self, dataset_dir: str, files_list: list = []): - - images_dir = os.path.join(dataset_dir, 'images') - labels_dir = os.path.join(dataset_dir, 'labels') - - for i in os.listdir(images_dir): - files_list.append( - { - 'image': os.path.join(images_dir, i), - 'label': os.path.join(labels_dir, i) - } - ) - return files_list - - def open_image(self, idx: int): - image = xp.load(self.files[idx]['image'], allow_pickle=False) - if self.invert: - image = image.transpose((2, 0, 1)) - if self.normalize: - image = (image / xp.iinfo(image.dtype).max) - if self.standardize: - image = preprocessing.standardize_local(image) - return from_dlpack(image.toDlpack()).float() - - def open_mask(self, idx: int, add_dims: bool = False): - mask = xp.load(self.files[idx]['label'], allow_pickle=False) - mask = xp.expand_dims(mask, 0) if add_dims else mask - return from_dlpack(mask.toDlpack()).long() diff --git a/pytorch_caney/data/datasets/simmim_modis_dataset.py b/pytorch_caney/data/datasets/simmim_modis_dataset.py deleted file mode 100644 index ff69735..0000000 --- a/pytorch_caney/data/datasets/simmim_modis_dataset.py +++ /dev/null @@ -1,90 +0,0 @@ -from ..utils import SimmimMaskGenerator - -import os -import numpy as np - -from torch.utils.data import Dataset - - -class MODISDataset(Dataset): - """ - MODIS MOD09GA pre-training dataset - """ - IMAGE_PATH = os.path.join("images") - - def __init__( - self, - config, - data_paths: list, - split: str, - img_size: tuple = (192, 192), - transform=None, - ): - - self.config = config - - self.img_size = img_size - - self.transform = transform - - self.split = split - - self.data_paths = data_paths - - self.img_list = [] - - for data_path in data_paths: - - img_path = os.path.join(data_path, self.IMAGE_PATH) - - self.img_list.extend(self.get_filenames(img_path)) - - n_items = len(self.img_list) - - print(f'> Found {n_items} patches for this dataset ({split})') - - if config.MODEL.TYPE in ['swin', 'swinv2']: - - model_patch_size = config.MODEL.SWINV2.PATCH_SIZE - - else: - - raise NotImplementedError - - self.mask_generator = SimmimMaskGenerator( - input_size=config.DATA.IMG_SIZE, - mask_patch_size=config.DATA.MASK_PATCH_SIZE, - model_patch_size=model_patch_size, - mask_ratio=config.DATA.MASK_RATIO, - ) - - def __len__(self): - - return len(self.img_list) - - def __getitem__(self, idx, transpose=True): - - # load image - img = np.load(self.img_list[idx]) - - img = np.clip(img, 0, 1.0) - - # perform transformations - img = self.transform(img) - - mask = self.mask_generator() - - return img, mask - - def get_filenames(self, path): - """ - Returns a list of absolute paths to images inside given `path` - """ - - files_list = [] - - for filename in sorted(os.listdir(path)): - - files_list.append(os.path.join(path, filename)) - - return files_list diff --git a/pytorch_caney/datamodules/__init__.py b/pytorch_caney/datamodules/__init__.py new file mode 100644 index 0000000..b5633d2 --- /dev/null +++ b/pytorch_caney/datamodules/__init__.py @@ -0,0 +1,12 @@ +from .abi_3dcloud_datamodule import AbiToa3DCloudDataModule +from .modis_toa_mim_datamodule import ModisToaMimDataModule + + +DATAMODULES = { + 'abitoa3dcloud': AbiToa3DCloudDataModule, + 'modistoamimpretrain': ModisToaMimDataModule, +} + + +def get_available_datamodules(): + return {name: cls for name, cls in DATAMODULES.items()} diff --git a/pytorch_caney/datamodules/abi_3dcloud_datamodule.py b/pytorch_caney/datamodules/abi_3dcloud_datamodule.py new file mode 100644 index 0000000..2b23f03 --- /dev/null +++ b/pytorch_caney/datamodules/abi_3dcloud_datamodule.py @@ -0,0 +1,75 @@ +from torch.utils.data import DataLoader +import lightning as L + +from pytorch_caney.datasets.abi_3dcloud_dataset import AbiToa3DCloudDataset +from pytorch_caney.transforms.abi_toa import AbiToaTransform + + +# ----------------------------------------------------------------------------- +# AbiToa3DCloudDataModule +# ----------------------------------------------------------------------------- +class AbiToa3DCloudDataModule(L.LightningDataModule): + """NonGeo ABI TOA 3D cloud data module implementation""" + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__( + self, + config, + ) -> None: + super().__init__() + self.config = config + self.transform = AbiToaTransform(config.DATA.IMG_SIZE) + print(self.transform) + self.train_data_paths = config.DATA.DATA_PATHS + self.test_data_paths = config.DATA.TEST_DATA_PATHS + self.batch_size = config.DATA.BATCH_SIZE + self.num_workers = config.DATA.NUM_WORKERS + + # ------------------------------------------------------------------------- + # setup + # ------------------------------------------------------------------------- + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = AbiToa3DCloudDataset( + self.config, + self.train_data_paths, + self.transform, + ) + if stage in ["fit", "validate"]: + self.val_dataset = AbiToa3DCloudDataset( + self.config, + self.test_data_paths, + self.transform, + ) + if stage in ["test"]: + self.test_dataset = AbiToa3DCloudDataset( + self.config, + self.test_data_paths, + self.transform, + ) + + # ------------------------------------------------------------------------- + # train_dataloader + # ------------------------------------------------------------------------- + def train_dataloader(self): + return DataLoader(self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) + + # ------------------------------------------------------------------------- + # val_dataloader + # ------------------------------------------------------------------------- + def val_dataloader(self): + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) + + # ------------------------------------------------------------------------- + # test_dataloader + # ------------------------------------------------------------------------- + def test_dataloader(self): + return DataLoader(self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) diff --git a/pytorch_caney/datamodules/modis_toa_mim_datamodule.py b/pytorch_caney/datamodules/modis_toa_mim_datamodule.py new file mode 100644 index 0000000..e77064e --- /dev/null +++ b/pytorch_caney/datamodules/modis_toa_mim_datamodule.py @@ -0,0 +1,48 @@ +import lightning as L +from torch.utils.data import DataLoader + +from pytorch_caney.datasets.sharded_dataset import ShardedDataset +from pytorch_caney.transforms.mim_modis_toa import MimTransform + + +# ----------------------------------------------------------------------------- +# SatVisionToaPretrain +# ----------------------------------------------------------------------------- +class ModisToaMimDataModule(L.LightningDataModule): + """NonGeo MODIS TOA Masked-Image-Modeling data module implementation""" + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__(self, config,) -> None: + super().__init__() + self.config = config + self.transform = MimTransform(config) + self.batch_size = config.DATA.BATCH_SIZE + self.num_workers = config.DATA.NUM_WORKERS + self.img_size = config.DATA.IMG_SIZE + self.train_data_paths = config.DATA.DATA_PATHS + self.train_data_length = config.DATA.LENGTH + self.pin_memory = config.DATA.PIN_MEMORY + + # ------------------------------------------------------------------------- + # setup + # ------------------------------------------------------------------------- + def setup(self, stage: str) -> None: + if stage in ["fit"]: + self.train_dataset = ShardedDataset( + self.config, + self.train_data_paths, + split='train', + length=self.train_data_length, + img_size=self.img_size, + transform=self.transform, + batch_size=self.batch_size).dataset() + + # ------------------------------------------------------------------------- + # train_dataloader + # ------------------------------------------------------------------------- + def train_dataloader(self) -> DataLoader: + return DataLoader(self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers) diff --git a/pytorch_caney/data/datasets/__init__.py b/pytorch_caney/datasets/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/data/datasets/__init__.py rename to pytorch_caney/datasets/__init__.py diff --git a/pytorch_caney/datasets/abi_3dcloud_dataset.py b/pytorch_caney/datasets/abi_3dcloud_dataset.py new file mode 100644 index 0000000..85056fc --- /dev/null +++ b/pytorch_caney/datasets/abi_3dcloud_dataset.py @@ -0,0 +1,77 @@ +import os +from pathlib import Path +from typing import Any, Dict + +import numpy as np +import rioxarray as rxr + +from torchgeo.datasets import NonGeoDataset + + +# ----------------------------------------------------------------------------- +# AbiToa3DCloudDataModule +# ----------------------------------------------------------------------------- +class AbiToa3DCloudDataset(NonGeoDataset): + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__(self, config, data_paths: list, transform=None) -> None: + + super().__init__() + + self.config = config + self.data_paths = data_paths + self.transform = transform + self.img_size = config.DATA.IMG_SIZE + + self.image_list = [] + self.mask_list = [] + + for image_mask_path in self.data_paths: + self.image_list.extend(self.get_filenames(image_mask_path)) + + self.rgb_indices = [0, 1, 2] + + # ------------------------------------------------------------------------- + # __len__ + # ------------------------------------------------------------------------- + def __len__(self) -> int: + return len(self.image_list) + + # ------------------------------------------------------------------------- + # __getitem__ + # ------------------------------------------------------------------------- + def __getitem__(self, index: int) -> Dict[str, Any]: + + npz_array = self._load_file(self.image_list[index]) + image = npz_array['chip'] + mask = npz_array['data'].item()['Cloud_mask'] + + if self.transform is not None: + image = self.transform(image) + + return image, mask + + # ------------------------------------------------------------------------- + # _load_file + # ------------------------------------------------------------------------- + def _load_file(self, path: Path): + if Path(path).suffix == '.npy' or Path(path).suffix == '.npz': + return np.load(path, allow_pickle=True) + elif Path(path).suffix == '.tif': + return rxr.open_rasterio(path) + else: + raise RuntimeError('Non-recognized dataset format. Expects npy or tif.') # noqa: E501 + + # ------------------------------------------------------------------------- + # get_filenames + # ------------------------------------------------------------------------- + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = [] + for filename in sorted(os.listdir(path)): + files_list.append(os.path.join(path, filename)) + return files_list diff --git a/pytorch_caney/data/datasets/mim_modis_22m_dataset.py b/pytorch_caney/datasets/sharded_dataset.py similarity index 65% rename from pytorch_caney/data/datasets/mim_modis_22m_dataset.py rename to pytorch_caney/datasets/sharded_dataset.py index 0407907..8cec063 100644 --- a/pytorch_caney/data/datasets/mim_modis_22m_dataset.py +++ b/pytorch_caney/datasets/sharded_dataset.py @@ -8,6 +8,9 @@ import torch.distributed as dist +# ----------------------------------------------------------------------------- +# nodesplitter +# ----------------------------------------------------------------------------- def nodesplitter(src, group=None): if dist.is_initialized(): if group is None: @@ -26,50 +29,52 @@ def nodesplitter(src, group=None): yield from src -class MODIS22MDataset(object): +# ----------------------------------------------------------------------------- +# ShardedDataset +# ----------------------------------------------------------------------------- +class ShardedDataset(object): """ - MODIS MOD09GA 22-million pre-training dataset + Base pre-training webdataset """ - SHARD_PATH = os.path.join("shards") - - INPUT_KEY = 'input.npy' - OUTPUT_KEY = 'output.npy' + SHARD_PATH = os.path.join("shards") + INPUT_KEY: str = 'input.npy' + OUTPUT_KEY: str = 'output.npy' + REPEAT: int = 2 def __init__( self, config, data_paths: list, split: str, + length: int, img_size: tuple = (192, 192), transform=None, batch_size=64, ): - self.random_state = 42 - + self.random_state = 1000 self.config = config - self.img_size = img_size - self.transform = transform - self.split = split + self.length = length - self.shard_path = pathlib.Path( - os.path.join(data_paths[0], self.SHARD_PATH)) - + self.shard_path = pathlib.Path(data_paths[0]) shards = self.shard_path.glob('*.tar') - self.shards = list(map(str, shards)) self.batch_size = batch_size + # ------------------------------------------------------------------------- + # dataset + # ------------------------------------------------------------------------- def dataset(self): dataset = ( wds.WebDataset(self.shards, shardshuffle=True, + repeat=True, handler=wds.ignore_and_continue, nodesplitter=nodesplitter) .shuffle(self.random_state) @@ -78,6 +83,8 @@ def dataset(self): .map_tuple(np.load) .map_tuple(self.transform) .batched(self.batch_size, partial=False) + .repeat(self.REPEAT) + .with_length(self.length) ) return dataset diff --git a/pytorch_caney/inference.py b/pytorch_caney/inference.py deleted file mode 100755 index 5abad4a..0000000 --- a/pytorch_caney/inference.py +++ /dev/null @@ -1,382 +0,0 @@ -import logging -import math -import numpy as np - -import torch - -from tiler import Tiler, Merger - -from pytorch_caney.processing import normalize -from pytorch_caney.processing import global_standardization -from pytorch_caney.processing import local_standardization -from pytorch_caney.processing import standardize_image - -__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" -__email__ = "jordan.a.caraballo-vega@nasa.gov" -__status__ = "Production" - -# --------------------------------------------------------------------------- -# module inference -# -# Data segmentation and prediction functions. -# --------------------------------------------------------------------------- - - -# --------------------------------------------------------------------------- -# Module Methods -# --------------------------------------------------------------------------- -def sliding_window_tiler_multiclass( - xraster, - model, - n_classes: int, - img_size: int, - pad_style: str = 'reflect', - overlap: float = 0.50, - constant_value: int = 600, - batch_size: int = 1024, - threshold: float = 0.50, - standardization: str = None, - mean=None, - std=None, - normalize: float = 1.0, - rescale: str = None, - window: str = 'triang', # 'overlap-tile' - probability_map: bool = False - ): - """ - Sliding window using tiler. - """ - - tile_channels = xraster.shape[-1] # model.layers[0].input_shape[0][-1] - print(f'Standardizing: {standardization}') - # n_classes = out of the output layer, output_shape - - tiler_image = Tiler( - data_shape=xraster.shape, - tile_shape=(img_size, img_size, tile_channels), - channel_dimension=-1, - overlap=overlap, - mode=pad_style, - constant_value=constant_value - ) - - # Define the tiler and merger based on the output size of the prediction - tiler_mask = Tiler( - data_shape=(xraster.shape[0], xraster.shape[1], n_classes), - tile_shape=(img_size, img_size, n_classes), - channel_dimension=-1, - overlap=overlap, - mode=pad_style, - constant_value=constant_value - ) - - merger = Merger(tiler=tiler_mask, window=window) - # xraster = normalize_image(xraster, normalize) - - # Iterate over the data in batches - for batch_id, batch_i in tiler_image(xraster, batch_size=batch_size): - - # Standardize - batch = batch_i.copy() - - if standardization is not None: - for item in range(batch.shape[0]): - batch[item, :, :, :] = standardize_image( - batch[item, :, :, :], standardization, mean, std) - - input_batch = batch.astype('float32') - input_batch_tensor = torch.from_numpy(input_batch) - input_batch_tensor = input_batch_tensor.transpose(-1, 1) - # input_batch_tensor = input_batch_tensor.cuda(non_blocking=True) - with torch.no_grad(): - y_batch = model(input_batch_tensor) - y_batch = y_batch.transpose(1, -1) # .cpu().numpy() - merger.add_batch(batch_id, batch_size, y_batch) - - prediction = merger.merge(unpad=True) - - if not probability_map: - if prediction.shape[-1] > 1: - prediction = np.argmax(prediction, axis=-1) - else: - prediction = np.squeeze( - np.where(prediction > threshold, 1, 0).astype(np.int16) - ) - else: - prediction = np.squeeze(prediction) - return prediction - - -# --------------------------- Segmentation Functions ----------------------- # - -def segment(image, model='model.h5', tile_size=256, channels=6, - norm_data=[], bsize=8): - """ - Applies a semantic segmentation model to an image. Ideal for non-scene - imagery. Leaves artifacts in boundaries if no post-processing is done. - :param image: image to classify (numpy array) - :param model: loaded model object - :param tile_size: tile size of patches - :param channels: number of channels - :param norm_data: numpy array with mean and std data - :param bsize: number of patches to predict at the same time - return numpy array with classified mask - """ - # Create blank array to store predicted label - seg = np.zeros((image.shape[0], image.shape[1])) - for i in range(0, image.shape[0], int(tile_size)): - for j in range(0, image.shape[1], int(tile_size)): - # If edge of tile beyond image boundary, shift it to boundary - if i + tile_size > image.shape[0]: - i = image.shape[0] - tile_size - if j + tile_size > image.shape[1]: - j = image.shape[1] - tile_size - - # Extract and normalise tile - tile = normalize( - image[i: i + tile_size, j: j + tile_size, :].astype(float), - norm_data - ) - out = model.predict( - tile.reshape( - (1, tile.shape[0], tile.shape[1], tile.shape[2]) - ).astype(float), - batch_size=4 - ) - out = out.argmax(axis=3) # get max prediction for pixel in classes - out = out.reshape(tile_size, tile_size) # reshape to tile size - seg[i: i + tile_size, j: j + tile_size] = out - return seg - - -def segment_binary(image, model='model.h5', norm_data=[], - tile_size=256, channels=6, bsize=8 - ): - """ - Applies binary semantic segmentation model to an image. Ideal for non-scene - imagery. Leaves artifacts in boundaries if no post-processing is done. - :param image: image to classify (numpy array) - :param model: loaded model object - :param tile_size: tile size of patches - :param channels: number of channels - :param norm_data: numpy array with mean and std data - return numpy array with classified mask - """ - # Create blank array to store predicted label - seg = np.zeros((image.shape[0], image.shape[1])) - for i in range(0, image.shape[0], int(tile_size)): - for j in range(0, image.shape[1], int(tile_size)): - # If edge of tile beyond image boundary, shift it to boundary - if i + tile_size > image.shape[0]: - i = image.shape[0] - tile_size - if j + tile_size > image.shape[1]: - j = image.shape[1] - tile_size - - # Extract and normalise tile - tile = normalize( - image[i:i + tile_size, j:j + tile_size, :].astype(float), - norm_data - ) - out = model.predict( - tile.reshape( - (1, tile.shape[0], tile.shape[1], tile.shape[2]) - ).astype(float), - batch_size=bsize - ) - out[out >= 0.5] = 1 - out[out < 0.5] = 0 - out = out.reshape(tile_size, tile_size) # reshape to tile size - seg[i:i + tile_size, j:j + tile_size] = out - return seg - - -def pad_image(img, target_size): - """ - Pad an image up to the target size. - """ - rows_missing = target_size - img.shape[0] - cols_missing = target_size - img.shape[1] - padded_img = np.pad( - img, ((0, rows_missing), (0, cols_missing), (0, 0)), 'constant' - ) - return padded_img - - -def predict_sliding(image, model='', stand_method='local', - stand_strategy='per-batch', stand_data=[], - tile_size=256, nclasses=6, overlap=0.25, spline=[] - ): - """ - Predict on tiles of exactly the network input shape. - This way nothing gets squeezed. - """ - model.eval() - stride = math.ceil(tile_size * (1 - overlap)) - tile_rows = max( - int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1 - ) # strided convolution formula - tile_cols = max( - int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1 - ) - logging.info("Need %i x %i prediction tiles @ stride %i px" % - (tile_cols, tile_rows, stride) - ) - - full_probs = np.zeros((image.shape[0], image.shape[1], nclasses)) - count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses)) - tile_counter = 0 - for row in range(tile_rows): - for col in range(tile_cols): - x1 = int(col * stride) - y1 = int(row * stride) - x2 = min(x1 + tile_size, image.shape[1]) - y2 = min(y1 + tile_size, image.shape[0]) - x1 = max(int(x2 - tile_size), 0) - y1 = max(int(y2 - tile_size), 0) - - img = image[y1:y2, x1:x2] - padded_img = pad_image(img, tile_size) - tile_counter += 1 - - padded_img = np.expand_dims(padded_img, 0) - - if stand_method == 'local': - imgn = local_standardization( - padded_img, ndata=stand_data, strategy=stand_strategy - ) - elif stand_method == 'global': - imgn = global_standardization( - padded_img, strategy=stand_strategy - ) - else: - imgn = padded_img - - imgn = imgn.astype('float32') - imgn_tensor = torch.from_numpy(imgn) - imgn_tensor = imgn_tensor.transpose(-1, 1) - with torch.no_grad(): - padded_prediction = model(imgn_tensor) - # if padded_prediction.shape[1] > 1: - # padded_prediction = np.argmax(padded_prediction, axis=1) - padded_prediction = np.squeeze(padded_prediction) - padded_prediction = padded_prediction.transpose(0, -1).numpy() - prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :] - count_predictions[y1:y2, x1:x2] += 1 - full_probs[y1:y2, x1:x2] += prediction # * spline - # average the predictions in the overlapping regions - full_probs /= count_predictions - return full_probs - - -def predict_sliding_binary(image, model='model.h5', tile_size=256, - nclasses=6, overlap=1/3, norm_data=[] - ): - """ - Predict on tiles of exactly the network input shape. - This way nothing gets squeezed. - """ - stride = math.ceil(tile_size * (1 - overlap)) - tile_rows = max( - int(math.ceil((image.shape[0] - tile_size) / stride) + 1), 1 - ) # strided convolution formula - tile_cols = max( - int(math.ceil((image.shape[1] - tile_size) / stride) + 1), 1 - ) - logging.info("Need %i x %i prediction tiles @ stride %i px" % - (tile_cols, tile_rows, stride) - ) - full_probs = np.zeros((image.shape[0], image.shape[1], nclasses)) - count_predictions = np.zeros((image.shape[0], image.shape[1], nclasses)) - tile_counter = 0 - for row in range(tile_rows): - for col in range(tile_cols): - x1 = int(col * stride) - y1 = int(row * stride) - x2 = min(x1 + tile_size, image.shape[1]) - y2 = min(y1 + tile_size, image.shape[0]) - x1 = max(int(x2 - tile_size), 0) - y1 = max(int(y2 - tile_size), 0) - - img = image[y1:y2, x1:x2] - padded_img = pad_image(img, tile_size) - tile_counter += 1 - - imgn = normalize(padded_img, norm_data) - imgn = imgn.astype('float32') - padded_prediction = model.predict(np.expand_dims(imgn, 0))[0] - prediction = padded_prediction[0:img.shape[0], 0:img.shape[1], :] - count_predictions[y1:y2, x1:x2] += 1 - full_probs[y1:y2, x1:x2] += prediction - # average the predictions in the overlapping regions - full_probs /= count_predictions - full_probs[full_probs >= 0.8] = 1 - full_probs[full_probs < 0.8] = 0 - return full_probs.reshape((image.shape[0], image.shape[1])) - - -def predict_windowing(x, model, stand_method='local', - stand_strategy='per-batch', stand_data=[], - patch_sz=160, n_classes=5, b_size=128, spline=[] - ): - img_height = x.shape[0] - img_width = x.shape[1] - n_channels = x.shape[2] - # make extended img so that it contains integer number of patches - npatches_vertical = math.ceil(img_height / patch_sz) - npatches_horizontal = math.ceil(img_width / patch_sz) - extended_height = patch_sz * npatches_vertical - extended_width = patch_sz * npatches_horizontal - ext_x = np.zeros( - shape=(extended_height, extended_width, n_channels), dtype=np.float32 - ) - # fill extended image with mirrors: - ext_x[:img_height, :img_width, :] = x - for i in range(img_height, extended_height): - ext_x[i, :, :] = ext_x[2 * img_height - i - 1, :, :] - for j in range(img_width, extended_width): - ext_x[:, j, :] = ext_x[:, 2 * img_width - j - 1, :] - - # now we assemble all patches in one array - patches_list = [] - for i in range(0, npatches_vertical): - for j in range(0, npatches_horizontal): - x0, x1 = i * patch_sz, (i + 1) * patch_sz - y0, y1 = j * patch_sz, (j + 1) * patch_sz - patches_list.append(ext_x[x0:x1, y0:y1, :]) - patches_array = np.asarray(patches_list) - - # normalization(patches_array, ndata) - - if stand_method == 'local': # apply local zero center standardization - patches_array = local_standardization( - patches_array, ndata=stand_data, strategy=stand_strategy - ) - elif stand_method == 'global': # apply global zero center standardization - patches_array = global_standardization( - patches_array, strategy=stand_strategy - ) - - # predictions: - patches_predict = model.predict(patches_array, batch_size=b_size) - prediction = np.zeros( - shape=(extended_height, extended_width, n_classes), dtype=np.float32 - ) - logging.info("prediction shape: ", prediction.shape) - for k in range(patches_predict.shape[0]): - i = k // npatches_horizontal - j = k % npatches_horizontal - x0, x1 = i * patch_sz, (i + 1) * patch_sz - y0, y1 = j * patch_sz, (j + 1) * patch_sz - prediction[x0:x1, y0:y1, :] = patches_predict[k, :, :, :] * spline - return prediction[:img_height, :img_width, :] - - -# ------------------------------------------------------------------------------- -# module model Unit Tests -# ------------------------------------------------------------------------------- - -if __name__ == "__main__": - - logging.basicConfig(level=logging.INFO) - - # Add unit tests here diff --git a/pytorch_caney/models/mim/__init__.py b/pytorch_caney/inference/__init__.py similarity index 100% rename from pytorch_caney/models/mim/__init__.py rename to pytorch_caney/inference/__init__.py diff --git a/pytorch_caney/loss/build.py b/pytorch_caney/loss/build.py deleted file mode 100644 index aa1cc16..0000000 --- a/pytorch_caney/loss/build.py +++ /dev/null @@ -1,64 +0,0 @@ -from segmentation_models_pytorch.losses import TverskyLoss - - -LOSSES = { - 'tversky': TverskyLoss, -} - - -def get_loss_from_dict(loss_name, config): - """Gets the proper loss given a loss name. - - Args: - loss_name (str): name of the loss - config: config object - - Raises: - KeyError: thrown if loss key is not present in dict - - Returns: - loss: pytorch loss - """ - - try: - - loss_to_use = LOSSES[loss_name] - - except KeyError: - - error_msg = f"{loss_name} is not an implemented loss" - - error_msg = f"{error_msg}. Available loss functions: {LOSSES.keys()}" - - raise KeyError(error_msg) - - if loss_name == 'tversky': - loss = loss_to_use(mode=config.LOSS.MODE, - classes=config.LOSS.CLASSES, - log_loss=config.LOSS.LOG, - from_logits=config.LOSS.LOGITS, - smooth=config.LOSS.SMOOTH, - ignore_index=config.LOSS.IGNORE_INDEX, - eps=config.LOSS.EPS, - alpha=config.LOSS.ALPHA, - beta=config.LOSS.BETA, - gamma=config.LOSS.GAMMA) - return loss - - -def build_loss(config): - """ - Builds the loss function given a configuration object. - - Args: - config: config object - - Returns: - loss_to_use: pytorch loss function - """ - - loss_name = config.LOSS.NAME - - loss_to_use = get_loss_from_dict(loss_name, config) - - return loss_to_use diff --git a/pytorch_caney/loss/utils.py b/pytorch_caney/loss/utils.py deleted file mode 100755 index 4319803..0000000 --- a/pytorch_caney/loss/utils.py +++ /dev/null @@ -1,26 +0,0 @@ -import numpy as np - -import torch - - -# --- -# Adapted from -# https://github.com/qubvel/segmentation_models.pytorch \ -# /tree/master/segmentation_models_pytorch/losses -# --- -def to_tensor(x, dtype=None) -> torch.Tensor: - if isinstance(x, torch.Tensor): - if dtype is not None: - x = x.type(dtype) - return x - if isinstance(x, np.ndarray): - x = torch.from_numpy(x) - if dtype is not None: - x = x.type(dtype) - return x - if isinstance(x, (list, tuple)): - x = np.array(x) - x = torch.from_numpy(x) - if dtype is not None: - x = x.type(dtype) - return x diff --git a/pytorch_caney/models/simmim/__init__.py b/pytorch_caney/losses/__init__.py similarity index 100% rename from pytorch_caney/models/simmim/__init__.py rename to pytorch_caney/losses/__init__.py diff --git a/pytorch_caney/lr_scheduler.py b/pytorch_caney/lr_scheduler.py deleted file mode 100644 index cd693c9..0000000 --- a/pytorch_caney/lr_scheduler.py +++ /dev/null @@ -1,185 +0,0 @@ -from bisect import bisect_right - -from timm.scheduler.cosine_lr import CosineLRScheduler -from timm.scheduler.step_lr import StepLRScheduler -from timm.scheduler.scheduler import Scheduler - -import torch -import torch.distributed as dist - - -def build_scheduler(config, optimizer, n_iter_per_epoch): - num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) - warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) - decay_steps = int( - config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) - multi_steps = [ - i * n_iter_per_epoch for i in config.TRAIN.LR_SCHEDULER.MULTISTEPS] - - lr_scheduler = None - if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': - lr_scheduler = CosineLRScheduler( - optimizer, - t_initial=num_steps, - cycle_mul=1., - lr_min=config.TRAIN.MIN_LR, - warmup_lr_init=config.TRAIN.WARMUP_LR, - warmup_t=warmup_steps, - cycle_limit=1, - t_in_epochs=False, - ) - elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': - lr_scheduler = LinearLRScheduler( - optimizer, - t_initial=num_steps, - lr_min_rate=0.01, - warmup_lr_init=config.TRAIN.WARMUP_LR, - warmup_t=warmup_steps, - t_in_epochs=False, - ) - elif config.TRAIN.LR_SCHEDULER.NAME == 'step': - lr_scheduler = StepLRScheduler( - optimizer, - decay_t=decay_steps, - decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, - warmup_lr_init=config.TRAIN.WARMUP_LR, - warmup_t=warmup_steps, - t_in_epochs=False, - ) - elif config.TRAIN.LR_SCHEDULER.NAME == 'multistep': - lr_scheduler = MultiStepLRScheduler( - optimizer, - milestones=multi_steps, - gamma=config.TRAIN.LR_SCHEDULER.GAMMA, - warmup_lr_init=config.TRAIN.WARMUP_LR, - warmup_t=warmup_steps, - t_in_epochs=False, - ) - - return lr_scheduler - - -class LinearLRScheduler(Scheduler): - def __init__(self, - optimizer: torch.optim.Optimizer, - t_initial: int, - lr_min_rate: float, - warmup_t=0, - warmup_lr_init=0., - t_in_epochs=True, - noise_range_t=None, - noise_pct=0.67, - noise_std=1.0, - noise_seed=42, - initialize=True, - ) -> None: - super().__init__( - optimizer, param_group_field="lr", - noise_range_t=noise_range_t, noise_pct=noise_pct, - noise_std=noise_std, noise_seed=noise_seed, - initialize=initialize) - - self.t_initial = t_initial - self.lr_min_rate = lr_min_rate - self.warmup_t = warmup_t - self.warmup_lr_init = warmup_lr_init - self.t_in_epochs = t_in_epochs - if self.warmup_t: - self.warmup_steps = [(v - warmup_lr_init) / - self.warmup_t for v in self.base_values] - super().update_groups(self.warmup_lr_init) - else: - self.warmup_steps = [1 for _ in self.base_values] - - def _get_lr(self, t): - if t < self.warmup_t: - lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] - else: - t = t - self.warmup_t - total_t = self.t_initial - self.warmup_t - lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) - for v in self.base_values] - return lrs - - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None - - -class MultiStepLRScheduler(Scheduler): - def __init__(self, optimizer: torch.optim.Optimizer, - milestones, gamma=0.1, warmup_t=0, - warmup_lr_init=0, t_in_epochs=True) -> None: - super().__init__(optimizer, param_group_field="lr") - - self.milestones = milestones - self.gamma = gamma - self.warmup_t = warmup_t - self.warmup_lr_init = warmup_lr_init - self.t_in_epochs = t_in_epochs - if self.warmup_t: - self.warmup_steps = [(v - warmup_lr_init) / - self.warmup_t for v in self.base_values] - super().update_groups(self.warmup_lr_init) - else: - self.warmup_steps = [1 for _ in self.base_values] - - assert self.warmup_t <= min(self.milestones) - - def _get_lr(self, t): - if t < self.warmup_t: - lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] - else: - lrs = [v * (self.gamma ** bisect_right(self.milestones, t)) - for v in self.base_values] - return lrs - - def get_epoch_values(self, epoch: int): - if self.t_in_epochs: - return self._get_lr(epoch) - else: - return None - - def get_update_values(self, num_updates: int): - if not self.t_in_epochs: - return self._get_lr(num_updates) - else: - return None - - -def setup_scaled_lr(config): - # linear scale the learning rate according to total batch size, - # may not be optimal - - batch_size = config.DATA.BATCH_SIZE - - world_size = dist.get_world_size() - - denom_const = 512.0 - - accumulation_steps = config.TRAIN.ACCUMULATION_STEPS - - linear_scaled_lr = config.TRAIN.BASE_LR * \ - batch_size * world_size / denom_const - - linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * \ - batch_size * world_size / denom_const - - linear_scaled_min_lr = config.TRAIN.MIN_LR * \ - batch_size * world_size / denom_const - - # gradient accumulation also need to scale the learning rate - if accumulation_steps > 1: - linear_scaled_lr = linear_scaled_lr * accumulation_steps - linear_scaled_warmup_lr = linear_scaled_warmup_lr * accumulation_steps - linear_scaled_min_lr = linear_scaled_min_lr * accumulation_steps - - return linear_scaled_lr, linear_scaled_warmup_lr, linear_scaled_min_lr diff --git a/pytorch_caney/network/__init__.py b/pytorch_caney/lr_schedulers/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/network/__init__.py rename to pytorch_caney/lr_schedulers/__init__.py diff --git a/pytorch_caney/metrics.py b/pytorch_caney/metrics.py deleted file mode 100755 index 4679464..0000000 --- a/pytorch_caney/metrics.py +++ /dev/null @@ -1,80 +0,0 @@ -import logging -from typing import List - -import torch -import numpy as np -from sklearn.metrics import accuracy_score -from sklearn.metrics import precision_score -from sklearn.metrics import recall_score - -__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" -__email__ = "jordan.a.caraballo-vega@nasa.gov" -__status__ = "Production" - -# --------------------------------------------------------------------------- -# module metrics -# -# General functions to compute custom metrics. -# --------------------------------------------------------------------------- - -# --------------------------------------------------------------------------- -# Module Methods -# --------------------------------------------------------------------------- - -EPSILON = 1e-15 - - -# ------------------------------ Metric Functions -------------------------- # - -def iou_val(y_true, y_pred): - intersection = np.logical_and(y_true, y_pred) - union = np.logical_or(y_true, y_pred) - iou_score = np.sum(intersection) / np.sum(union) - return iou_score - - -def acc_val(y_true, y_pred): - return accuracy_score(y_true, y_pred) - - -def prec_val(y_true, y_pred): - return precision_score(y_true, y_pred, average='macro'), \ - precision_score(y_true, y_pred, average=None) - - -def recall_val(y_true, y_pred): - return recall_score(y_true, y_pred, average='macro'), \ - recall_score(y_true, y_pred, average=None) - - -def find_average(outputs: List, name: str) -> torch.Tensor: - if len(outputs[0][name].shape) == 0: - return torch.stack([x[name] for x in outputs]).mean() - return torch.cat([x[name] for x in outputs]).mean() - - -def binary_mean_iou( - logits: torch.Tensor, - targets: torch.Tensor - ) -> torch.Tensor: - - output = (logits > 0).int() - - if output.shape != targets.shape: - targets = torch.squeeze(targets, 1) - - intersection = (targets * output).sum() - - union = targets.sum() + output.sum() - intersection - - result = (intersection + EPSILON) / (union + EPSILON) - - return result - - -# ------------------------------------------------------------------------------- -# module metrics Unit Tests -# ------------------------------------------------------------------------------- -if __name__ == "__main__": - - logging.basicConfig(level=logging.INFO) diff --git a/pytorch_caney/models/__init__.py b/pytorch_caney/models/__init__.py old mode 100755 new mode 100644 index e69de29..e381289 --- a/pytorch_caney/models/__init__.py +++ b/pytorch_caney/models/__init__.py @@ -0,0 +1,9 @@ +from .model_factory import ModelFactory +from .mim import MiMModel +from .heads import SegmentationHead +from .decoders import FcnDecoder +from .encoders import SatVision, SwinTransformerV2, FcnEncoder + + +__all__ = [ModelFactory, MiMModel, SegmentationHead, + FcnDecoder, SatVision, SwinTransformerV2, FcnEncoder] diff --git a/pytorch_caney/models/build.py b/pytorch_caney/models/build.py deleted file mode 100644 index 9ffc47c..0000000 --- a/pytorch_caney/models/build.py +++ /dev/null @@ -1,105 +0,0 @@ -from .swinv2_model import SwinTransformerV2 -from .unet_swin_model import unet_swin -from .mim.mim import build_mim_model -from ..training.mim_utils import load_pretrained - -import logging - - -def build_model(config, - pretrain: bool = False, - pretrain_method: str = 'mim', - logger: logging.Logger = None): - """ - Given a config object, builds a pytorch model. - - Returns: - model: built model - """ - - if pretrain: - - if pretrain_method == 'mim': - model = build_mim_model(config) - return model - - encoder_architecture = config.MODEL.TYPE - decoder_architecture = config.MODEL.DECODER - - if encoder_architecture == 'swinv2': - - logger.info(f'Hit encoder only build, building {encoder_architecture}') - - window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES - - model = SwinTransformerV2( - img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWINV2.PATCH_SIZE, - in_chans=config.MODEL.SWINV2.IN_CHANS, - num_classes=config.MODEL.NUM_CLASSES, - embed_dim=config.MODEL.SWINV2.EMBED_DIM, - depths=config.MODEL.SWINV2.DEPTHS, - num_heads=config.MODEL.SWINV2.NUM_HEADS, - window_size=config.MODEL.SWINV2.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, - qkv_bias=config.MODEL.SWINV2.QKV_BIAS, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWINV2.APE, - patch_norm=config.MODEL.SWINV2.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT, - pretrained_window_sizes=window_sizes) - - if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): - load_pretrained(config, model, logger) - - else: - - errorMsg = f'Unknown encoder architecture {encoder_architecture}' - - logger.error(errorMsg) - - raise NotImplementedError(errorMsg) - - if decoder_architecture is not None: - - if encoder_architecture == 'swinv2': - - window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES - - model = SwinTransformerV2( - img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWINV2.PATCH_SIZE, - in_chans=config.MODEL.SWINV2.IN_CHANS, - num_classes=config.MODEL.NUM_CLASSES, - embed_dim=config.MODEL.SWINV2.EMBED_DIM, - depths=config.MODEL.SWINV2.DEPTHS, - num_heads=config.MODEL.SWINV2.NUM_HEADS, - window_size=config.MODEL.SWINV2.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, - qkv_bias=config.MODEL.SWINV2.QKV_BIAS, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWINV2.APE, - patch_norm=config.MODEL.SWINV2.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT, - pretrained_window_sizes=window_sizes) - - else: - - raise NotImplementedError() - - if decoder_architecture == 'unet': - - num_classes = config.MODEL.NUM_CLASSES - - if config.MODEL.PRETRAINED and (not config.MODEL.RESUME): - load_pretrained(config, model, logger) - - model = unet_swin(encoder=model, num_classes=num_classes) - - else: - error_msg = f'Unknown decoder architecture: {decoder_architecture}' - raise NotImplementedError(error_msg) - - return model diff --git a/pytorch_caney/models/decoders/__init__.py b/pytorch_caney/models/decoders/__init__.py new file mode 100644 index 0000000..303682d --- /dev/null +++ b/pytorch_caney/models/decoders/__init__.py @@ -0,0 +1,4 @@ +from .fcn_decoder import FcnDecoder + + +__all__ = [FcnDecoder] diff --git a/pytorch_caney/models/decoders/fcn_decoder.py b/pytorch_caney/models/decoders/fcn_decoder.py new file mode 100644 index 0000000..cdb3d04 --- /dev/null +++ b/pytorch_caney/models/decoders/fcn_decoder.py @@ -0,0 +1,25 @@ +import torch.nn as nn + +from ..model_factory import ModelFactory + + +@ModelFactory.decoder("fcn") +class FcnDecoder(nn.Module): + def __init__(self, num_features: int = 1024): + super(FcnDecoder, self).__init__() + self.output_channels = 64 + self.decoder = nn.Sequential( + nn.ConvTranspose2d(num_features, 2048, kernel_size=3, stride=2, padding=1, output_padding=1), # 16x16x512 # noqa: E501 + nn.ReLU(), + nn.ConvTranspose2d(2048, 512, kernel_size=3, stride=2, padding=1, output_padding=1), # 32x32x256 # noqa: E501 + nn.ReLU(), + nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), # 64x64x128 # noqa: E501 + nn.ReLU(), + nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1), # 64x64x128 # noqa: E501 + nn.ReLU(), + nn.ConvTranspose2d(128, self.output_channels, kernel_size=3, stride=2, padding=1, output_padding=1), # 128x128x64 # noqa: E501 + nn.ReLU() + ) + + def forward(self, x): + return self.decoder(x) diff --git a/pytorch_caney/models/decoders/unet_decoder.py b/pytorch_caney/models/decoders/unet_decoder.py deleted file mode 100644 index b55fcb3..0000000 --- a/pytorch_caney/models/decoders/unet_decoder.py +++ /dev/null @@ -1,181 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from segmentation_models_pytorch.base import modules as md - - -class DecoderBlock(nn.Module): - def __init__( - self, - in_channels, - skip_channels, - out_channels, - use_batchnorm=True, - attention_type=None, - ): - super().__init__() - - self.conv1 = md.Conv2dReLU( - in_channels + skip_channels, - out_channels, - kernel_size=3, - padding=1, - use_batchnorm=use_batchnorm, - ) - - in_and_skip_channels = in_channels + skip_channels - - self.attention1 = md.Attention(attention_type, - in_channels=in_and_skip_channels) - - self.conv2 = md.Conv2dReLU( - out_channels, - out_channels, - kernel_size=3, - padding=1, - use_batchnorm=use_batchnorm, - ) - - self.attention2 = md.Attention(attention_type, - in_channels=out_channels) - - self.in_channels = in_channels - self.out_channels = out_channels - self.skip_channels = skip_channels - - def forward(self, x, skip=None): - - if skip is None: - x = F.interpolate(x, scale_factor=2, mode="nearest") - - else: - - if x.shape[-1] != skip.shape[-1]: - x = F.interpolate(x, scale_factor=2, mode="nearest") - - if skip is not None: - - x = torch.cat([x, skip], dim=1) - x = self.attention1(x) - - x = self.conv1(x) - x = self.conv2(x) - x = self.attention2(x) - - return x - - -class CenterBlock(nn.Sequential): - def __init__(self, in_channels, out_channels, use_batchnorm=True): - conv1 = md.Conv2dReLU( - in_channels, - out_channels, - kernel_size=3, - padding=1, - use_batchnorm=use_batchnorm, - ) - conv2 = md.Conv2dReLU( - out_channels, - out_channels, - kernel_size=3, - padding=1, - use_batchnorm=use_batchnorm, - ) - super().__init__(conv1, conv2) - - -class UnetDecoder(nn.Module): - def __init__(self, - encoder_channels, - decoder_channels, - n_blocks=5, - use_batchnorm=True, - attention_type=None, - center=False): - super().__init__() - - if n_blocks != len(decoder_channels): - raise ValueError( - f"Model depth is {n_blocks}, but you provided " - f"decoder_channels for {len(decoder_channels)} blocks." - ) - - # remove first skip with same spatial resolution - encoder_channels = encoder_channels[1:] - - # reverse channels to start from head of encoder - encoder_channels = encoder_channels[::-1] - - # computing blocks input and output channels - head_channels = encoder_channels[0] - - in_channels = [head_channels] + list(decoder_channels[:-1]) - - skip_channels = list(encoder_channels[1:]) + [0] - - out_channels = decoder_channels - - if center: - - self.center = CenterBlock( - head_channels, head_channels, use_batchnorm=use_batchnorm) - - else: - - self.center = nn.Identity() - - # combine decoder keyword arguments - kwargs = dict(use_batchnorm=use_batchnorm, - attention_type=attention_type) - - blocks = [ - DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) - for in_ch, skip_ch, out_ch in zip(in_channels, - skip_channels, - out_channels) - ] - - self.blocks = nn.ModuleList(blocks) - - def forward(self, *features): - - features = features[1:] - - # remove first skip with same spatial resolution - - features = features[:: -1] - # reverse channels to start from head of encoder - - head = features[0] - - skips = features[1:] - - x = self.center(head) - - for i, decoder_block in enumerate(self.blocks): - - skip = skips[i] if i < len(skips) else None - - x = decoder_block(x, skip) - - return x - - -class SegmentationHead(nn.Sequential): - - def __init__(self, - in_channels, - out_channels, - kernel_size=3, - upsampling=1): - - conv2d = nn.Conv2d(in_channels, - out_channels, - kernel_size=kernel_size, - padding=kernel_size // 2) - - upsampling = nn.UpsamplingBilinear2d( - scale_factor=upsampling) if upsampling > 1 else nn.Identity() - - super().__init__(conv2d, upsampling) diff --git a/pytorch_caney/models/encoders/__init__.py b/pytorch_caney/models/encoders/__init__.py new file mode 100644 index 0000000..ac897ad --- /dev/null +++ b/pytorch_caney/models/encoders/__init__.py @@ -0,0 +1,6 @@ +from .fcn_encoder import FcnEncoder +from .satvision import SatVision +from .swinv2 import SwinTransformerV2 + + +__all__ = [FcnEncoder, SatVision, SwinTransformerV2] diff --git a/pytorch_caney/models/encoders/fcn_encoder.py b/pytorch_caney/models/encoders/fcn_encoder.py new file mode 100644 index 0000000..3f77cc0 --- /dev/null +++ b/pytorch_caney/models/encoders/fcn_encoder.py @@ -0,0 +1,26 @@ +import torch.nn as nn + +from ..model_factory import ModelFactory + + +@ModelFactory.encoder("fcn") +class FcnEncoder(nn.Module): + def __init__(self, config): + super(FcnEncoder, self).__init__() + self.config = config + self.num_input_channels = self.config.MODEL.IN_CHANS + self.num_features = 1024 + self.encoder = nn.Sequential( + nn.Conv2d(self.num_input_channels, 64, kernel_size=3, stride=1, padding=1), # 128x128x64 # noqa: E501 + nn.ReLU(), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 64x64x128 # noqa: E501 + nn.ReLU(), + nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 32x32x256 # noqa: E501 + nn.ReLU(), + nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 16x16x512 # noqa: E501 + nn.ReLU(), + nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1) # 8x8x1024 # noqa: E501 + ) + + def forward(self, x): + return self.encoder(x) diff --git a/pytorch_caney/models/encoders/satvision.py b/pytorch_caney/models/encoders/satvision.py new file mode 100644 index 0000000..06141d0 --- /dev/null +++ b/pytorch_caney/models/encoders/satvision.py @@ -0,0 +1,99 @@ +from .swinv2 import SwinTransformerV2 +from ..model_factory import ModelFactory +import torch.nn as nn +import torch + + +# ----------------------------------------------------------------------------- +# SatVision +# ----------------------------------------------------------------------------- +@ModelFactory.encoder("satvision") +class SatVision(nn.Module): + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__(self, config): + super().__init__() + + self.config = config + + window_sizes = config.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES + + self.model = SwinTransformerV2( + img_size=config.DATA.IMG_SIZE, + patch_size=config.MODEL.SWINV2.PATCH_SIZE, + in_chans=config.MODEL.SWINV2.IN_CHANS, + num_classes=config.MODEL.NUM_CLASSES, + embed_dim=config.MODEL.SWINV2.EMBED_DIM, + depths=config.MODEL.SWINV2.DEPTHS, + num_heads=config.MODEL.SWINV2.NUM_HEADS, + window_size=config.MODEL.SWINV2.WINDOW_SIZE, + mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, + qkv_bias=config.MODEL.SWINV2.QKV_BIAS, + drop_rate=config.MODEL.DROP_RATE, + drop_path_rate=config.MODEL.DROP_PATH_RATE, + ape=config.MODEL.SWINV2.APE, + patch_norm=config.MODEL.SWINV2.PATCH_NORM, + use_checkpoint=config.TRAIN.USE_CHECKPOINT, + pretrained_window_sizes=window_sizes, + ) + + if self.config.MODEL.PRETRAINED: + self.load_pretrained() + + self.num_classes = self.model.num_classes + self.num_layers = self.model.num_layers + self.num_features = self.model.num_features + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def load_pretrained(self): + + checkpoint = torch.load( + self.config.MODEL.PRETRAINED, map_location='cpu') + + checkpoint_model = checkpoint['module'] + + if any([True if 'encoder.' in k else + False for k in checkpoint_model.keys()]): + + checkpoint_model = {k.replace( + 'encoder.', ''): v for k, v in checkpoint_model.items() + if k.startswith('encoder.')} + + print('Detect pre-trained model, remove [encoder.] prefix.') + + else: + + print( + 'Detect non-pre-trained model, pass without doing anything.') + + msg = self.model.load_state_dict(checkpoint_model, strict=False) + + print(msg) + + del checkpoint + + torch.cuda.empty_cache() + + print(f">>>>>>> loaded successfully '{self.config.MODEL.PRETRAINED}'") + + # ------------------------------------------------------------------------- + # forward + # ------------------------------------------------------------------------- + def forward(self, x): + return self.model.forward(x) + + # ------------------------------------------------------------------------- + # forward_features + # ------------------------------------------------------------------------- + def forward_features(self, x): + return self.model.forward_features(x) + + # ------------------------------------------------------------------------- + # extra_features + # ------------------------------------------------------------------------- + def extra_features(self, x): + return self.model.extra_features(x) diff --git a/pytorch_caney/models/swinv2_model.py b/pytorch_caney/models/encoders/swinv2.py similarity index 68% rename from pytorch_caney/models/swinv2_model.py rename to pytorch_caney/models/encoders/swinv2.py index f784af0..fd4ed8f 100644 --- a/pytorch_caney/models/swinv2_model.py +++ b/pytorch_caney/models/encoders/swinv2.py @@ -1,13 +1,249 @@ import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint as checkpoint + from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import numpy as np + +from ..model_factory import ModelFactory + + +# ----------------------------------------------------------------------------- +# WindowAttention +# ----------------------------------------------------------------------------- +class WindowAttention(nn.Module): + """ + Window based multi-head self attention (W-MSA) module with + relative position bias. It supports both of shifted and + non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, + key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. + Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the + window in pre-training. + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + + self.dim = dim + + self.window_size = window_size # Wh, Ww + + self.pretrained_window_size = pretrained_window_size + + self.num_heads = num_heads + + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), + self.window_size[0], + dtype=torch.float32) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), + self.window_size[1], + dtype=torch.float32) + + # 1, 2*Wh-1, 2*Ww-1, 2 + relative_coords_table = torch.stack( + torch.meshgrid( + [relative_coords_h, + relative_coords_w])).permute(1, + 2, + 0).contiguous().unsqueeze(0) + + if pretrained_window_size[0] > 0: + + relative_coords_table[:, :, :, + 0] /= (pretrained_window_size[0] - 1) + + relative_coords_table[:, :, :, + 1] /= (pretrained_window_size[1] - 1) + + else: + + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + + relative_coords_table *= 8 # normalize to -8, 8 + + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) -from pytorch_caney.network.mlp import Mlp -from pytorch_caney.network.attention import WindowAttention + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside + # the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + + relative_coords = coords_flatten[:, :, None] - \ + coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + + relative_coords[:, :, 0] += self.window_size[0] - \ + 1 # shift to start from 0 + + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + + self.register_buffer("relative_position_index", + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + + if qkv_bias: + + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + + else: + + self.q_bias = None + self.v_bias = None + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) + or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like( + self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv[0], qkv[1], qkv[2] + + # cosine attention + attn = (F.normalize(q, dim=-1) @ + F.normalize(k, dim=-1).transpose(-2, -1)) + # logit_scale = torch.clamp( + # self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + logit_scale = torch.clamp(self.logit_scale, max=torch.log( + torch.tensor(1. / 0.01)).to(self.logit_scale.get_device())).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp( + self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = \ + relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], -1) + # Wh*Ww,Wh*Ww,nH + + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, ' \ + f'num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +# ----------------------------------------------------------------------------- +# Mlp +# ----------------------------------------------------------------------------- +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x +# ----------------------------------------------------------------------------- +# window_partition +# ----------------------------------------------------------------------------- def window_partition(x, window_size): """ Args: @@ -25,6 +261,9 @@ def window_partition(x, window_size): return windows +# ----------------------------------------------------------------------------- +# window_reverse +# ----------------------------------------------------------------------------- def window_reverse(windows, window_size, H, W): """ Args: @@ -43,6 +282,9 @@ def window_reverse(windows, window_size, H, W): return x +# ----------------------------------------------------------------------------- +# SwinTransformerBlock +# ----------------------------------------------------------------------------- class SwinTransformerBlock(nn.Module): r""" Swin Transformer Block. @@ -198,6 +440,9 @@ def flops(self): return flops +# ----------------------------------------------------------------------------- +# PatchMerging +# ----------------------------------------------------------------------------- class PatchMerging(nn.Module): r""" Patch Merging Layer. @@ -248,6 +493,9 @@ def flops(self): return flops +# ----------------------------------------------------------------------------- +# BasicLayer +# ----------------------------------------------------------------------------- class BasicLayer(nn.Module): """ A basic Swin Transformer layer for one stage. @@ -318,7 +566,7 @@ def _extra_norm(index): def forward(self, x): for blk in self.blocks: if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x) + x = checkpoint.checkpoint(blk, x, use_reentrant=False) else: x = blk(x) if self.downsample is not None: @@ -346,6 +594,9 @@ def _init_respostnorm(self): nn.init.constant_(blk.norm2.weight, 0) +# ----------------------------------------------------------------------------- +# PatchEmbed +# ----------------------------------------------------------------------------- class PatchEmbed(nn.Module): r""" Image to Patch Embedding @@ -404,6 +655,10 @@ def flops(self): return flops +# ----------------------------------------------------------------------------- +# SwinTransformerV2 +# ----------------------------------------------------------------------------- +@ModelFactory.encoder("swinv2") class SwinTransformerV2(nn.Module): r""" Swin Transformer A PyTorch impl of : `Swin Transformer: Hierarchical @@ -579,8 +834,7 @@ def get_unet_feature(self, x): return feature def forward(self, x): - x = self.forward_features(x) - x = self.head(x) + x = self.extra_features(x)[-1] return x def flops(self): diff --git a/pytorch_caney/models/heads/__init__.py b/pytorch_caney/models/heads/__init__.py new file mode 100644 index 0000000..fcf4565 --- /dev/null +++ b/pytorch_caney/models/heads/__init__.py @@ -0,0 +1,4 @@ +from .segmentation_head import SegmentationHead + + +__all__ = [SegmentationHead] diff --git a/pytorch_caney/models/heads/segmentation_head.py b/pytorch_caney/models/heads/segmentation_head.py new file mode 100644 index 0000000..5561bac --- /dev/null +++ b/pytorch_caney/models/heads/segmentation_head.py @@ -0,0 +1,21 @@ +import torch.nn as nn + +from ..model_factory import ModelFactory + + +@ModelFactory.head("segmentation_head") +class SegmentationHead(nn.Module): + def __init__(self, decoder_channels=128, num_classes=4, + head_dropout=0.2, output_shape=(91, 40)): + super(SegmentationHead, self).__init__() + self.head = nn.Sequential( + nn.Conv2d(decoder_channels, num_classes, + kernel_size=3, stride=1, padding=1), + nn.Dropout(head_dropout), + nn.Upsample(size=output_shape, + mode='bilinear', + align_corners=False) + ) + + def forward(self, x): + return self.head(x) diff --git a/pytorch_caney/models/mim/mim.py b/pytorch_caney/models/mim.py similarity index 88% rename from pytorch_caney/models/mim/mim.py rename to pytorch_caney/models/mim.py index aaf69c7..2d421cf 100644 --- a/pytorch_caney/models/mim/mim.py +++ b/pytorch_caney/models/mim.py @@ -3,9 +3,12 @@ import torch.nn.functional as F from timm.models.layers import trunc_normal_ -from ..swinv2_model import SwinTransformerV2 +from .encoders.swinv2 import SwinTransformerV2 +# ----------------------------------------------------------------------------- +# SwinTransformerV2ForMiM +# ----------------------------------------------------------------------------- class SwinTransformerV2ForSimMIM(SwinTransformerV2): def __init__(self, **kwargs): super().__init__(**kwargs) @@ -44,6 +47,9 @@ def no_weight_decay(self): return super().no_weight_decay() | {'mask_token'} +# ----------------------------------------------------------------------------- +# MiMModel +# ----------------------------------------------------------------------------- class MiMModel(nn.Module): """ Masked-Image-Modeling model @@ -94,6 +100,9 @@ def no_weight_decay_keywords(self): return {} +# ----------------------------------------------------------------------------- +# build_mim_model +# ----------------------------------------------------------------------------- def build_mim_model(config): """Builds the masked-image-modeling model. diff --git a/pytorch_caney/models/model_factory.py b/pytorch_caney/models/model_factory.py new file mode 100644 index 0000000..e888ae4 --- /dev/null +++ b/pytorch_caney/models/model_factory.py @@ -0,0 +1,85 @@ +# ----------------------------------------------------------------------------- +# ModelFactory +# ----------------------------------------------------------------------------- +class ModelFactory: + # Class-level registries + backbones = {} + decoders = {} + heads = {} + + # ------------------------------------------------------------------------- + # register_backbone + # ------------------------------------------------------------------------- + @classmethod + def register_backbone(cls, name: str, backbone_cls): + """Register a new backbone in the factory.""" + cls.backbones[name] = backbone_cls + + # ------------------------------------------------------------------------- + # register_decoder + # ------------------------------------------------------------------------- + @classmethod + def register_decoder(cls, name: str, decoder_cls): + """Register a new decoder in the factory.""" + cls.decoders[name] = decoder_cls + + # ------------------------------------------------------------------------- + # register_head + # ------------------------------------------------------------------------- + @classmethod + def register_head(cls, name: str, head_cls): + """Register a new head in the factory.""" + cls.heads[name] = head_cls + + # ------------------------------------------------------------------------- + # get_component + # ------------------------------------------------------------------------- + @classmethod + def get_component(cls, component_type: str, name: str, **kwargs): + """Public method to retrieve and instantiate a component by type and name.""" # noqa: E501 + print(cls.backbones) + print(cls.decoders) + print(cls.heads) + registry = { + "encoder": cls.backbones, + "decoder": cls.decoders, + "head": cls.heads, + }.get(component_type) + + if registry is None or name not in registry: + raise ValueError(f"{component_type.capitalize()} '{name}' not found in registry.") # noqa: E501 + + return registry[name](**kwargs) + + # ------------------------------------------------------------------------- + # encoder + # ------------------------------------------------------------------------- + @classmethod + def encoder(cls, name): + """Class decorator for registering an encoder.""" + def decorator(encoder_cls): + cls.register_backbone(name, encoder_cls) + return encoder_cls + return decorator + + # ------------------------------------------------------------------------- + # decoder + # ------------------------------------------------------------------------- + @classmethod + def decoder(cls, name): + """Class decorator for registering a decoder.""" + def decorator(decoder_cls): + cls.register_decoder(name, decoder_cls) + return decoder_cls + return decorator + + # ------------------------------------------------------------------------- + # head + # ------------------------------------------------------------------------- + @classmethod + def head(cls, name): + """Class decorator for registering a head.""" + def decorator(head_cls): + cls.register_head(name, head_cls) + return head_cls + return decorator diff --git a/pytorch_caney/models/simmim/simmim.py b/pytorch_caney/models/simmim/simmim.py deleted file mode 100644 index b13cfca..0000000 --- a/pytorch_caney/models/simmim/simmim.py +++ /dev/null @@ -1,117 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from timm.models.layers import trunc_normal_ - -from ..swinv2_model import SwinTransformerV2 - - -class SwinTransformerV2ForSimMIM(SwinTransformerV2): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - assert self.num_classes == 0 - - self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - trunc_normal_(self.mask_token, mean=0., std=.02) - - def forward(self, x, mask): - x = self.patch_embed(x) - - assert mask is not None - B, L, _ = x.shape - - mask_tokens = self.mask_token.expand(B, L, -1) - w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) - x = x * (1. - w) + mask_tokens * w - - if self.ape: - x = x + self.absolute_pos_embed - x = self.pos_drop(x) - - for layer in self.layers: - x = layer(x) - x = self.norm(x) - - x = x.transpose(1, 2) - B, C, L = x.shape - H = W = int(L ** 0.5) - x = x.reshape(B, C, H, W) - return x - - @torch.jit.ignore - def no_weight_decay(self): - return super().no_weight_decay() | {'mask_token'} - - -class MiMModel(nn.Module): - def __init__(self, encoder, encoder_stride, in_chans, patch_size): - super().__init__() - self.encoder = encoder - self.encoder_stride = encoder_stride - self.in_chans = in_chans - self.patch_size = patch_size - self.decoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.num_features, - out_channels=self.encoder_stride ** 2 * self.in_chans, - kernel_size=1), - nn.PixelShuffle(self.encoder_stride), - ) - - # self.in_chans = self.encoder.in_chans - # self.patch_size = self.encoder.patch_size - - def forward(self, x, mask): - z = self.encoder(x, mask) - x_rec = self.decoder(z) - - mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave( - self.patch_size, 2).unsqueeze(1).contiguous() - loss_recon = F.l1_loss(x, x_rec, reduction='none') - loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans - return loss - - @torch.jit.ignore - def no_weight_decay(self): - if hasattr(self.encoder, 'no_weight_decay'): - return {'encoder.' + i for i in self.encoder.no_weight_decay()} - return {} - - @torch.jit.ignore - def no_weight_decay_keywords(self): - if hasattr(self.encoder, 'no_weight_decay_keywords'): - return {'encoder.' + i for i in - self.encoder.no_weight_decay_keywords()} - return {} - - -def build_mim_model(config): - model_type = config.MODEL.TYPE - if model_type == 'swinv2': - encoder = SwinTransformerV2ForSimMIM( - img_size=config.DATA.IMG_SIZE, - patch_size=config.MODEL.SWINV2.PATCH_SIZE, - in_chans=config.MODEL.SWINV2.IN_CHANS, - num_classes=0, - embed_dim=config.MODEL.SWINV2.EMBED_DIM, - depths=config.MODEL.SWINV2.DEPTHS, - num_heads=config.MODEL.SWINV2.NUM_HEADS, - window_size=config.MODEL.SWINV2.WINDOW_SIZE, - mlp_ratio=config.MODEL.SWINV2.MLP_RATIO, - qkv_bias=config.MODEL.SWINV2.QKV_BIAS, - drop_rate=config.MODEL.DROP_RATE, - drop_path_rate=config.MODEL.DROP_PATH_RATE, - ape=config.MODEL.SWINV2.APE, - patch_norm=config.MODEL.SWINV2.PATCH_NORM, - use_checkpoint=config.TRAIN.USE_CHECKPOINT) - encoder_stride = 32 - in_chans = config.MODEL.SWINV2.IN_CHANS - patch_size = config.MODEL.SWINV2.PATCH_SIZE - else: - raise NotImplementedError(f"Unknown pre-train model: {model_type}") - - model = MiMModel(encoder=encoder, encoder_stride=encoder_stride, - in_chans=in_chans, patch_size=patch_size) - - return model diff --git a/pytorch_caney/models/unet_model.py b/pytorch_caney/models/unet_model.py deleted file mode 100755 index b8e0779..0000000 --- a/pytorch_caney/models/unet_model.py +++ /dev/null @@ -1,187 +0,0 @@ -from pl_bolts.models.vision.unet import UNet -from pytorch_lightning import LightningModule -from pytorch_lightning.utilities.cli import MODEL_REGISTRY - -import torch -from torch.nn import functional as F -from torchmetrics import MetricCollection, Accuracy, IoU - - -# ------------------------------------------------------------------------------- -# class UNet -# This class performs training and classification of satellite imagery using a -# UNet CNN. -# ------------------------------------------------------------------------------- -@MODEL_REGISTRY -class UNetSegmentation(LightningModule): - - # --------------------------------------------------------------------------- - # __init__ - # --------------------------------------------------------------------------- - def __init__( - self, - input_channels: int = 4, - num_classes: int = 19, - num_layers: int = 5, - features_start: int = 64, - bilinear: bool = False, - ): - super().__init__() - - self.input_channels = input_channels - self.num_classes = num_classes - self.num_layers = num_layers - self.features_start = features_start - self.bilinear = bilinear - - self.net = UNet( - input_channels=self.input_channels, - num_classes=num_classes, - num_layers=self.num_layers, - features_start=self.features_start, - bilinear=self.bilinear, - ) - - metrics = MetricCollection( - [ - Accuracy(), IoU(num_classes=self.num_classes) - ] - ) - self.train_metrics = metrics.clone(prefix='train_') - self.val_metrics = metrics.clone(prefix='val_') - - # --------------------------------------------------------------------------- - # model methods - # --------------------------------------------------------------------------- - def forward(self, x): - return self.net(x) - - def training_step(self, batch, batch_nb): - img, mask = batch - img, mask = img.float(), mask.long() - - # Forward step, calculate logits and loss - logits = self(img) - # loss_val = F.cross_entropy(logits, mask) - - # Get target tensor from logits for metrics, calculate metrics - probs = torch.nn.functional.softmax(logits, dim=1) - probs = torch.argmax(probs, dim=1) - - # metrics_train = self.train_metrics(probs, mask) - # log_dict = {"train_loss": loss_val.detach()} - # return {"loss": loss_val, "log": log_dict, "progress_bar": log_dict} - # return { - # "loss": loss_val, "train_acc": metrics_train['train_Accuracy'], - # "train_iou": metrics_train['train_IoU'] - # } - - tensorboard_logs = self.train_metrics(probs, mask) - tensorboard_logs['loss'] = F.cross_entropy(logits, mask) - # tensorboard_logs['lr'] = self._get_current_lr() - - self.log( - 'acc', tensorboard_logs['train_Accuracy'], - sync_dist=True, prog_bar=True - ) - self.log( - 'iou', tensorboard_logs['train_IoU'], - sync_dist=True, prog_bar=True - ) - return tensorboard_logs - - def training_epoch_end(self, outputs): - pass - - # Get average metrics from multi-GPU batch sources - # loss_val = torch.stack([x["loss"] for x in outputs]).mean() - # acc_train = torch.stack([x["train_acc"] for x in outputs]).mean() - # iou_train = torch.stack([x["train_iou"] for x in outputs]).mean() - - # tensorboard_logs = self.train_metrics(probs, mask) - # tensorboard_logs['loss'] = F.cross_entropy(logits, mask) - # tensorboard_logs['lr'] = self._get_current_lr() - - # self.log( - # 'acc', tensorboard_logs['train_Accuracy'], - # sync_dist=True, prog_bar=True - # ) - # self.log( - # 'iou', tensorboard_logs['train_IoU'], - # sync_dist=True, prog_bar=True - # ) - # # Send output to logger - # self.log( - # "loss", loss_val, on_epoch=True, prog_bar=True, logger=True) - # self.log( - # "train_acc", acc_train, - # on_epoch=True, prog_bar=True, logger=True) - # self.log( - # "train_iou", iou_train, - # on_epoch=True, prog_bar=True, logger=True) - # return tensorboard_logs - - def validation_step(self, batch, batch_idx): - - # Get data, change type for validation - img, mask = batch - img, mask = img.float(), mask.long() - - # Forward step, calculate logits and loss - logits = self(img) - # loss_val = F.cross_entropy(logits, mask) - - # Get target tensor from logits for metrics, calculate metrics - probs = torch.nn.functional.softmax(logits, dim=1) - probs = torch.argmax(probs, dim=1) - # metrics_val = self.val_metrics(probs, mask) - - # return { - # "val_loss": loss_val, "val_acc": metrics_val['val_Accuracy'], - # "val_iou": metrics_val['val_IoU'] - # } - tensorboard_logs = self.val_metrics(probs, mask) - tensorboard_logs['val_loss'] = F.cross_entropy(logits, mask) - - self.log( - 'val_loss', tensorboard_logs['val_loss'], - sync_dist=True, prog_bar=True - ) - self.log( - 'val_acc', tensorboard_logs['val_Accuracy'], - sync_dist=True, prog_bar=True - ) - self.log( - 'val_iou', tensorboard_logs['val_IoU'], - sync_dist=True, prog_bar=True - ) - return tensorboard_logs - - # def validation_epoch_end(self, outputs): - - # # Get average metrics from multi-GPU batch sources - # loss_val = torch.stack([x["val_loss"] for x in outputs]).mean() - # acc_val = torch.stack([x["val_acc"] for x in outputs]).mean() - # iou_val = torch.stack([x["val_iou"] for x in outputs]).mean() - - # # Send output to logger - # self.log( - # "val_loss", torch.mean(self.all_gather(loss_val)), - # on_epoch=True, prog_bar=True, logger=True) - # self.log( - # "val_acc", torch.mean(self.all_gather(acc_val)), - # on_epoch=True, prog_bar=True, logger=True) - # self.log( - # "val_iou", torch.mean(self.all_gather(iou_val)), - # on_epoch=True, prog_bar=True, logger=True) - - # def configure_optimizers(self): - # opt = torch.optim.Adam(self.net.parameters(), lr=self.lr) - # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) - # return [opt], [sch] - - def test_step(self, batch, batch_idx, dataloader_idx=0): - return self(batch) - - def predict_step(self, batch, batch_idx, dataloader_idx=0): - return self(batch) diff --git a/pytorch_caney/models/unet_swin_model.py b/pytorch_caney/models/unet_swin_model.py deleted file mode 100644 index 63ede6a..0000000 --- a/pytorch_caney/models/unet_swin_model.py +++ /dev/null @@ -1,44 +0,0 @@ -from .decoders.unet_decoder import UnetDecoder -from .decoders.unet_decoder import SegmentationHead - -import torch.nn as nn - -from typing import Tuple - - -class unet_swin(nn.Module): - """ - Pytorch encoder-decoder model which pairs - an encoder (swin) with the attention unet - decoder. - """ - - FEATURE_CHANNELS: Tuple[int] = (3, 704, 1408, 2816, 2816) - DECODE_CHANNELS: Tuple[int] = (512, 256, 128, 64) - IN_CHANNELS: int = 64 - N_BLOCKS: int = 4 - KERNEL_SIZE: int = 3 - UPSAMPLING: int = 4 - - def __init__(self, encoder, num_classes=9): - super().__init__() - - self.encoder = encoder - - self.decoder = UnetDecoder( - encoder_channels=self.FEATURE_CHANNELS, - n_blocks=self.N_BLOCKS, - decoder_channels=self.DECODE_CHANNELS, - attention_type=None) - self.segmentation_head = SegmentationHead( - in_channels=self.IN_CHANNELS, - out_channels=num_classes, - kernel_size=self.KERNEL_SIZE, - upsampling=self.UPSAMPLING) - - def forward(self, x): - encoder_featrue = self.encoder.get_unet_feature(x) - decoder_output = self.decoder(*encoder_featrue) - masks = self.segmentation_head(decoder_output) - - return masks diff --git a/pytorch_caney/network/attention.py b/pytorch_caney/network/attention.py deleted file mode 100644 index 99f216f..0000000 --- a/pytorch_caney/network/attention.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np - - -class WindowAttention(nn.Module): - """ - Window based multi-head self attention (W-MSA) module with - relative position bias. It supports both of shifted and - non-shifted window. - - Args: - dim (int): Number of input channels. - window_size (tuple[int]): The height and width of the window. - num_heads (int): Number of attention heads. - qkv_bias (bool, optional): If True, add a learnable bias to query, - key, value. Default: True - attn_drop (float, optional): Dropout ratio of attention weight. - Default: 0.0 - proj_drop (float, optional): Dropout ratio of output. Default: 0.0 - pretrained_window_size (tuple[int]): The height and width of the - window in pre-training. - """ - - def __init__(self, - dim, - window_size, - num_heads, - qkv_bias=True, - attn_drop=0., - proj_drop=0., - pretrained_window_size=[0, 0]): - - super().__init__() - - self.dim = dim - - self.window_size = window_size # Wh, Ww - - self.pretrained_window_size = pretrained_window_size - - self.num_heads = num_heads - - self.logit_scale = nn.Parameter( - torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) - - # mlp to generate continuous relative position bias - self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), - nn.ReLU(inplace=True), - nn.Linear(512, num_heads, bias=False)) - - # get relative_coords_table - relative_coords_h = torch.arange( - -(self.window_size[0] - 1), - self.window_size[0], - dtype=torch.float32) - relative_coords_w = torch.arange( - -(self.window_size[1] - 1), - self.window_size[1], - dtype=torch.float32) - - # 1, 2*Wh-1, 2*Ww-1, 2 - relative_coords_table = torch.stack( - torch.meshgrid( - [relative_coords_h, - relative_coords_w])).permute(1, - 2, - 0).contiguous().unsqueeze(0) - - if pretrained_window_size[0] > 0: - - relative_coords_table[:, :, :, - 0] /= (pretrained_window_size[0] - 1) - - relative_coords_table[:, :, :, - 1] /= (pretrained_window_size[1] - 1) - - else: - - relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) - - relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) - - relative_coords_table *= 8 # normalize to -8, 8 - - relative_coords_table = torch.sign(relative_coords_table) * torch.log2( - torch.abs(relative_coords_table) + 1.0) / np.log2(8) - - self.register_buffer("relative_coords_table", relative_coords_table) - - # get pair-wise relative position index for each token inside - # the window - coords_h = torch.arange(self.window_size[0]) - coords_w = torch.arange(self.window_size[1]) - - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - - relative_coords = coords_flatten[:, :, None] - \ - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww - - relative_coords = relative_coords.permute( - 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 - - relative_coords[:, :, 0] += self.window_size[0] - \ - 1 # shift to start from 0 - - relative_coords[:, :, 1] += self.window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 - - relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - - self.register_buffer("relative_position_index", - relative_position_index) - - self.qkv = nn.Linear(dim, dim * 3, bias=False) - - if qkv_bias: - - self.q_bias = nn.Parameter(torch.zeros(dim)) - self.v_bias = nn.Parameter(torch.zeros(dim)) - - else: - - self.q_bias = None - self.v_bias = None - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - self.softmax = nn.Softmax(dim=-1) - - def forward(self, x, mask=None): - """ - Args: - x: input features with shape of (num_windows*B, N, C) - mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) - or None - """ - B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like( - self.v_bias, requires_grad=False), self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - # make torchscript happy (cannot use tensor as tuple) - q, k, v = qkv[0], qkv[1], qkv[2] - - # cosine attention - attn = (F.normalize(q, dim=-1) @ - F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp( - self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() - # logit_scale = torch.clamp(self.logit_scale, max=torch.log( - # torch.tensor(1. / 0.01)).to(self.logit_scale.get_device())).exp() - attn = attn * logit_scale - - relative_position_bias_table = self.cpb_mlp( - self.relative_coords_table).view(-1, self.num_heads) - relative_position_bias = \ - relative_position_bias_table[ - self.relative_position_index.view(-1)].view( - self.window_size[0] * self.window_size[1], - self.window_size[0] * self.window_size[1], -1) - # Wh*Ww,Wh*Ww,nH - - relative_position_bias = relative_position_bias.permute( - 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww - - relative_position_bias = 16 * torch.sigmoid(relative_position_bias) - attn = attn + relative_position_bias.unsqueeze(0) - - if mask is not None: - nW = mask.shape[0] - attn = attn.view(B_ // nW, nW, self.num_heads, N, - N) + mask.unsqueeze(1).unsqueeze(0) - attn = attn.view(-1, self.num_heads, N, N) - attn = self.softmax(attn) - - else: - attn = self.softmax(attn) - - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B_, N, C) - x = self.proj(x) - x = self.proj_drop(x) - return x - - def extra_repr(self) -> str: - return f'dim={self.dim}, window_size={self.window_size}, ' \ - f'pretrained_window_size={self.pretrained_window_size}, ' \ - f'num_heads={self.num_heads}' - - def flops(self, N): - # calculate flops for 1 window with token length of N - flops = 0 - # qkv = self.qkv(x) - flops += N * self.dim * 3 * self.dim - # attn = (q @ k.transpose(-2, -1)) - flops += self.num_heads * N * (self.dim // self.num_heads) * N - # x = (attn @ v) - flops += self.num_heads * N * N * (self.dim // self.num_heads) - # x = self.proj(x) - flops += N * self.dim * self.dim - return flops diff --git a/pytorch_caney/network/mlp.py b/pytorch_caney/network/mlp.py deleted file mode 100644 index d154808..0000000 --- a/pytorch_caney/network/mlp.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch.nn as nn - - -class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, - out_features=None, act_layer=nn.GELU, drop=0.): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.drop(x) - x = self.fc2(x) - x = self.drop(x) - return x diff --git a/pytorch_caney/data/datasets/classification_dataset.py b/pytorch_caney/optimizers/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/data/datasets/classification_dataset.py rename to pytorch_caney/optimizers/__init__.py diff --git a/pytorch_caney/optimizers/build.py b/pytorch_caney/optimizers/build.py new file mode 100644 index 0000000..c5d6c17 --- /dev/null +++ b/pytorch_caney/optimizers/build.py @@ -0,0 +1,248 @@ +from functools import partial + +import torch +import deepspeed + +from pytorch_caney.optimizers.lamb import Lamb + + +OPTIMIZERS = { + 'adamw': torch.optim.AdamW, + 'lamb': Lamb, + 'fusedlamb': deepspeed.ops.lamb.FusedLamb, + 'fusedadamw': deepspeed.ops.adam.FusedAdam, +} + + +# ----------------------------------------------------------------------------- +# get_optimizer_from_dict +# ----------------------------------------------------------------------------- +def get_optimizer_from_dict(optimizer_name, config): + """Gets the proper optimizer given an optimizer name. + + Args: + optimizer_name (str): name of the optimizer + config: config object + + Raises: + KeyError: thrown if loss key is not present in dict + + Returns: + loss: pytorch optimizer + """ + + try: + + optimizer_to_use = OPTIMIZERS[optimizer_name.lower()] + + except KeyError: + + error_msg = f"{optimizer_name} is not an implemented optimizer" + + error_msg = f"{error_msg}. Available optimizer functions: {OPTIMIZERS.keys()}" # noqa: E501 + + raise KeyError(error_msg) + + return optimizer_to_use + + +# ----------------------------------------------------------------------------- +# build_optimizer +# ----------------------------------------------------------------------------- +def build_optimizer(config, model, is_pretrain=False, logger=None): + """ + Build optimizer, set weight decay of normalization to 0 by default. + AdamW only. + """ + if logger: + logger.info('>>>>>>>>>> Build Optimizer') + + skip = {} + skip_keywords = {} + optimizer_name = config.TRAIN.OPTIMIZER.NAME + + if logger: + logger.info(f'Building {optimizer_name}') + + optimizer_to_use = get_optimizer_from_dict(optimizer_name, config) + + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + + if hasattr(model, 'no_weight_decay_keywords'): + skip_keywords = model.no_weight_decay_keywords() + + if is_pretrain: + parameters = get_pretrain_param_groups(model, skip, skip_keywords) + + else: + depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ + else config.MODEL.SWINV2.DEPTHS + + num_layers = sum(depths) + + get_layer_func = partial(get_swin_layer, + num_layers=num_layers + 2, + depths=depths) + + scales = list(config.TRAIN.LAYER_DECAY ** i for i in + reversed(range(num_layers + 2))) + + parameters = get_finetune_param_groups(model, + config.TRAIN.BASE_LR, + config.TRAIN.WEIGHT_DECAY, + get_layer_func, + scales, + skip, + skip_keywords) + + optimizer = None + optimizer = optimizer_to_use(parameters, + eps=config.TRAIN.OPTIMIZER.EPS, + betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, + weight_decay=config.TRAIN.WEIGHT_DECAY) + if logger: + logger.info(optimizer) + + return optimizer + + +# ----------------------------------------------------------------------------- +# get_finetune_param_groups +# ----------------------------------------------------------------------------- +def get_finetune_param_groups(model, + lr, + weight_decay, + get_layer_func, + scales, + skip_list=(), + skip_keywords=()): + + parameter_group_names = {} + parameter_group_vars = {} + + for name, param in model.named_parameters(): + + if not param.requires_grad: + continue + + if len(param.shape) == 1 or name.endswith(".bias") \ + or (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + group_name = "no_decay" + this_weight_decay = 0. + + else: + group_name = "decay" + this_weight_decay = weight_decay + + if get_layer_func is not None: + layer_id = get_layer_func(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + else: + layer_id = None + + if group_name not in parameter_group_names: + if scales is not None: + scale = scales[layer_id] + else: + scale = 1. + + parameter_group_names[group_name] = { + "group_name": group_name, + "weight_decay": this_weight_decay, + "params": [], + "lr": lr * scale, + "lr_scale": scale, + } + + parameter_group_vars[group_name] = { + "group_name": group_name, + "weight_decay": this_weight_decay, + "params": [], + "lr": lr * scale, + "lr_scale": scale + } + + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + return list(parameter_group_vars.values()) + + +# ----------------------------------------------------------------------------- +# check_keywords_in_name +# ----------------------------------------------------------------------------- +def check_keywords_in_name(name, keywords=()): + isin = False + for keyword in keywords: + if keyword in name: + isin = True + return isin + + +# ----------------------------------------------------------------------------- +# get_pretrain_param_groups +# ----------------------------------------------------------------------------- +def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): + + has_decay = [] + no_decay = [] + has_decay_name = [] + no_decay_name = [] + + for name, param in model.named_parameters(): + + if not param.requires_grad: + + continue + + if len(param.shape) == 1 or name.endswith(".bias") or \ + (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + + no_decay.append(param) + + no_decay_name.append(name) + + else: + + has_decay.append(param) + + has_decay_name.append(name) + + return [{'params': has_decay}, + {'params': no_decay, 'weight_decay': 0.}] + + +# ----------------------------------------------------------------------------- +# get_swin_layer +# ----------------------------------------------------------------------------- +def get_swin_layer(name, num_layers, depths): + + if name in ("mask_token"): + + return 0 + + elif name.startswith("patch_embed"): + + return 0 + + elif name.startswith("layers"): + + layer_id = int(name.split('.')[1]) + + block_id = name.split('.')[3] + + if block_id == 'reduction' or block_id == 'norm': + + return sum(depths[:layer_id + 1]) + + layer_id = sum(depths[:layer_id]) + int(block_id) + + return layer_id + 1 + + else: + + return num_layers - 1 diff --git a/pytorch_caney/optimizers/lamb.py b/pytorch_caney/optimizers/lamb.py new file mode 100644 index 0000000..e51466d --- /dev/null +++ b/pytorch_caney/optimizers/lamb.py @@ -0,0 +1,225 @@ +""" PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb + +This optimizer code was adapted from the following (starting with latest) +* https://github.com/HabanaAI/Model-References/blob/2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py # noqa: E501 +* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py # noqa: E501 +* https://github.com/cybertronai/pytorch-lamb + +Use FusedLamb if you can (GPU). The reason for including this variant of Lamb is to have a version that is # noqa: E501 +similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or cannot install/use APEX. # noqa: E501 + +In addition to some cleanup, this Lamb impl has been modified to support PyTorch XLA and has been tested on TPU. # noqa: E501 + +Original copyrights for above sources are below. + +Modifications Copyright 2021 Ross Wightman +""" +# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. + +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all # noqa: E501 +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +import collections +import math + +import torch +from torch.optim import Optimizer + +from torch.utils.tensorboard import SummaryWriter + + +def log_lamb_rs(optimizer: Optimizer, + event_writer: SummaryWriter, + token_count: int): + """Log a histogram of trust ratio scalars in across layers.""" + results = collections.defaultdict(list) + for group in optimizer.param_groups: + for p in group['params']: + state = optimizer.state[p] + for i in ('weight_norm', 'adam_norm', 'trust_ratio'): + if i in state: + results[i].append(state[i]) + + for k, v in results.items(): + event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) + + +class Lamb(Optimizer): + """Implements a pure pytorch variant of FuseLAMB (NvLamb variant) optimizer from apex.optimizers.FusedLAMB # noqa: E501 + reference: https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/LanguageModeling/Transformer-XL/pytorch/lamb.py # noqa: E501 + + LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. # noqa: E501 + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups. # noqa: E501 + lr (float, optional): learning rate. (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability. (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + grad_averaging (bool, optional): whether apply (1-beta2) to grad when + calculating running averages of gradient. (default: True) + max_grad_norm (float, optional): value used to clip global grad norm (default: 1.0) + trust_clip (bool): enable LAMBC trust ratio clipping (default: False) + always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 + weight decay parameter (default: False) + + .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + .. _On the Convergence of Adam and Beyond: + https://openreview.net/forum?id=ryQu7f-RZ + """ + + def __init__( + self, params, lr=1e-3, bias_correction=True, + betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01, + grad_averaging=True, max_grad_norm=1.0, + trust_clip=False, always_adapt=False): + + defaults = dict( + lr=lr, bias_correction=bias_correction, + betas=betas, eps=eps, weight_decay=weight_decay, + grad_averaging=grad_averaging, max_grad_norm=max_grad_norm, + trust_clip=trust_clip, always_adapt=always_adapt) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + device = self.param_groups[0]['params'][0].device + # because torch.where doesn't handle scalars correctly + one_tensor = torch.tensor(1.0, device=device) + global_grad_norm = torch.zeros(1, device=device) + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.is_sparse: + raise RuntimeError( + 'Lamb does not support sparse gradients, consider SparseAdam instad.') # noqa: E501 + global_grad_norm.add_(grad.pow(2).sum()) + + global_grad_norm = torch.sqrt(global_grad_norm) + # FIXME it'd be nice to remove explicit tensor conversion of scalars when torch.where promotes # noqa: E501 + # scalar types properly https://github.com/pytorch/pytorch/issues/9190 # noqa: E501 + max_grad_norm = torch.tensor(self.defaults['max_grad_norm'], + device=device) + clip_global_grad_norm = torch.where( + global_grad_norm > max_grad_norm, + global_grad_norm / max_grad_norm, + one_tensor) + + for group in self.param_groups: + bias_correction = 1 if group['bias_correction'] else 0 + beta1, beta2 = group['betas'] + grad_averaging = 1 if group['grad_averaging'] else 0 + beta3 = 1 - beta1 if grad_averaging else 1.0 + + # assume same step across group now to simplify things + # per parameter step can be easily support by making it tensor, or pass list into kernel # noqa: E501 + if 'step' in group: + group['step'] += 1 + else: + group['step'] = 1 + + if bias_correction: + bias_correction1 = 1 - beta1 ** group['step'] + bias_correction2 = 1 - beta2 ** group['step'] + else: + bias_correction1, bias_correction2 = 1.0, 1.0 + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.div_(clip_global_grad_norm) + state = self.state[p] + + # State initialization + if len(state) == 0: + # Exponential moving average of gradient valuesa + state['exp_avg'] = torch.zeros_like(p) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + + # Decay the first and second moment running average coefficient + exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # v_t # noqa: E501 + + denom = (exp_avg_sq.sqrt() / + math.sqrt(bias_correction2)).add_(group['eps']) + update = (exp_avg / bias_correction1).div_(denom) + + weight_decay = group['weight_decay'] + if weight_decay != 0: + update.add_(p, alpha=weight_decay) + + if weight_decay != 0 or group['always_adapt']: + # Layer-wise LR adaptation. By default, skip adaptation on parameters that are # noqa: E501 + # excluded from weight decay, unless always_adapt == True, then always enabled. # noqa: E501 + w_norm = p.norm(2.0) + g_norm = update.norm(2.0) + # FIXME nested where required since logical and/or not working in PT XLA # noqa: E501 + trust_ratio = torch.where( + w_norm > 0, + torch.where(g_norm > 0, w_norm / g_norm, one_tensor), + one_tensor, + ) + if group['trust_clip']: + # LAMBC trust clipping, upper bound fixed at one + trust_ratio = torch.minimum(trust_ratio, one_tensor) + + state['weight_norm'] = w_norm + state['adam_norm'] = g_norm + state['trust_ratio'] = trust_ratio + + update.mul_(trust_ratio) + + p.add_(update, alpha=-group['lr']) + + return loss diff --git a/pytorch_caney/pipelines/__init__.py b/pytorch_caney/pipelines/__init__.py new file mode 100644 index 0000000..911c274 --- /dev/null +++ b/pytorch_caney/pipelines/__init__.py @@ -0,0 +1,12 @@ +from .satvision_toa_pretrain_pipeline import SatVisionToaPretrain +from .three_d_cloud_pipeline import ThreeDCloudTask + + +PIPELINES = { + 'satvisiontoapretrain': SatVisionToaPretrain, + '3dcloud': ThreeDCloudTask +} + + +def get_available_pipelines(): + return {name: cls for name, cls in PIPELINES.items()} diff --git a/pytorch_caney/pipelines/finetuning/finetune.py b/pytorch_caney/pipelines/finetuning/finetune.py deleted file mode 100644 index 72ade94..0000000 --- a/pytorch_caney/pipelines/finetuning/finetune.py +++ /dev/null @@ -1,438 +0,0 @@ -from pytorch_caney.models.build import build_model - -from pytorch_caney.data.datamodules.finetune_datamodule \ - import build_finetune_dataloaders - -from pytorch_caney.training.mim_utils \ - import build_optimizer, save_checkpoint, reduce_tensor - -from pytorch_caney.config import get_config -from pytorch_caney.loss.build import build_loss -from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr -from pytorch_caney.ptc_logging import create_logger -from pytorch_caney.training.mim_utils import get_grad_norm - -import argparse -import datetime -import joblib -import numpy as np -import os -import time - -import torch -import torch.cuda.amp as amp -import torch.backends.cudnn as cudnn -import torch.distributed as dist - -from timm.utils import AverageMeter - - -def parse_args(): - """ - Parse command-line arguments - """ - - parser = argparse.ArgumentParser( - 'pytorch-caney finetuning', - add_help=False) - - parser.add_argument( - '--cfg', - type=str, - required=True, - metavar="FILE", - help='path to config file') - - parser.add_argument( - "--data-paths", - nargs='+', - required=True, - help="paths where dataset is stored") - - parser.add_argument( - '--dataset', - type=str, - required=True, - help='Dataset to use') - - parser.add_argument( - '--pretrained', - type=str, - help='path to pre-trained model') - - parser.add_argument( - '--batch-size', - type=int, - help="batch size for single GPU") - - parser.add_argument( - '--resume', - help='resume from checkpoint') - - parser.add_argument( - '--accumulation-steps', - type=int, - help="gradient accumulation steps") - - parser.add_argument( - '--use-checkpoint', - action='store_true', - help="whether to use gradient checkpointing to save memory") - - parser.add_argument( - '--enable-amp', - action='store_true') - - parser.add_argument( - '--disable-amp', - action='store_false', - dest='enable_amp') - - parser.set_defaults(enable_amp=True) - - parser.add_argument( - '--output', - default='output', - type=str, - metavar='PATH', - help='root of output folder, the full path is ' + - '// (default: output)') - - parser.add_argument( - '--tag', - help='tag of experiment') - - args = parser.parse_args() - - config = get_config(args) - - return args, config - - -def train(config, - dataloader_train, - dataloader_val, - model, - model_wo_ddp, - optimizer, - lr_scheduler, - scaler, - criterion): - """ - Start fine-tuning a specific model and dataset. - - Args: - config: config object - dataloader_train: training pytorch dataloader - dataloader_val: validation pytorch dataloader - model: model to pre-train - model_wo_ddp: model to pre-train that is not the DDP version - optimizer: pytorch optimizer - lr_scheduler: learning-rate scheduler - scaler: loss scaler - criterion: loss function to use for fine-tuning - """ - - loss = validate(config, model, dataloader_val, criterion) - - logger.info(f'Model validation loss: {loss:.3f}%') - - logger.info("Start fine-tuning") - - start_time = time.time() - - for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): - - dataloader_train.sampler.set_epoch(epoch) - - execute_one_epoch(config, model, dataloader_train, - optimizer, criterion, epoch, lr_scheduler, scaler) - - loss = validate(config, model, dataloader_val, criterion) - - logger.info(f'Model validation loss: {loss:.3f}%') - - if dist.get_rank() == 0 and \ - (epoch % config.SAVE_FREQ == 0 or - epoch == (config.TRAIN.EPOCHS - 1)): - - save_checkpoint(config, epoch, model_wo_ddp, 0., - optimizer, lr_scheduler, scaler, logger) - - total_time = time.time() - start_time - - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - - logger.info('Training time {}'.format(total_time_str)) - - -def execute_one_epoch(config, - model, - dataloader, - optimizer, - criterion, - epoch, - lr_scheduler, - scaler): - """ - Execute training iterations on a single epoch. - - Args: - config: config object - model: model to pre-train - dataloader: dataloader to use - optimizer: pytorch optimizer - epoch: int epoch number - lr_scheduler: learning-rate scheduler - scaler: loss scaler - """ - model.train() - - optimizer.zero_grad() - - num_steps = len(dataloader) - - # Set up logging meters - batch_time = AverageMeter() - data_time = AverageMeter() - loss_meter = AverageMeter() - norm_meter = AverageMeter() - loss_scale_meter = AverageMeter() - - start = time.time() - end = time.time() - for idx, (samples, targets) in enumerate(dataloader): - - data_time.update(time.time() - start) - - samples = samples.cuda(non_blocking=True) - targets = targets.cuda(non_blocking=True) - - samples = samples.to(torch.bfloat16) - - with amp.autocast(enabled=config.ENABLE_AMP): - logits = model(samples) - - loss = criterion(logits, targets) - loss = loss / config.TRAIN.ACCUMULATION_STEPS - - scaler.scale(loss).backward() - - grad_norm = get_grad_norm(model.parameters()) - - if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: - optimizer.zero_grad() - lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS) - - loss_scale_value = scaler.state_dict()["scale"] - - torch.cuda.synchronize() - - loss_meter.update(loss.item(), targets.size(0)) - norm_meter.update(grad_norm) - loss_scale_meter.update(loss_scale_value) - batch_time.update(time.time() - end) - end = time.time() - - if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] - memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) - etas = batch_time.avg * (num_steps - idx) - logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {loss_scale_meter.val:.4f}' + - f' ({loss_scale_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') - - epoch_time = time.time() - start - logger.info( - f"EPOCH {epoch} training takes " + - f"{datetime.timedelta(seconds=int(epoch_time))}") - - -@torch.no_grad() -def validate(config, model, dataloader, criterion): - """Validation function which given a model and validation loader - performs a validation run and returns the average loss according - to the criterion. - - Args: - config: config object - model: pytorch model to validate - dataloader: pytorch validation loader - criterion: pytorch-friendly loss function - - Returns: - loss_meter.avg: average of the loss throught the validation - iterations - """ - - model.eval() - - batch_time = AverageMeter() - - loss_meter = AverageMeter() - - end = time.time() - - for idx, (images, target) in enumerate(dataloader): - - images = images.cuda(non_blocking=True) - - target = target.cuda(non_blocking=True) - - images = images.to(torch.bfloat16) - - # compute output - with amp.autocast(enabled=config.ENABLE_AMP): - output = model(images) - - # measure accuracy and record loss - loss = criterion(output, target) - - loss = reduce_tensor(loss) - - loss_meter.update(loss.item(), target.size(0)) - - # measure elapsed time - batch_time.update(time.time() - end) - - end = time.time() - - if idx % config.PRINT_FREQ == 0: - - memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) - - logger.info( - f'Test: [{idx}/{len(dataloader)}]\t' - f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'Mem {memory_used:.0f}MB') - - return loss_meter.avg - - -def main(config): - """ - Performs the main function of building model, loader, etc. and starts - training. - """ - - dataloader_train, dataloader_val = build_finetune_dataloaders( - config, logger) - - model = build_finetune_model(config, logger) - - optimizer = build_optimizer(config, - model, - is_pretrain=False, - logger=logger) - - model, model_wo_ddp = make_ddp(model) - - n_iter_per_epoch = len(dataloader_train) - - lr_scheduler = build_scheduler(config, optimizer, n_iter_per_epoch) - - scaler = amp.GradScaler() - - criterion = build_loss(config) - - train(config, - dataloader_train, - dataloader_val, - model, - model_wo_ddp, - optimizer, - lr_scheduler, - scaler, - criterion) - - -def build_finetune_model(config, logger): - - logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") - - # You can replace this section by simply calling your model class - # For example: model = UNet(parameters) - model = build_model(config, - pretrain=False, - pretrain_method='mim', - logger=logger) - - model.cuda() - - logger.info(str(model)) - - return model - - -def make_ddp(model): - - model = torch.nn.parallel.DistributedDataParallel( - model, - device_ids=[int(os.environ["RANK"])], - broadcast_buffers=False, - find_unused_parameters=True) - - model_without_ddp = model.module - - return model, model_without_ddp - - -def setup_rank_worldsize(): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: - rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) - print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") - else: - rank = -1 - world_size = -1 - return rank, world_size - - -def setup_distributed_processing(rank, world_size): - torch.cuda.set_device(int(os.environ["RANK"])) - torch.distributed.init_process_group( - backend='nccl', init_method='env://', world_size=world_size, rank=rank) - torch.distributed.barrier() - - -def setup_seeding(config): - seed = config.SEED + dist.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - - -if __name__ == '__main__': - _, config = parse_args() - - rank, world_size = setup_rank_worldsize() - - setup_distributed_processing(rank, world_size) - - setup_seeding(config) - - cudnn.benchmark = True - - os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, - dist_rank=dist.get_rank(), - name=f"{config.MODEL.NAME}") - - if dist.get_rank() == 0: - path = os.path.join(config.OUTPUT, "config.json") - with open(path, "w") as f: - f.write(config.dump()) - logger.info(f"Full config saved to {path}") - logger.info(config.dump()) - config_file_name = f'{config.TAG}.config.sav' - config_file_path = os.path.join(config.OUTPUT, config_file_name) - joblib.dump(config, config_file_path) - - main(config) diff --git a/pytorch_caney/pipelines/modis_segmentation.py b/pytorch_caney/pipelines/modis_segmentation.py deleted file mode 100644 index 2c58e9c..0000000 --- a/pytorch_caney/pipelines/modis_segmentation.py +++ /dev/null @@ -1,364 +0,0 @@ -from argparse import ArgumentParser, Namespace -import multiprocessing - -import torch -from torch import nn -import torch.nn.functional as F -from torch.utils.data import DataLoader - -import torchvision.transforms as transforms - -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from lightning.pytorch.loggers import CSVLogger - -from pytorch_caney.datasets.modis_dataset import MODISDataset -from pytorch_caney.utils import check_gpus_available - - -class UNet(nn.Module): - """ - Architecture based on U-Net: Convolutional Networks for - Biomedical Image Segmentation. - Link - https://arxiv.org/abs/1505.04597 - >>> UNet(num_classes=2, num_layers=3) \ - # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - UNet( - (layers): ModuleList( - (0): DoubleConv(...) - (1): Down(...) - (2): Down(...) - (3): Up(...) - (4): Up(...) - (5): Conv2d(64, 2, kernel_size=(1, 1), stride=(1, 1)) - ) - ) - """ - - def __init__( - self, - num_channels: int = 7, - num_classes: int = 19, - num_layers: int = 5, - features_start: int = 64, - bilinear: bool = False - ): - - super().__init__() - self.num_layers = num_layers - - layers = [DoubleConv(num_channels, features_start)] - - feats = features_start - for _ in range(num_layers - 1): - layers.append(Down(feats, feats * 2)) - feats *= 2 - - for _ in range(num_layers - 1): - layers.append(Up(feats, feats // 2, bilinear)) - feats //= 2 - - layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) - - self.layers = nn.ModuleList(layers) - - def forward(self, x): - xi = [self.layers[0](x)] - # Down path - for layer in self.layers[1: self.num_layers]: - xi.append(layer(xi[-1])) - # Up path - for i, layer in enumerate(self.layers[self.num_layers: -1]): - xi[-1] = layer(xi[-1], xi[-2 - i]) - return self.layers[-1](xi[-1]) - - -class DoubleConv(nn.Module): - """Double Convolution and BN and ReLU (3x3 conv -> BN -> ReLU) ** 2. - >>> DoubleConv(4, 4) \ - # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - DoubleConv( - (net): Sequential(...) - ) - """ - - def __init__(self, in_ch: int, out_ch: int): - super().__init__() - self.net = nn.Sequential( - nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), - nn.BatchNorm2d(out_ch), - nn.ReLU(inplace=True), - ) - - def forward(self, x): - return self.net(x) - - -class Down(nn.Module): - """Combination of MaxPool2d and DoubleConv in series. - >>> Down(4, 8) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Down( - (net): Sequential( - (0): MaxPool2d( - kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) - (1): DoubleConv( - (net): Sequential(...) - ) - ) - ) - """ - - def __init__(self, in_ch: int, out_ch: int): - super().__init__() - self.net = nn.Sequential( - nn.MaxPool2d(kernel_size=2, stride=2), DoubleConv(in_ch, out_ch)) - - def forward(self, x): - return self.net(x) - - -class Up(nn.Module): - """Upsampling (by either bilinear interpolation or transpose convolutions) - followed by concatenation of feature - map from contracting path, followed by double 3x3 convolution. - >>> Up(8, 4) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE - Up( - (upsample): ConvTranspose2d(8, 4, kernel_size=(2, 2), stride=(2, 2)) - (conv): DoubleConv( - (net): Sequential(...) - ) - ) - """ - - def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): - super().__init__() - self.upsample = None - if bilinear: - self.upsample = nn.Sequential( - nn.Upsample( - scale_factor=2, mode="bilinear", align_corners=True), - nn.Conv2d( - in_ch, in_ch // 2, kernel_size=1), - ) - else: - self.upsample = nn.ConvTranspose2d( - in_ch, in_ch // 2, kernel_size=2, stride=2) - - self.conv = DoubleConv(in_ch, out_ch) - - def forward(self, x1, x2): - x1 = self.upsample(x1) - - # Pad x1 to the size of x2 - diff_h = x2.shape[2] - x1.shape[2] - diff_w = x2.shape[3] - x1.shape[3] - - x1 = F.pad( - x1, - [ - diff_w // 2, diff_w - diff_w // 2, - diff_h // 2, diff_h - diff_h // 2 - ]) - - # Concatenate along the channels axis - x = torch.cat([x2, x1], dim=1) - return self.conv(x) - - -class SegmentationModel(LightningModule): - - def __init__( - self, - data_path: list = [], - n_classes: int = 18, - batch_size: int = 256, - lr: float = 3e-4, - num_layers: int = 5, - features_start: int = 64, - bilinear: bool = False, - **kwargs, - ): - super().__init__(**kwargs) - self.data_paths = data_path - self.n_classes = n_classes - self.batch_size = batch_size - self.learning_rate = lr - self.num_layers = num_layers - self.features_start = features_start - self.bilinear = bilinear - self.validation_step_outputs = [] - - self.net = UNet( - num_classes=self.n_classes, - num_layers=self.num_layers, - features_start=self.features_start, - bilinear=self.bilinear - ) - self.transform = transforms.Compose( - [ - transforms.ToTensor(), - transforms.Normalize( - mean=[0.0173, 0.0332, 0.0088, - 0.0136, 0.0381, 0.0348, 0.0249], - std=[0.0150, 0.0127, 0.0124, - 0.0128, 0.0120, 0.0159, 0.0164] - ), - ] - ) - print('> Init datasets') - self.trainset = MODISDataset( - self.data_paths, split="train", transform=self.transform) - self.validset = MODISDataset( - self.data_paths, split="valid", transform=self.transform) - print('Done init datasets') - - def forward(self, x): - return self.net(x) - - def training_step(self, batch, batch_nb): - img, mask = batch - img = img.float() - mask = mask.long() - out = self(img) - loss = F.cross_entropy(out, mask, ignore_index=250) - log_dict = {"train_loss": loss} - self.log_dict(log_dict) - return {"loss": loss, "log": log_dict, "progress_bar": log_dict} - - def validation_step(self, batch, batch_idx): - img, mask = batch - img = img.float() - mask = mask.long() - out = self(img) - loss_val = F.cross_entropy(out, mask, ignore_index=250) - self.validation_step_outputs.append(loss_val) - return {"val_loss": loss_val} - - def on_validation_epoch_end(self): - loss_val = torch.stack(self.validation_step_outputs).mean() - log_dict = {"val_loss": loss_val} - self.log("val_loss", loss_val, sync_dist=True) - self.validation_step_outputs.clear() - return { - "log": log_dict, - "val_loss": log_dict["val_loss"], - "progress_bar": log_dict - } - - def configure_optimizers(self): - opt = torch.optim.Adam(self.net.parameters(), lr=self.learning_rate) - # sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10) - return [opt] # , [sch] - - def train_dataloader(self): - return DataLoader( - self.trainset, - batch_size=self.batch_size, - num_workers=multiprocessing.cpu_count(), - shuffle=True - ) - - def val_dataloader(self): - return DataLoader( - self.validset, - batch_size=self.batch_size, - num_workers=multiprocessing.cpu_count(), - shuffle=False - ) - - -def main(hparams: Namespace): - # ------------------------ - # 1 INIT LIGHTNING MODEL - # ------------------------ - ngpus = int(hparams.ngpus) - # PT ligtning does not expect this, del after use - del hparams.ngpus - - model = SegmentationModel(**vars(hparams)) - - # ------------------------ - # 2 SET LOGGER - # ------------------------ - # logger = True - # if hparams.log_wandb: - # logger = WandbLogger() - # # optional: log model topology - # logger.watch(model.net) - - train_callbacks = [ - # TQDMProgressBar(refresh_rate=20), - ModelCheckpoint(dirpath='models/', - monitor='val_loss', - save_top_k=5, - filename='{epoch}-{val_loss:.2f}.ckpt'), - EarlyStopping("val_loss", patience=10, mode='min'), - ] - - # See number of devices - check_gpus_available(ngpus) - - # ------------------------ - # 3 INIT TRAINER - # ------------------------ - # trainer = Trainer( - # ------------------------ - trainer = Trainer( - accelerator="gpu", - devices=ngpus, - strategy="ddp", - min_epochs=1, - max_epochs=500, - callbacks=train_callbacks, - logger=CSVLogger(save_dir="logs/"), - # precision=16 # makes loss nan, need to fix that - ) - - # ------------------------ - # 5 START TRAINING - # ------------------------ - trainer.fit(model) - trainer.save_checkpoint("best_model.ckpt") - - # ------------------------ - # 6 START TEST - # ------------------------ - # test_set = MODISDataset( - # self.data_path, split=None, transform=self.transform) - # test_dataloader = DataLoader(...) - # trainer.test(ckpt_path="best", dataloaders=) - - -if __name__ == "__main__": - cli_lightning_logo() - - parser = ArgumentParser() - parser.add_argument( - "--data_path", nargs='+', required=True, - help="path where dataset is stored") - parser.add_argument('--ngpus', type=int, - default=torch.cuda.device_count(), - help='number of gpus to use') - parser.add_argument( - "--n-classes", type=int, default=18, help="number of classes") - parser.add_argument( - "--batch_size", type=int, default=256, help="size of the batches") - parser.add_argument( - "--lr", type=float, default=3e-4, help="adam: learning rate") - parser.add_argument( - "--num_layers", type=int, default=5, help="number of layers on u-net") - parser.add_argument( - "--features_start", type=float, default=64, - help="number of features in first layer") - parser.add_argument( - "--bilinear", action="store_true", default=False, - help="whether to use bilinear interpolation or transposed") - # parser.add_argument( - # "--log-wandb", action="store_true", default=True, - # help="whether to use wandb as the logger") - hparams = parser.parse_args() - - main(hparams) diff --git a/pytorch_caney/pipelines/pretraining/mim.py b/pytorch_caney/pipelines/pretraining/mim.py deleted file mode 100644 index 3bcc795..0000000 --- a/pytorch_caney/pipelines/pretraining/mim.py +++ /dev/null @@ -1,371 +0,0 @@ -from pytorch_caney.data.datamodules.mim_datamodule \ - import build_mim_dataloader - -from pytorch_caney.models.mim.mim \ - import build_mim_model - -from pytorch_caney.training.mim_utils \ - import build_optimizer, save_checkpoint - -from pytorch_caney.training.mim_utils import get_grad_norm -from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr -from pytorch_caney.ptc_logging import create_logger -from pytorch_caney.config import get_config - -import argparse -import datetime -import joblib -import numpy as np -import os -import time - -import torch -import torch.cuda.amp as amp -import torch.backends.cudnn as cudnn -import torch.distributed as dist - -from timm.utils import AverageMeter - - -def parse_args(): - """ - Parse command-line arguments - """ - parser = argparse.ArgumentParser( - 'pytorch-caney implementation of MiM pre-training script', - add_help=False) - - parser.add_argument( - '--cfg', - type=str, - required=True, - metavar="FILE", - help='path to config file') - - parser.add_argument( - "--data-paths", - nargs='+', - required=True, - help="paths where dataset is stored") - - parser.add_argument( - '--dataset', - type=str, - required=True, - help='Dataset to use') - - parser.add_argument( - '--batch-size', - type=int, - help="batch size for single GPU") - - parser.add_argument( - '--resume', - help='resume from checkpoint') - - parser.add_argument( - '--accumulation-steps', - type=int, - help="gradient accumulation steps") - - parser.add_argument( - '--use-checkpoint', - action='store_true', - help="whether to use gradient checkpointing to save memory") - - parser.add_argument( - '--enable-amp', - action='store_true') - - parser.add_argument( - '--disable-amp', - action='store_false', - dest='enable_amp') - - parser.set_defaults(enable_amp=True) - - parser.add_argument( - '--output', - default='output', - type=str, - metavar='PATH', - help='root of output folder, the full path is ' + - '// (default: output)') - - parser.add_argument( - '--tag', - help='tag of experiment') - - args = parser.parse_args() - - config = get_config(args) - - return args, config - - -def train(config, - dataloader, - model, - model_wo_ddp, - optimizer, - lr_scheduler, - scaler): - """ - Start pre-training a specific model and dataset. - - Args: - config: config object - dataloader: dataloader to use - model: model to pre-train - model_wo_ddp: model to pre-train that is not the DDP version - optimizer: pytorch optimizer - lr_scheduler: learning-rate scheduler - scaler: loss scaler - """ - - logger.info("Start training") - - start_time = time.time() - - for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): - - dataloader.sampler.set_epoch(epoch) - - execute_one_epoch(config, model, dataloader, - optimizer, epoch, lr_scheduler, scaler) - - if dist.get_rank() == 0 and \ - (epoch % config.SAVE_FREQ == 0 or - epoch == (config.TRAIN.EPOCHS - 1)): - - save_checkpoint(config, epoch, model_wo_ddp, 0., - optimizer, lr_scheduler, scaler, logger) - - total_time = time.time() - start_time - - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - - logger.info('Training time {}'.format(total_time_str)) - - -def execute_one_epoch(config, - model, - dataloader, - optimizer, - epoch, - lr_scheduler, - scaler): - """ - Execute training iterations on a single epoch. - - Args: - config: config object - model: model to pre-train - dataloader: dataloader to use - optimizer: pytorch optimizer - epoch: int epoch number - lr_scheduler: learning-rate scheduler - scaler: loss scaler - """ - - model.train() - - optimizer.zero_grad() - - num_steps = len(dataloader) - - # Set up logging meters - batch_time = AverageMeter() - data_time = AverageMeter() - loss_meter = AverageMeter() - norm_meter = AverageMeter() - loss_scale_meter = AverageMeter() - - start = time.time() - end = time.time() - for idx, (img, mask, _) in enumerate(dataloader): - - data_time.update(time.time() - start) - - img = img.cuda(non_blocking=True) - mask = mask.cuda(non_blocking=True) - - with amp.autocast(enabled=config.ENABLE_AMP): - loss = model(img, mask) - - if config.TRAIN.ACCUMULATION_STEPS > 1: - loss = loss / config.TRAIN.ACCUMULATION_STEPS - scaler.scale(loss).backward() - loss.backward() - if config.TRAIN.CLIP_GRAD: - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), - config.TRAIN.CLIP_GRAD) - else: - grad_norm = get_grad_norm(model.parameters()) - if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: - scaler.step(optimizer) - optimizer.zero_grad() - scaler.update() - lr_scheduler.step_update(epoch * num_steps + idx) - else: - optimizer.zero_grad() - scaler.scale(loss).backward() - if config.TRAIN.CLIP_GRAD: - scaler.unscale_(optimizer) - grad_norm = torch.nn.utils.clip_grad_norm_( - model.parameters(), - config.TRAIN.CLIP_GRAD) - else: - grad_norm = get_grad_norm(model.parameters()) - scaler.step(optimizer) - scaler.update() - lr_scheduler.step_update(epoch * num_steps + idx) - - torch.cuda.synchronize() - - loss_meter.update(loss.item(), img.size(0)) - norm_meter.update(grad_norm) - loss_scale_meter.update(scaler.get_scale()) - batch_time.update(time.time() - end) - end = time.time() - - if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] - memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) - etas = batch_time.avg * (num_steps - idx) - logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - f'loss_scale {loss_scale_meter.val:.4f}' + - f' ({loss_scale_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') - - epoch_time = time.time() - start - logger.info( - f"EPOCH {epoch} training takes " + - f"{datetime.timedelta(seconds=int(epoch_time))}") - - -def main(config): - """ - Starts training process after building the proper model, optimizer, etc. - - Args: - config: config object - """ - - pretrain_data_loader = build_mim_dataloader(config, logger) - - simmim_model = build_model(config, logger) - - simmim_optimizer = build_optimizer(config, - simmim_model, - is_pretrain=True, - logger=logger) - - model, model_wo_ddp = make_ddp(simmim_model) - - n_iter_per_epoch = len(pretrain_data_loader) - - lr_scheduler = build_scheduler(config, simmim_optimizer, n_iter_per_epoch) - - scaler = amp.GradScaler() - - train(config, - pretrain_data_loader, - model, - model_wo_ddp, - simmim_optimizer, - lr_scheduler, - scaler) - - -def build_model(config, logger): - - logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") - - model = build_mim_model(config) - - model.cuda() - - logger.info(str(model)) - - return model - - -def make_ddp(model): - - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[int(os.environ["RANK"])], broadcast_buffers=False) - - model_without_ddp = model.module - - return model, model_without_ddp - - -def setup_rank_worldsize(): - if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: - rank = int(os.environ["RANK"]) - world_size = int(os.environ['WORLD_SIZE']) - print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}") - else: - rank = -1 - world_size = -1 - return rank, world_size - - -def setup_distributed_processing(rank, world_size): - torch.cuda.set_device(int(os.environ["RANK"])) - torch.distributed.init_process_group( - backend='nccl', init_method='env://', world_size=world_size, rank=rank) - torch.distributed.barrier() - - -def setup_seeding(config): - seed = config.SEED + dist.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - - -if __name__ == '__main__': - _, config = parse_args() - - rank, world_size = setup_rank_worldsize() - - setup_distributed_processing(rank, world_size) - - setup_seeding(config) - - cudnn.benchmark = True - - linear_scaled_lr, linear_scaled_min_lr, linear_scaled_warmup_lr = \ - setup_scaled_lr(config) - - config.defrost() - config.TRAIN.BASE_LR = linear_scaled_lr - config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr - config.TRAIN.MIN_LR = linear_scaled_min_lr - config.freeze() - - os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, - dist_rank=dist.get_rank(), - name=f"{config.MODEL.NAME}") - - if dist.get_rank() == 0: - path = os.path.join(config.OUTPUT, "config.json") - with open(path, "w") as f: - f.write(config.dump()) - logger.info(f"Full config saved to {path}") - logger.info(config.dump()) - config_file_name = f'{config.TAG}.config.sav' - config_file_path = os.path.join(config.OUTPUT, config_file_name) - joblib.dump(config, config_file_path) - - main(config) diff --git a/pytorch_caney/pipelines/pretraining/mim_deepspeed.py b/pytorch_caney/pipelines/pretraining/mim_deepspeed.py deleted file mode 100644 index 4d7d33c..0000000 --- a/pytorch_caney/pipelines/pretraining/mim_deepspeed.py +++ /dev/null @@ -1,340 +0,0 @@ -from pytorch_caney.data.datamodules.mim_datamodule \ - import build_mim_dataloader - -from pytorch_caney.models.mim.mim \ - import build_mim_model - -from pytorch_caney.training.mim_utils \ - import build_optimizer, save_checkpoint - -# from pytorch_caney.training.mim_utils import get_grad_norm -from pytorch_caney.lr_scheduler import build_scheduler, setup_scaled_lr -from pytorch_caney.ptc_logging import create_logger -from pytorch_caney.config import get_config - -import deepspeed - -import argparse -import datetime -import joblib -import numpy as np -import os -import time - -import torch -import torch.cuda.amp as amp -import torch.backends.cudnn as cudnn -import torch.distributed as dist - -from timm.utils import AverageMeter - - -def parse_args(): - """ - Parse command-line arguments - """ - parser = argparse.ArgumentParser( - 'pytorch-caney implementation of MiM pre-training script', - add_help=False) - - parser.add_argument( - '--cfg', - type=str, - required=True, - metavar="FILE", - help='path to config file') - - parser.add_argument( - "--data-paths", - nargs='+', - required=True, - help="paths where dataset is stored") - - parser.add_argument( - '--dataset', - type=str, - required=True, - help='Dataset to use') - - parser.add_argument( - '--batch-size', - type=int, - help="batch size for single GPU") - - parser.add_argument( - '--resume', - help='resume from checkpoint') - - parser.add_argument( - '--accumulation-steps', - type=int, - help="gradient accumulation steps") - - parser.add_argument( - '--use-checkpoint', - action='store_true', - help="whether to use gradient checkpointing to save memory") - - parser.add_argument( - '--enable-amp', - action='store_true') - - parser.add_argument( - '--disable-amp', - action='store_false', - dest='enable_amp') - - parser.set_defaults(enable_amp=True) - - parser.add_argument( - '--output', - default='output', - type=str, - metavar='PATH', - help='root of output folder, the full path is ' + - '// (default: output)') - - parser.add_argument( - '--tag', - help='tag of experiment') - - # distributed training (deepspeed) - parser.add_argument("--local_rank", - type=int, - required=True, - help='local rank for DistributedDataParallel') - - args = parser.parse_args() - - config = get_config(args) - - return args, config - - -def train(config, - dataloader, - model_engine, - optimizer, - lr_scheduler, - scaler): - """ - Start pre-training a specific model and dataset. - - Args: - config: config object - dataloader: dataloader to use - model: model to pre-train - model_wo_ddp: model to pre-train that is not the DDP version - optimizer: pytorch optimizer - lr_scheduler: learning-rate scheduler - scaler: loss scaler - """ - - logger.info("Start training") - - start_time = time.time() - - for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): - - dataloader.sampler.set_epoch(epoch) - - execute_one_epoch(config, model_engine, dataloader, - optimizer, epoch, lr_scheduler, scaler) - - if dist.get_rank() == 0 and \ - (epoch % config.SAVE_FREQ == 0 or - epoch == (config.TRAIN.EPOCHS - 1)): - - save_checkpoint(config, epoch, model_engine, 0., - optimizer, lr_scheduler, scaler, logger) - - total_time = time.time() - start_time - - total_time_str = str(datetime.timedelta(seconds=int(total_time))) - - logger.info('Training time {}'.format(total_time_str)) - - -def execute_one_epoch(config, - model, - dataloader, - optimizer, - epoch, - lr_scheduler, - scaler): - """ - Execute training iterations on a single epoch. - - Args: - config: config object - model: model to pre-train - dataloader: dataloader to use - optimizer: pytorch optimizer - epoch: int epoch number - lr_scheduler: learning-rate scheduler - scaler: loss scaler - """ - - model.train() - - optimizer.zero_grad() - - num_steps = len(dataloader) - - # Set up logging meters - batch_time = AverageMeter() - data_time = AverageMeter() - loss_meter = AverageMeter() - norm_meter = AverageMeter() - loss_scale_meter = AverageMeter() - - start = time.time() - end = time.time() - for idx, (img, mask, _) in enumerate(dataloader): - - data_time.update(time.time() - start) - - img = img.cuda(non_blocking=True) - mask = mask.cuda(non_blocking=True) - - loss = model(img, mask) - - model.backward(loss) - - model.step() - - loss_meter.update(loss.item(), img.size(0)) - batch_time.update(time.time() - end) - end = time.time() - - if idx % config.PRINT_FREQ == 0: - lr = optimizer.param_groups[0]['lr'] - memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) - etas = batch_time.avg * (num_steps - idx) - logger.info( - f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' - f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' - f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' - f'data_time {data_time.val:.4f} ({data_time.avg:.4f})\t' - f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' - f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' - # f'loss_scale {loss_scale_meter.val:.4f}' + - f' ({loss_scale_meter.avg:.4f})\t' - f'mem {memory_used:.0f}MB') - - epoch_time = time.time() - start - logger.info( - f"EPOCH {epoch} training takes " + - f"{datetime.timedelta(seconds=int(epoch_time))}") - - -def main(config): - """ - Starts training process after building the proper model, optimizer, etc. - - Args: - config: config object - """ - - logger.info('In main') - - pretrain_data_loader = build_mim_dataloader(config, logger) - - simmim_model = build_model(config, logger) - - simmim_optimizer = build_optimizer(config, - simmim_model, - is_pretrain=True, - logger=logger) - - n_iter_per_epoch = len(pretrain_data_loader) - - lr_scheduler = build_scheduler(config, simmim_optimizer, n_iter_per_epoch) - - deepspeed_config = { - "train_micro_batch_size_per_gpu": config.DATA.BATCH_SIZE, - - "zero_optimization": { - "stage": 2, - "allgather_partitions": True, - "allgather_bucket_size": 2e8, - "overlap_comm": True, - "reduce_scatter": True, - "reduce_bucket_size": "auto", - "contiguous_gradients": True, - } - } - - logger.info('Initializing deepspeed') - - model_engine, _, _, _ = deepspeed.initialize( - model=simmim_model, - optimizer=simmim_optimizer, - lr_scheduler=lr_scheduler, - model_parameters=simmim_model.parameters(), - config=deepspeed_config - ) - - scaler = amp.GradScaler() - - logger.info('Starting training block') - - train(config, - pretrain_data_loader, - model_engine, - simmim_optimizer, - lr_scheduler, - scaler) - - -def build_model(config, logger): - - logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}") - - model = build_mim_model(config) - - logger.info(str(model)) - - return model - - -def setup_seeding(config): - seed = config.SEED + dist.get_rank() - torch.manual_seed(seed) - np.random.seed(seed) - - -if __name__ == '__main__': - _, config = parse_args() - - deepspeed.init_distributed() - - setup_seeding(config) - - cudnn.benchmark = True - - linear_scaled_lr, linear_scaled_min_lr, linear_scaled_warmup_lr = \ - setup_scaled_lr(config) - - config.defrost() - config.TRAIN.BASE_LR = linear_scaled_lr - config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr - config.TRAIN.MIN_LR = linear_scaled_min_lr - config.freeze() - - os.makedirs(config.OUTPUT, exist_ok=True) - logger = create_logger(output_dir=config.OUTPUT, - dist_rank=dist.get_rank(), - name=f"{config.MODEL.NAME}") - - if dist.get_rank() == 0: - path = os.path.join(config.OUTPUT, "config.json") - with open(path, "w") as f: - f.write(config.dump()) - logger.info(f"Full config saved to {path}") - logger.info(config.dump()) - config_file_name = f'{config.TAG}.config.sav' - config_file_path = os.path.join(config.OUTPUT, config_file_name) - joblib.dump(config, config_file_path) - - main(config) diff --git a/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py b/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py new file mode 100644 index 0000000..c5461fe --- /dev/null +++ b/pytorch_caney/pipelines/satvision_toa_pretrain_pipeline.py @@ -0,0 +1,101 @@ +import torch +import torchmetrics +from torch.utils.data import DataLoader + +import lightning.pytorch as pl + +from pytorch_caney.datasets.sharded_dataset import ShardedDataset +from pytorch_caney.models.mim import build_mim_model +from pytorch_caney.optimizers.build import build_optimizer +from pytorch_caney.transforms.mim_modis_toa import MimTransform + + +# ----------------------------------------------------------------------------- +# SatVisionToaPretrain +# ----------------------------------------------------------------------------- +class SatVisionToaPretrain(pl.LightningModule): + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__(self, config): + super(SatVisionToaPretrain, self).__init__() + self.save_hyperparameters(ignore=['model']) + self.config = config + + self.model = build_mim_model(self.config) + if self.config.MODEL.PRETRAINED: + self.load_checkpoint() + + self.transform = MimTransform(self.config) + self.batch_size = config.DATA.BATCH_SIZE + self.num_workers = config.DATA.NUM_WORKERS + self.img_size = config.DATA.IMG_SIZE + self.train_data_paths = config.DATA.DATA_PATHS + self.train_data_length = config.DATA.LENGTH + self.pin_memory = config.DATA.PIN_MEMORY + + self.train_loss_avg = torchmetrics.MeanMetric() + self.trainset = ShardedDataset( + self.config, + self.train_data_paths, + split='train', + length=self.train_data_length, + img_size=self.img_size, + transform=self.transform, + batch_size=self.batch_size).dataset() + + # ------------------------------------------------------------------------- + # load_checkpoint + # ------------------------------------------------------------------------- + def load_checkpoint(self): + print('Loading checkpoint from {self.config.MODEL.PRETRAINED}') + checkpoint = torch.load(self.config.MODEL.PRETRAINED) + self.model.load_state_dict(checkpoint['module']) + print('Successfully applied checkpoint') + + # ------------------------------------------------------------------------- + # forward + # ------------------------------------------------------------------------- + def forward(self, x, x_mask): + return self.model(x, x_mask) + + # ------------------------------------------------------------------------- + # training_step + # ------------------------------------------------------------------------- + def training_step(self, batch, batch_idx): + image_imagemask = batch[0] + image = torch.stack([pair[0] for pair in image_imagemask]) + mask = torch.stack([pair[1] for pair in image_imagemask]) + loss = self.forward(image, mask) + self.train_loss_avg.update(loss) + self.log('train_loss', + self.train_loss_avg.compute(), + rank_zero_only=True, + batch_size=self.batch_size, + prog_bar=True) + + return loss + + # ------------------------------------------------------------------------- + # configure_optimizers + # ------------------------------------------------------------------------- + def configure_optimizers(self): + optimizer = build_optimizer(self.config, self.model, is_pretrain=True) + return optimizer + + # ------------------------------------------------------------------------- + # on_train_epoch_start + # ------------------------------------------------------------------------- + def on_train_epoch_start(self): + self.train_loss_avg.reset() + + # ------------------------------------------------------------------------- + # train_dataloader + # ------------------------------------------------------------------------- + def train_dataloader(self): + return DataLoader(self.trainset, + batch_size=None, + shuffle=False, + pin_memory=self.pin_memory, + num_workers=self.num_workers) diff --git a/pytorch_caney/pipelines/three_d_cloud_pipeline.py b/pytorch_caney/pipelines/three_d_cloud_pipeline.py new file mode 100644 index 0000000..717c705 --- /dev/null +++ b/pytorch_caney/pipelines/three_d_cloud_pipeline.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +import torchmetrics + +import lightning.pytorch as pl + +from pytorch_caney.optimizers.build import build_optimizer +from pytorch_caney.transforms.abi_toa import AbiToaTransform +from pytorch_caney.models import ModelFactory +from typing import Tuple + + +# ----------------------------------------------------------------------------- +# ThreeDCloudTask +# ----------------------------------------------------------------------------- +class ThreeDCloudTask(pl.LightningModule): + + NUM_CLASSES: int = 1 + OUTPUT_SHAPE: Tuple[int, int] = (91, 40) + + # ------------------------------------------------------------------------- + # __init__ + # ------------------------------------------------------------------------- + def __init__(self, config): + super(ThreeDCloudTask, self).__init__() + self.save_hyperparameters(ignore=['model']) + self.config = config + self.configure_models() + self.configure_losses() + self.configure_metrics() + self.transform = AbiToaTransform(self.config) + + # ------------------------------------------------------------------------- + # configure_models + # ------------------------------------------------------------------------- + def configure_models(self): + factory = ModelFactory() + + self.encoder = factory.get_component(component_type="encoder", + name=self.config.MODEL.ENCODER, + config=self.config) + + self.decoder = factory.get_component( + component_type="decoder", + name=self.config.MODEL.DECODER, + num_features=self.encoder.num_features) + + self.segmentation_head = factory.get_component( + component_type="head", + name="segmentation_head", + decoder_channels=self.decoder.output_channels, + num_classes=self.NUM_CLASSES, + output_shape=self.OUTPUT_SHAPE + ) + + self.model = nn.Sequential(self.encoder, + self.decoder, + self.segmentation_head) + print(self.model) + + # ------------------------------------------------------------------------- + # configure_losses + # ------------------------------------------------------------------------- + def configure_losses(self): + loss: str = self.config.LOSS.NAME + if loss == 'bce': + self.criterion = nn.BCEWithLogitsLoss() + else: + raise ValueError( + f'Loss type "{loss}" is not valid. ' + 'Currecntly supports "ce".' + ) + + # ------------------------------------------------------------------------- + # configure_metrics + # ------------------------------------------------------------------------- + def configure_metrics(self): + num_classes = 2 + self.train_iou = torchmetrics.JaccardIndex(num_classes=num_classes, + task="binary") + self.val_iou = torchmetrics.JaccardIndex(num_classes=num_classes, + task="binary") + self.train_loss_avg = torchmetrics.MeanMetric() + self.val_loss_avg = torchmetrics.MeanMetric() + + self.train_iou_avg = torchmetrics.MeanMetric() + self.val_iou_avg = torchmetrics.MeanMetric() + + # ------------------------------------------------------------------------- + # forward + # ------------------------------------------------------------------------- + def forward(self, x): + return self.model(x) + + # ------------------------------------------------------------------------- + # training_step + # ------------------------------------------------------------------------- + def training_step(self, batch, batch_idx): + inputs, targets = batch + targets = targets.unsqueeze(1) + logits = self.forward(inputs) + loss = self.criterion(logits, targets.float()) + preds = torch.sigmoid(logits) + iou = self.train_iou(preds, targets.int()) + + self.train_loss_avg.update(loss) + self.train_iou_avg.update(iou) + self.log('train_loss', self.train_loss_avg.compute(), + on_step=True, on_epoch=True, prog_bar=True) + self.log('train_iou', self.train_iou_avg.compute(), + on_step=True, on_epoch=True, prog_bar=True) + return loss + + # ------------------------------------------------------------------------- + # validation_step + # ------------------------------------------------------------------------- + def validation_step(self, batch, batch_idx): + inputs, targets = batch + targets = targets.unsqueeze(1) + logits = self.forward(inputs) + val_loss = self.criterion(logits, targets.float()) + preds = torch.sigmoid(logits) + val_iou = self.val_iou(preds, targets.int()) + self.val_loss_avg.update(val_loss) + self.val_iou_avg.update(val_iou) + self.log('val_loss', self.val_loss_avg.compute(), + on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + self.log('val_iou', self.val_iou_avg.compute(), + on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) + + return val_loss + + # ------------------------------------------------------------------------- + # configure_optimizers + # ------------------------------------------------------------------------- + def configure_optimizers(self): + optimizer = build_optimizer(self.config, self.model, is_pretrain=True) + print(f'Using optimizer: {optimizer}') + return optimizer + + # ------------------------------------------------------------------------- + # on_train_epoch_start + # ------------------------------------------------------------------------- + def on_train_epoch_start(self): + self.train_loss_avg.reset() + self.train_iou_avg.reset() + + # ------------------------------------------------------------------------- + # on_validation_epoch_start + # ------------------------------------------------------------------------- + def on_validation_epoch_start(self): + self.val_loss_avg.reset() diff --git a/pytorch_caney/data/datasets/object_dataset.py b/pytorch_caney/plotting/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/data/datasets/object_dataset.py rename to pytorch_caney/plotting/__init__.py diff --git a/pytorch_caney/plotting/modis_toa.py b/pytorch_caney/plotting/modis_toa.py new file mode 100644 index 0000000..37671b6 --- /dev/null +++ b/pytorch_caney/plotting/modis_toa.py @@ -0,0 +1,152 @@ +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages + +from ..transforms.modis_toa_scale import MinMaxEmissiveScaleReflectance + + +# ----------------------------------------------------------------------------- +# MODIS Reconstruction Visualization Pipeline +# ----------------------------------------------------------------------------- +# This script processes MODIS TOA images and model reconstructions, generating +# comparison visualizations in a PDF format. It contains several functions that +# interact to prepare, transform, and visualize MODIS image data, applying +# necessary transformations for reflective and emissive band scaling, masking, +# and normalization. The flow is as follows: +# +# 1. `plot_export_pdf`: Main function that generates PDF visualizations. +# It uses other functions to process and organize data. +# 2. `process_reconstruction_prediction`: Prepares images and masks for +# visualization, applying transformations and normalization. +# 3. `minmax_norm`: Scales image arrays to 0-255 range for display. +# 4. `process_mask`: Prepares mask images to match the input image dimensions. +# 5. `reverse_transform`: Applies band-specific scaling to MODIS data. +# +# ASCII Diagram: +# +# plot_export_pdf +# └── process_reconstruction_prediction +# ├── minmax_norm +# ├── process_mask +# └── reverse_transform +# +# ----------------------------------------------------------------------------- + + +# ----------------------------------------------------------------------------- +# plot_export_pdf +# ----------------------------------------------------------------------------- +# Generates a multi-page PDF with visualizations of original, reconstructed, +# and masked MODIS images. Uses the `process_reconstruction_prediction` funct +# to prepare data for display and organizes subplots for easy comparison. +# ----------------------------------------------------------------------------- +def plot_export_pdf(path, inputs, outputs, masks, rgb_index): + pdf_plot_obj = PdfPages(path) + + for idx in range(len(inputs)): + # prediction processing + image = inputs[idx] + img_recon = outputs[idx] + mask = masks[idx] + rgb_image, rgb_image_masked, rgb_recon_masked, mask = \ + process_reconstruction_prediction( + image, img_recon, mask, rgb_index) + + # matplotlib code + fig, (ax01, ax23) = plt.subplots(2, 2, figsize=(40, 30)) + ax0, ax1 = ax01 + ax2, ax3 = ax23 + ax2.imshow(rgb_image) + ax2.set_title(f"Idx: {idx} MOD021KM v6.1 Bands: {rgb_index}") + + ax0.imshow(rgb_recon_masked) + ax0.set_title(f"Idx: {idx} Model reconstruction") + + ax1.imshow(rgb_image_masked) + ax1.set_title(f"Idx: {idx} MOD021KM Bands: {rgb_index}, masked") + + ax3.matshow(mask[:, :, 0]) + ax3.set_title(f"Idx: {idx} Reconstruction Mask") + pdf_plot_obj.savefig() + + pdf_plot_obj.close() + + +# ----------------------------------------------------------------------------- +# process_reconstruction_prediction +# ----------------------------------------------------------------------------- +# Prepares RGB images, reconstructions, and masked versions by extracting and +# normalizing specific bands based on the provided RGB indices. Returns masked +# images and the processed mask for visualization in the PDF. +# ----------------------------------------------------------------------------- +def process_reconstruction_prediction(image, img_recon, mask, rgb_index): + + mask = process_mask(mask) + + red_idx = rgb_index[0] + blue_idx = rgb_index[1] + green_idx = rgb_index[2] + + image = reverse_transform(image.numpy()) + + img_recon = reverse_transform(img_recon.numpy()) + + rgb_image = np.stack((image[red_idx, :, :], + image[blue_idx, :, :], + image[green_idx, :, :]), axis=-1) + rgb_image = minmax_norm(rgb_image) + + rgb_image_recon = np.stack((img_recon[red_idx, :, :], + img_recon[blue_idx, :, :], + img_recon[green_idx, :, :]), axis=-1) + rgb_image_recon = minmax_norm(rgb_image_recon) + + rgb_masked = np.where(mask == 0, rgb_image, rgb_image_recon) + rgb_image_masked = np.where(mask == 1, 0, rgb_image) + rgb_recon_masked = rgb_masked + + return rgb_image, rgb_image_masked, rgb_recon_masked, mask + + +# ----------------------------------------------------------------------------- +# minmax_norm +# ----------------------------------------------------------------------------- +# Normalizes an image array to a range of 0-255 for consistent display. +# ----------------------------------------------------------------------------- +def minmax_norm(img_arr): + arr_min = img_arr.min() + arr_max = img_arr.max() + img_arr_scaled = (img_arr - arr_min) / (arr_max - arr_min) + img_arr_scaled = img_arr_scaled * 255 + img_arr_scaled = img_arr_scaled.astype(np.uint8) + return img_arr_scaled + + +# ----------------------------------------------------------------------------- +# process_mask +# ----------------------------------------------------------------------------- +# Adjusts the dimensions of a binary mask to match the input image shape, +# replicating mask values across the image. +# ----------------------------------------------------------------------------- +def process_mask(mask): + mask_img = mask.unsqueeze(0) + mask_img = mask_img.repeat_interleave(4, 1).repeat_interleave(4, 2) + mask_img = mask_img.unsqueeze(1).contiguous()[0, 0] + return np.stack([mask_img] * 3, axis=-1) + + +# ----------------------------------------------------------------------------- +# reverse_transform +# ----------------------------------------------------------------------------- +# Reverses scaling transformations applied to the original MODIS data to +# prepare the image for RGB visualization. +# ----------------------------------------------------------------------------- +def reverse_transform(image): + minMaxTransform = MinMaxEmissiveScaleReflectance() + image = image.transpose((1, 2, 0)) + image[:, :, minMaxTransform.reflectance_indices] *= 100 + emis_min, emis_max = \ + minMaxTransform.emissive_mins, minMaxTransform.emissive_maxs + image[:, :, minMaxTransform.emissive_indices] *= (emis_max - emis_min) + image[:, :, minMaxTransform.emissive_indices] += emis_min + return image.transpose((2, 0, 1)) diff --git a/pytorch_caney/processing.py b/pytorch_caney/processing.py deleted file mode 100755 index 30723d0..0000000 --- a/pytorch_caney/processing.py +++ /dev/null @@ -1,410 +0,0 @@ -import logging -import random -from tqdm import tqdm - -import numpy as np -from numpy import fliplr, flipud - -import scipy.signal - - -SEED = 42 -np.random.seed(SEED) - -__author__ = "Jordan A Caraballo-Vega, Science Data Processing Branch" -__email__ = "jordan.a.caraballo-vega@nasa.gov" -__status__ = "Production" - -# ---------------------------------------------------------------------------- -# module processing -# -# General functions to perform standardization of images (numpy arrays). -# A couple of methods have been implemented for testing, including global and -# local standardization for neural networks input. Data manipulation stage, -# extract random patches for training and store them in numpy arrays. -# --------------------------------------------------------------------------- - -# --------------------------------------------------------------------------- -# Module Methods -# --------------------------------------------------------------------------- - - -# --------------------------- Normalization Functions ----------------------- # -def normalize(images, factor=65535.0) -> np.array: - """ - Normalize numpy array in the range of [0,1] - :param images: numpy array in the format (n,w,h,c). - :param factor: float number to normalize images, e.g. 2^(16)-1 - :return: numpy array in the [0,1] range - """ - return images / factor - - -# ------------------------ Standardization Functions ----------------------- # -def global_standardization(images, strategy='per-batch') -> np.array: - """ - Standardize numpy array using global standardization. - :param images: numpy array in the format (n,w,h,c). - :param strategy: can select between per-image or per-batch. - :return: globally standardized numpy array - """ - if strategy == 'per-batch': - mean = np.mean(images) # global mean of all images - std = np.std(images) # global std of all images - for i in range(images.shape[0]): # for each image in images - images[i, :, :, :] = (images[i, :, :, :] - mean) / std - elif strategy == 'per-image': - for i in range(images.shape[0]): # for each image in images - mean = np.mean(images[i, :, :, :]) # image mean - std = np.std(images[i, :, :, :]) # image std - images[i, :, :, :] = (images[i, :, :, :] - mean) / std - return images - - -def local_standardization(images, filename='normalization_data', - ndata=None, strategy='per-batch' - ) -> np.array: - """ - Standardize numpy array using local standardization. - :param images: numpy array in the format (n,w,h,c). - :param filename: filename to store mean and std data. - :param ndata: pandas df with mean and std values for each channel. - :param strategy: can select between per-image or per-batch. - :return: locally standardized numpy array - """ - if ndata: # for inference only - for i in range(images.shape[-1]): # for each channel in images - # standardize all images based on given mean and std - images[:, :, :, i] = \ - (images[:, :, :, i] - ndata['channel_mean'][i]) / \ - ndata['channel_std'][i] - return images - elif strategy == 'per-batch': # for all images in batch - f = open(filename + "_norm_data.csv", "w+") - f.write( - "i,channel_mean,channel_std,channel_mean_post,channel_std_post\n" - ) - for i in range(images.shape[-1]): # for each channel in images - channel_mean = np.mean(images[:, :, :, i]) # mean for each channel - channel_std = np.std(images[:, :, :, i]) # std for each channel - images[:, :, :, i] = \ - (images[:, :, :, i] - channel_mean) / channel_std - channel_mean_post = np.mean(images[:, :, :, i]) - channel_std_post = np.std(images[:, :, :, i]) - # write to file for each channel - f.write('{},{},{},{},{}\n'.format(i, channel_mean, channel_std, - channel_mean_post, - channel_std_post - ) - ) - f.close() # close file - elif strategy == 'per-image': # standardization for each image - for i in range(images.shape[0]): # for each image - for j in range(images.shape[-1]): # for each channel in images - channel_mean = np.mean(images[i, :, :, j]) - channel_std = np.std(images[i, :, :, j]) - images[i, :, :, j] = \ - (images[i, :, :, j] - channel_mean) / channel_std - else: - raise RuntimeError(f'Standardization <{strategy}> not supported') - - return images - - -def standardize_image( - image, - standardization_type: str, - mean: list = None, - std: list = None, - global_min: list = None, - global_max: list = None -): - """ - Standardize image within parameter, simple scaling of values. - Loca, Global, and Mixed options. - """ - image = image.astype(np.float32) - if standardization_type == 'local': - for i in range(image.shape[-1]): - image[:, :, i] = (image[:, :, i] - np.mean(image[:, :, i])) / \ - (np.std(image[:, :, i]) + 1e-8) - elif standardization_type == 'minmax': - for i in range(image.shape[-1]): - image[:, :, i] = (image[:, :, i] - 0) / (55-0) - elif standardization_type == 'localminmax': - for i in range(image.shape[-1]): - image[:, :, i] = (image[:, :, i] - np.min(image[:, :, 0])) / \ - (np.max(image[:, :, i])-np.min(image[:, :, i])) - elif standardization_type == 'globalminmax': - for i in range(image.shape[-1]): - image[:, :, i] = (image[:, :, i] - global_min) / \ - (global_max - global_min) - elif standardization_type == 'global': - for i in range(image.shape[-1]): - image[:, :, i] = (image[:, :, i] - mean[i]) / (std[i] + 1e-8) - elif standardization_type == 'mixed': - raise NotImplementedError - return image - - -def standardize_batch( - image_batch, - standardization_type: str, - mean: list = None, - std: list = None -): - """ - Standardize image within parameter, simple scaling of values. - Loca, Global, and Mixed options. - """ - for item in range(image_batch.shape[0]): - image_batch[item, :, :, :] = standardize_image( - image_batch[item, :, :, :], standardization_type, mean, std) - return image_batch - -# ------------------------ Data Preparation Functions ----------------------- # - - -def get_rand_patches_rand_cond(img, mask, n_patches=16000, sz=160, nclasses=6, - nodata_ascloud=True, method='rand' - ) -> np.array: - """ - Generate training data. - :param images: ndarray in the format (w,h,c). - :param mask: integer ndarray with shape (x_sz, y_sz) - :param n_patches: number of patches - :param sz: tile size, will be used for both height and width - :param nclasses: number of classes present in the output data - :param nodata_ascloud: convert no-data values to cloud labels - :param method: choose between rand, cond, cloud - rand - select N number of random patches for each image - cond - select N number of random patches for each image, - with the condition of having 1+ class per tile. - cloud - select tiles that have clouds - :return: two numpy array with data and labels. - """ - if nodata_ascloud: - # if no-data present, change to final class - mask = mask.values # return numpy array - mask[mask > nclasses] = nclasses # some no-data are 255 or other big - mask[mask < 0] = nclasses # some no-data are -128 or smaller negative - - patches = [] # list to store data patches - labels = [] # list to store label patches - - for i in tqdm(range(n_patches)): - - # Generate random integers from image - xc = random.randint(0, img.shape[0] - sz) - yc = random.randint(0, img.shape[1] - sz) - - if method == 'cond': - # while loop to regenerate random ints if tile has only one class - while len(np.unique(mask[xc:(xc+sz), yc:(yc+sz)])) == 1 or \ - 6 in mask[xc:(xc+sz), yc:(yc+sz)] or \ - img[xc:(xc+sz), yc:(yc+sz), :].values.min() < 0: - xc = random.randint(0, img.shape[0] - sz) - yc = random.randint(0, img.shape[1] - sz) - elif method == 'rand': - while 6 in mask[xc:(xc+sz), yc:(yc+sz)] or \ - img[xc:(xc+sz), yc:(yc+sz), :].values.min() < 0: - xc = random.randint(0, img.shape[0] - sz) - yc = random.randint(0, img.shape[1] - sz) - elif method == 'cloud': - while np.count_nonzero(mask[xc:(xc+sz), yc:(yc+sz)] == 6) < 15: - xc = random.randint(0, img.shape[0] - sz) - yc = random.randint(0, img.shape[1] - sz) - - # Generate img and mask patches - patch_img = img[xc:(xc + sz), yc:(yc + sz)] - patch_mask = mask[xc:(xc + sz), yc:(yc + sz)] - - # Apply some random transformations - random_transformation = np.random.randint(1, 7) - if random_transformation == 1: # flip left and right - patch_img = fliplr(patch_img) - patch_mask = fliplr(patch_mask) - elif random_transformation == 2: # reverse second dimension - patch_img = flipud(patch_img) - patch_mask = flipud(patch_mask) - elif random_transformation == 3: # rotate 90 degrees - patch_img = np.rot90(patch_img, 1) - patch_mask = np.rot90(patch_mask, 1) - elif random_transformation == 4: # rotate 180 degrees - patch_img = np.rot90(patch_img, 2) - patch_mask = np.rot90(patch_mask, 2) - elif random_transformation == 5: # rotate 270 degrees - patch_img = np.rot90(patch_img, 3) - patch_mask = np.rot90(patch_mask, 3) - else: # original image - pass - patches.append(patch_img) - labels.append(patch_mask) - return np.asarray(patches), np.asarray(labels) - - -def get_rand_patches_aug_augcond(img, mask, n_patches=16000, sz=256, - nclasses=6, over=50, nodata_ascloud=True, - nodata=-9999, method='augcond' - ) -> np.array: - """ - Generate training data. - :param images: ndarray in the format (w,h,c). - :param mask: integer ndarray with shape (x_sz, y_sz) - :param n_patches: number of patches - :param sz: tile size, will be used for both height and width - :param nclasses: number of classes present in the output data - :param over: number of pixels to overlap between images - :param nodata_ascloud: convert no-data values to cloud labels - :param method: choose between rand, cond, cloud - aug - select N * 8 number of random patches for each - image after data augmentation. - augcond - select N * 8 number of random patches for - each image, with the condition of having 1+ per - tile, after data augmentation. - :return: two numpy array with data and labels. - """ - mask = mask.values # return numpy array - - if nodata_ascloud: - # if no-data present, change to final class - mask[mask > nclasses] = nodata # some no-data are 255 or other big - mask[mask < 0] = nodata # some no-data are -128 or smaller negative - - patches = [] # list to store data patches - labels = [] # list to store label patches - - for i in tqdm(range(n_patches)): - - # Generate random integers from image - xc = random.randint(0, img.shape[0] - sz - sz) - yc = random.randint(0, img.shape[1] - sz - sz) - - if method == 'augcond': - # while loop to regenerate random ints if tile has only one class - while len(np.unique(mask[xc:(xc + sz), yc:(yc + sz)])) == 1 or \ - nodata in mask[xc:(xc + sz), yc:(yc + sz)] or \ - nodata in mask[(xc + sz - over):(xc + sz + sz - over), - (yc + sz - over):(yc + sz + sz - over)] or \ - nodata in mask[(xc + sz - over):(xc + sz + sz - over), - yc:(yc + sz)]: - xc = random.randint(0, img.shape[0] - sz - sz) - yc = random.randint(0, img.shape[1] - sz - sz) - elif method == 'aug': - # while loop to regenerate random ints if tile has only one class - while nodata in mask[xc:(xc + sz), yc:(yc + sz)] or \ - nodata in mask[(xc + sz - over):(xc + sz + sz - over), - (yc + sz - over):(yc + sz + sz - over)] or \ - nodata in mask[(xc + sz - over):(xc + sz + sz - over), - yc:(yc + sz)]: - xc = random.randint(0, img.shape[0] - sz - sz) - yc = random.randint(0, img.shape[1] - sz - sz) - - # Generate img and mask patches - patch_img = img[xc:(xc + sz), yc:(yc + sz)] # original image patch - patch_mask = mask[xc:(xc + sz), yc:(yc + sz)] # original mask patch - - # Apply transformations for data augmentation - # 1. No augmentation and append to list - patches.append(patch_img) - labels.append(patch_mask) - - # 2. Rotate 90 and append to list - patches.append(np.rot90(patch_img, 1)) - labels.append(np.rot90(patch_mask, 1)) - - # 3. Rotate 180 and append to list - patches.append(np.rot90(patch_img, 2)) - labels.append(np.rot90(patch_mask, 2)) - - # 4. Rotate 270 - patches.append(np.rot90(patch_img, 3)) - labels.append(np.rot90(patch_mask, 3)) - - # 5. Flipped up and down’ - patches.append(flipud(patch_img)) - labels.append(flipud(patch_mask)) - - # 6. Flipped left and right - patches.append(fliplr(patch_img)) - labels.append(fliplr(patch_mask)) - - # 7. overlapping tiles - next tile, down - patches.append(img[(xc + sz - over):(xc + sz + sz - over), - (yc + sz - over):(yc + sz + sz - over)]) - labels.append(mask[(xc + sz - over):(xc + sz + sz - over), - (yc + sz - over):(yc + sz + sz - over)]) - - # 8. overlapping tiles - next tile, side - patches.append(img[(xc + sz - over):(xc + sz + sz - over), - yc:(yc + sz)]) - labels.append(mask[(xc + sz - over):(xc + sz + sz - over), - yc:(yc + sz)]) - return np.asarray(patches), np.asarray(labels) - - -# ------------------------ Artifact Removal Functions ----------------------- # - -def _2d_spline(window_size=128, power=2) -> np.array: - """ - Window method for boundaries/edge artifacts smoothing. - :param window_size: size of window/tile to smooth - :param power: spline polinomial power to use - :return: smoothing distribution numpy array - """ - intersection = int(window_size/4) - tria = scipy.signal.triang(window_size) - wind_outer = (abs(2*(tria)) ** power)/2 - wind_outer[intersection:-intersection] = 0 - - wind_inner = 1 - (abs(2*(tria - 1)) ** power)/2 - wind_inner[:intersection] = 0 - wind_inner[-intersection:] = 0 - - wind = wind_inner + wind_outer - wind = wind / np.average(wind) - wind = np.expand_dims(np.expand_dims(wind, 1), 2) - wind = wind * wind.transpose(1, 0, 2) - return wind - - -def _hann_matrix(window_size=128, power=2) -> np.array: - logging.info("Placeholder for next release.") - - -# ------------------------------------------------------------------------------- -# module preprocessing Unit Tests -# ------------------------------------------------------------------------------- -if __name__ == "__main__": - - logging.basicConfig(level=logging.INFO) - - # Unit Test #1 - Testing normalization distributions - x = (np.random.randint(65536, size=(10, 128, 128, 6))).astype('float32') - x_norm = normalize(x, factor=65535) # apply static normalization - assert x_norm.max() == 1.0, "Unexpected max value." - logging.info(f"UT #1 PASS: {x_norm.mean()}, {x_norm.std()}") - - # Unit Test #2 - Testing standardization distributions - standardized = global_standardization(x_norm, strategy='per-batch') - assert standardized.max() > 1.731, "Unexpected max value." - logging.info(f"UT #2 PASS: {standardized.mean()}, {standardized.std()}") - - # Unit Test #3 - Testing standardization distributions - standardized = global_standardization(x_norm, strategy='per-image') - assert standardized.max() > 1.73, "Unexpected max value." - logging.info(f"UT #3 PASS: {standardized.mean()}, {standardized.std()}") - - # Unit Test #4 - Testing standardization distributions - standardized = local_standardization(x_norm, filename='normalization_data', - strategy='per-batch' - ) - assert standardized.max() > 1.74, "Unexpected max value." - logging.info(f"UT #4 PASS: {standardized.mean()}, {standardized.std()}") - - # Unit Test #5 - Testing standardization distributions - standardized = local_standardization(x_norm, filename='normalization_data', - strategy='per-image' - ) - assert standardized.max() > 1.75, "Unexpected max value." - logging.info(f"UT #5 PASS: {standardized.mean()}, {standardized.std()}") diff --git a/pytorch_caney/ptc_cli.py b/pytorch_caney/ptc_cli.py new file mode 100644 index 0000000..d41ed96 --- /dev/null +++ b/pytorch_caney/ptc_cli.py @@ -0,0 +1,87 @@ +import argparse +import os + +from lightning.pytorch import Trainer + +from pytorch_caney.configs.config import _C, _update_config_from_file +from pytorch_caney.utils import get_strategy, get_distributed_train_batches +from pytorch_caney.pipelines import PIPELINES, get_available_pipelines +from pytorch_caney.datamodules import DATAMODULES, get_available_datamodules + + +# ----------------------------------------------------------------------------- +# main +# ----------------------------------------------------------------------------- +def main(config, output_dir): + + print('Training') + + # Get the proper pipeline + available_pipelines = get_available_pipelines() + print("Available pipelines:", available_pipelines) + pipeline = PIPELINES[config.PIPELINE] + print(f'Using {pipeline}') + ptlPipeline = pipeline(config) + + # Resume from checkpoint + if config.MODEL.RESUME: + print(f'Attempting to resume from checkpoint {config.MODEL.RESUME}') + ptlPipeline = pipeline.load_from_checkpoint(config.MODEL.RESUME) + + # Determine training strategy + strategy = get_strategy(config) + + trainer = Trainer( + accelerator=config.TRAIN.ACCELERATOR, + devices=-1, + strategy=strategy, + precision=config.PRECISION, + max_epochs=config.TRAIN.EPOCHS, + log_every_n_steps=config.PRINT_FREQ, + default_root_dir=output_dir, + ) + + if config.TRAIN.LIMIT_TRAIN_BATCHES: + trainer.limit_train_batches = get_distributed_train_batches( + config, trainer) + + if config.DATA.DATAMODULE: + available_datamodules = get_available_datamodules() + print(f"Available data modules: {available_datamodules}") + datamoduleClass = DATAMODULES[config.DATAMODULE] + datamodule = datamoduleClass(config) + print(f'Training using datamodule: {datamodule}') + trainer.fit(model=ptlPipeline, datamodule=datamodule) + + else: + print(f'Training without datamodule, assuming data is set in pipeline: {ptlPipeline}') # noqa: E501 + trainer.fit(model=ptlPipeline) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument( + '--config-path', type=str, help='Path to pretrained model config' + ) + + hparams = parser.parse_args() + + config = _C.clone() + _update_config_from_file(config, hparams.config_path) + + output_dir = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG) + print(f'Output directory: {output_dir}') + os.makedirs(output_dir, exist_ok=True) + + path = os.path.join(output_dir, + f"{config.TAG}.config.json") + + with open(path, "w") as f: + f.write(config.dump()) + + print(f"Full config saved to {path}") + print(config.dump()) + + main(config, output_dir) diff --git a/pytorch_caney/ptc_logging.py b/pytorch_caney/ptc_logging.py deleted file mode 100644 index 3b76462..0000000 --- a/pytorch_caney/ptc_logging.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -import sys -import logging -import functools -from termcolor import colored - - -@functools.lru_cache() -def create_logger(output_dir, dist_rank=0, name=''): - # create logger - logger = logging.getLogger(name) - - logger.setLevel(logging.DEBUG) - - logger.propagate = False - - # create formatter - fmt = '[%(asctime)s %(name)s] ' + \ - '(%(filename)s %(lineno)d): ' + \ - '%(levelname)s %(message)s' - - color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ - colored('(%(filename)s %(lineno)d)', 'yellow') + \ - ': %(levelname)s %(message)s' - - # create console handlers for master process - if dist_rank == 0: - - console_handler = logging.StreamHandler(sys.stdout) - - console_handler.setLevel(logging.DEBUG) - - console_handler.setFormatter( - logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) - - logger.addHandler(console_handler) - - # create file handlers - file_handler = logging.FileHandler(os.path.join( - output_dir, f'log_rank{dist_rank}.txt'), mode='a') - - file_handler.setLevel(logging.DEBUG) - - file_handler.setFormatter(logging.Formatter( - fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) - - logger.addHandler(file_handler) - - return logger diff --git a/pytorch_caney/models/maskrcnn_model.py b/pytorch_caney/template/__init__.py old mode 100755 new mode 100644 similarity index 100% rename from pytorch_caney/models/maskrcnn_model.py rename to pytorch_caney/template/__init__.py diff --git a/pytorch_caney/tests/config/test_config.yaml b/pytorch_caney/tests/config/test_config.yaml deleted file mode 100644 index c13ee63..0000000 --- a/pytorch_caney/tests/config/test_config.yaml +++ /dev/null @@ -1,27 +0,0 @@ -MODEL: - TYPE: swinv2 - NAME: test_config - DROP_PATH_RATE: 0.1 - SWINV2: - IN_CHANS: 7 - EMBED_DIM: 128 - DEPTHS: [ 2, 2, 18, 2 ] - NUM_HEADS: [ 4, 8, 16, 32 ] - WINDOW_SIZE: 12 -DATA: - IMG_SIZE: 192 - MASK_PATCH_SIZE: 32 - MASK_RATIO: 0.6 -TRAIN: - EPOCHS: 800 - WARMUP_EPOCHS: 10 - BASE_LR: 1e-4 - WARMUP_LR: 5e-7 - WEIGHT_DECAY: 0.05 - LR_SCHEDULER: - NAME: 'multistep' - GAMMA: 0.1 - MULTISTEPS: [700,] -PRINT_FREQ: 100 -SAVE_FREQ: 5 -TAG: test_config_tag \ No newline at end of file diff --git a/pytorch_caney/tests/test_build.py b/pytorch_caney/tests/test_build.py deleted file mode 100644 index a472882..0000000 --- a/pytorch_caney/tests/test_build.py +++ /dev/null @@ -1,50 +0,0 @@ -from pytorch_caney.models.build import build_model -from pytorch_caney.config import get_config - -import unittest -import argparse -import logging - - -class TestBuildModel(unittest.TestCase): - - def setUp(self): - # Initialize any required configuration here - config_path = 'pytorch_caney/' + \ - 'tests/config/test_config.yaml' - args = argparse.Namespace(cfg=config_path) - self.config = get_config(args) - self.logger = logging.getLogger("TestLogger") - self.logger.setLevel(logging.DEBUG) - - def test_build_mim_model(self): - _ = build_model(self.config, - pretrain=True, - pretrain_method='mim', - logger=self.logger) - # Add assertions here to validate the returned 'model' instance - # For example: self.assertIsInstance(model, YourMimModelClass) - - def test_build_swinv2_encoder(self): - _ = build_model(self.config, logger=self.logger) - # Add assertions here to validate the returned 'model' instance - # For example: self.assertIsInstance(model, SwinTransformerV2) - - def test_build_unet_decoder(self): - self.config.defrost() - self.config.MODEL.DECODER = 'unet' - self.config.freeze() - _ = build_model(self.config, logger=self.logger) - # Add assertions here to validate the returned 'model' instance - # For example: self.assertIsInstance(model, YourUnetSwinModelClass) - - def test_unknown_decoder_architecture(self): - self.config.defrost() - self.config.MODEL.DECODER = 'unknown_decoder' - self.config.freeze() - with self.assertRaises(NotImplementedError): - build_model(self.config, logger=self.logger) - - -if __name__ == '__main__': - unittest.main() diff --git a/pytorch_caney/tests/test_config.py b/pytorch_caney/tests/test_config.py deleted file mode 100644 index f75c534..0000000 --- a/pytorch_caney/tests/test_config.py +++ /dev/null @@ -1,44 +0,0 @@ -from pytorch_caney.config import get_config - -import argparse -import unittest - - -class TestConfig(unittest.TestCase): - - @classmethod - def setUpClass(cls): - cls.config_yaml_path = 'pytorch_caney/' + \ - 'tests/config/test_config.yaml' - - def test_default_config(self): - # Get the default configuration - args = argparse.Namespace(cfg=self.config_yaml_path) - config = get_config(args) - - # Test specific configuration values - self.assertEqual(config.DATA.BATCH_SIZE, 128) - self.assertEqual(config.DATA.DATASET, 'MODIS') - self.assertEqual(config.MODEL.TYPE, 'swinv2') - self.assertEqual(config.MODEL.NAME, 'test_config') - self.assertEqual(config.TRAIN.EPOCHS, 800) - - def test_custom_config(self): - # Test with custom arguments - args = argparse.Namespace( - cfg=self.config_yaml_path, - batch_size=64, - dataset='CustomDataset', - data_paths=['solongandthanksforallthefish'], - ) - config = get_config(args) - - # Test specific configuration values with custom arguments - self.assertEqual(config.DATA.BATCH_SIZE, 64) - self.assertEqual(config.DATA.DATASET, 'CustomDataset') - self.assertEqual(config.DATA.DATA_PATHS, - ['solongandthanksforallthefish']) - - -if __name__ == '__main__': - unittest.main() diff --git a/pytorch_caney/tests/test_data.py b/pytorch_caney/tests/test_data.py deleted file mode 100644 index d6a5852..0000000 --- a/pytorch_caney/tests/test_data.py +++ /dev/null @@ -1,38 +0,0 @@ -from pytorch_caney.data.datamodules.finetune_datamodule \ - import get_dataset_from_dict - -from pytorch_caney.data.datamodules.finetune_datamodule \ - import DATASETS - -import unittest - - -class TestGetDatasetFromDict(unittest.TestCase): - - def test_existing_datasets(self): - # Test existing datasets - for dataset_name in ['modis', 'modislc9', 'modislc5']: - dataset = get_dataset_from_dict(dataset_name) - self.assertIsNotNone(dataset) - - def test_non_existing_dataset(self): - # Test non-existing dataset - invalid_dataset_name = 'invalid_dataset' - with self.assertRaises(KeyError) as context: - get_dataset_from_dict(invalid_dataset_name) - expected_error_msg = f'"{invalid_dataset_name} ' + \ - 'is not an existing dataset. Available datasets:' + \ - f' {DATASETS.keys()}"' - self.assertEqual(str(context.exception), expected_error_msg) - - def test_dataset_name_case_insensitive(self): - # Test case insensitivity - dataset_name = 'MoDiSLC5' - dataset = get_dataset_from_dict(dataset_name) - self.assertIsNotNone(dataset) - -# Add more test cases as needed - - -if __name__ == '__main__': - unittest.main() diff --git a/pytorch_caney/tests/test_loss_utils.py b/pytorch_caney/tests/test_loss_utils.py deleted file mode 100644 index 74a256a..0000000 --- a/pytorch_caney/tests/test_loss_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -from pytorch_caney.loss.utils import to_tensor - -import unittest -import numpy as np -import torch - - -class TestToTensorFunction(unittest.TestCase): - - def test_tensor_input(self): - tensor = torch.tensor([1, 2, 3]) - result = to_tensor(tensor) - self.assertTrue(torch.equal(result, tensor)) - - def test_tensor_input_with_dtype(self): - tensor = torch.tensor([1, 2, 3]) - result = to_tensor(tensor, dtype=torch.float32) - self.assertTrue(torch.equal(result, tensor.float())) - - def test_numpy_array_input(self): - numpy_array = np.array([1, 2, 3]) - expected_tensor = torch.tensor([1, 2, 3]) - result = to_tensor(numpy_array) - self.assertTrue(torch.equal(result, expected_tensor)) - - def test_numpy_array_input_with_dtype(self): - numpy_array = np.array([1, 2, 3]) - expected_tensor = torch.tensor([1, 2, 3], dtype=torch.float32) - result = to_tensor(numpy_array, dtype=torch.float32) - self.assertTrue(torch.equal(result, expected_tensor)) - - def test_list_input(self): - input_list = [1, 2, 3] - expected_tensor = torch.tensor([1, 2, 3]) - result = to_tensor(input_list) - self.assertTrue(torch.equal(result, expected_tensor)) - - def test_list_input_with_dtype(self): - input_list = [1, 2, 3] - expected_tensor = torch.tensor([1, 2, 3], dtype=torch.float32) - result = to_tensor(input_list, dtype=torch.float32) - self.assertTrue(torch.equal(result, expected_tensor)) - - -if __name__ == '__main__': - unittest.main() diff --git a/pytorch_caney/tests/test_lr_scheduler.py b/pytorch_caney/tests/test_lr_scheduler.py deleted file mode 100644 index f0cd7f2..0000000 --- a/pytorch_caney/tests/test_lr_scheduler.py +++ /dev/null @@ -1,48 +0,0 @@ -from pytorch_caney.lr_scheduler import build_scheduler - -import unittest -from unittest.mock import Mock, patch - - -class TestBuildScheduler(unittest.TestCase): - def setUp(self): - self.config = Mock( - TRAIN=Mock( - EPOCHS=300, - WARMUP_EPOCHS=20, - MIN_LR=1e-6, - WARMUP_LR=1e-7, - LR_SCHEDULER=Mock( - NAME='cosine', - DECAY_EPOCHS=30, - DECAY_RATE=0.1, - MULTISTEPS=[50, 100], - GAMMA=0.1 - ) - ) - ) - - self.optimizer = Mock() - self.n_iter_per_epoch = 100 # Example value - - def test_build_cosine_scheduler(self): - with patch('pytorch_caney.lr_scheduler.CosineLRScheduler') \ - as mock_cosine_scheduler: - _ = build_scheduler(self.config, - self.optimizer, - self.n_iter_per_epoch) - - mock_cosine_scheduler.assert_called_once_with( - self.optimizer, - t_initial=300 * 100, - cycle_mul=1., - lr_min=1e-6, - warmup_lr_init=1e-7, - warmup_t=20 * 100, - cycle_limit=1, - t_in_epochs=False - ) - - -if __name__ == '__main__': - unittest.main() diff --git a/pytorch_caney/tests/test_transforms.py b/pytorch_caney/tests/test_transforms.py deleted file mode 100644 index 9656e0b..0000000 --- a/pytorch_caney/tests/test_transforms.py +++ /dev/null @@ -1,70 +0,0 @@ -from pytorch_caney.config import get_config -from pytorch_caney.data.transforms import SimmimTransform -from pytorch_caney.data.transforms import TensorResizeTransform - -import argparse -import unittest -import torch -import numpy as np - - -class TestTransforms(unittest.TestCase): - - def setUp(self): - # Initialize any required configuration here - config_path = 'pytorch_caney/' + \ - 'tests/config/test_config.yaml' - args = argparse.Namespace(cfg=config_path) - self.config = get_config(args) - - def test_simmim_transform(self): - - # Create an instance of SimmimTransform - transform = SimmimTransform(self.config) - - # Create a sample ndarray - img = np.random.randn(self.config.DATA.IMG_SIZE, - self.config.DATA.IMG_SIZE, - 7) - - # Apply the transform - img_transformed, mask = transform(img) - - # Assertions - self.assertIsInstance(img_transformed, torch.Tensor) - self.assertEqual(img_transformed.shape, (7, - self.config.DATA.IMG_SIZE, - self.config.DATA.IMG_SIZE)) - self.assertIsInstance(mask, np.ndarray) - - def test_tensor_resize_transform(self): - # Create an instance of TensorResizeTransform - transform = TensorResizeTransform(self.config) - - # Create a sample image tensor - img = np.random.randn(self.config.DATA.IMG_SIZE, - self.config.DATA.IMG_SIZE, - 7) - - target = np.random.randint(0, 5, - size=((self.config.DATA.IMG_SIZE, - self.config.DATA.IMG_SIZE))) - - # Apply the transform - img_transformed = transform(img) - target_transformed = transform(target) - - # Assertions - self.assertIsInstance(img_transformed, torch.Tensor) - self.assertEqual(img_transformed.shape, - (7, self.config.DATA.IMG_SIZE, - self.config.DATA.IMG_SIZE)) - - self.assertIsInstance(target_transformed, torch.Tensor) - self.assertEqual(target_transformed.shape, - (1, self.config.DATA.IMG_SIZE, - self.config.DATA.IMG_SIZE)) - - -if __name__ == '__main__': - unittest.main() diff --git a/pytorch_caney/training/mim_utils.py b/pytorch_caney/training/mim_utils.py deleted file mode 100644 index 5373a98..0000000 --- a/pytorch_caney/training/mim_utils.py +++ /dev/null @@ -1,720 +0,0 @@ -from functools import partial -from torch import optim as optim - -import os -import torch -import torch.distributed as dist -import numpy as np -from scipy import interpolate - - -def build_optimizer(config, model, is_pretrain=False, logger=None): - """ - Build optimizer, set weight decay of normalization to 0 by default. - AdamW only. - """ - logger.info('>>>>>>>>>> Build Optimizer') - - skip = {} - - skip_keywords = {} - - if hasattr(model, 'no_weight_decay'): - skip = model.no_weight_decay() - - if hasattr(model, 'no_weight_decay_keywords'): - skip_keywords = model.no_weight_decay_keywords() - - if is_pretrain: - parameters = get_pretrain_param_groups(model, skip, skip_keywords) - - else: - - depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ - else config.MODEL.SWINV2.DEPTHS - - num_layers = sum(depths) - - get_layer_func = partial(get_swin_layer, - num_layers=num_layers + 2, - depths=depths) - - scales = list(config.TRAIN.LAYER_DECAY ** i for i in - reversed(range(num_layers + 2))) - - parameters = get_finetune_param_groups(model, - config.TRAIN.BASE_LR, - config.TRAIN.WEIGHT_DECAY, - get_layer_func, - scales, - skip, - skip_keywords) - - optimizer = None - - optimizer = optim.AdamW(parameters, - eps=config.TRAIN.OPTIMIZER.EPS, - betas=config.TRAIN.OPTIMIZER.BETAS, - lr=config.TRAIN.BASE_LR, - weight_decay=config.TRAIN.WEIGHT_DECAY) - - logger.info(optimizer) - - return optimizer - - -def set_weight_decay(model, skip_list=(), skip_keywords=()): - """ - - Args: - model (_type_): _description_ - skip_list (tuple, optional): _description_. Defaults to (). - skip_keywords (tuple, optional): _description_. Defaults to (). - - Returns: - _type_: _description_ - """ - - has_decay = [] - - no_decay = [] - - for name, param in model.named_parameters(): - - if not param.requires_grad: - - continue # frozen weights - - if len(param.shape) == 1 or name.endswith(".bias") \ - or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): - - no_decay.append(param) - - else: - - has_decay.append(param) - - return [{'params': has_decay}, - {'params': no_decay, 'weight_decay': 0.}] - - -def check_keywords_in_name(name, keywords=()): - - isin = False - - for keyword in keywords: - - if keyword in name: - - isin = True - - return isin - - -def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): - - has_decay = [] - - no_decay = [] - - has_decay_name = [] - - no_decay_name = [] - - for name, param in model.named_parameters(): - - if not param.requires_grad: - - continue - - if len(param.shape) == 1 or name.endswith(".bias") or \ - (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): - - no_decay.append(param) - - no_decay_name.append(name) - - else: - - has_decay.append(param) - - has_decay_name.append(name) - - return [{'params': has_decay}, - {'params': no_decay, 'weight_decay': 0.}] - - -def get_swin_layer(name, num_layers, depths): - - if name in ("mask_token"): - - return 0 - - elif name.startswith("patch_embed"): - - return 0 - - elif name.startswith("layers"): - - layer_id = int(name.split('.')[1]) - - block_id = name.split('.')[3] - - if block_id == 'reduction' or block_id == 'norm': - - return sum(depths[:layer_id + 1]) - - layer_id = sum(depths[:layer_id]) + int(block_id) - - return layer_id + 1 - - else: - - return num_layers - 1 - - -def get_finetune_param_groups(model, - lr, - weight_decay, - get_layer_func, - scales, - skip_list=(), - skip_keywords=()): - - parameter_group_names = {} - - parameter_group_vars = {} - - for name, param in model.named_parameters(): - - if not param.requires_grad: - - continue - - if len(param.shape) == 1 or name.endswith(".bias") \ - or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): - - group_name = "no_decay" - - this_weight_decay = 0. - - else: - - group_name = "decay" - - this_weight_decay = weight_decay - - if get_layer_func is not None: - - layer_id = get_layer_func(name) - - group_name = "layer_%d_%s" % (layer_id, group_name) - - else: - - layer_id = None - - if group_name not in parameter_group_names: - - if scales is not None: - - scale = scales[layer_id] - - else: - - scale = 1. - - parameter_group_names[group_name] = { - "group_name": group_name, - "weight_decay": this_weight_decay, - "params": [], - "lr": lr * scale, - "lr_scale": scale, - } - - parameter_group_vars[group_name] = { - "group_name": group_name, - "weight_decay": this_weight_decay, - "params": [], - "lr": lr * scale, - "lr_scale": scale - } - - parameter_group_vars[group_name]["params"].append(param) - - parameter_group_names[group_name]["params"].append(name) - - return list(parameter_group_vars.values()) - - -def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): - - logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") - - if config.MODEL.RESUME.startswith('https'): - - checkpoint = torch.hub.load_state_dict_from_url( - config.MODEL.RESUME, map_location='cpu', check_hash=True) - - else: - - checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') - - # re-map keys due to name change (only for loading provided models) - rpe_mlp_keys = [k for k in checkpoint['module'].keys() if "rpe_mlp" in k] - - for k in rpe_mlp_keys: - - checkpoint['module'][k.replace( - 'rpe_mlp', 'cpb_mlp')] = checkpoint['module'].pop(k) - - msg = model.load_state_dict(checkpoint['module'], strict=False) - - logger.info(msg) - - max_accuracy = 0.0 - - if not config.EVAL_MODE and 'optimizer' in checkpoint \ - and 'lr_scheduler' in checkpoint \ - and 'scaler' in checkpoint \ - and 'epoch' in checkpoint: - - optimizer.load_state_dict(checkpoint['optimizer']) - - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - - scaler.load_state_dict(checkpoint['scaler']) - - config.defrost() - config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 - config.freeze() - - logger.info( - f"=> loaded successfully '{config.MODEL.RESUME}' " + - f"(epoch {checkpoint['epoch']})") - - if 'max_accuracy' in checkpoint: - max_accuracy = checkpoint['max_accuracy'] - - else: - max_accuracy = 0.0 - - del checkpoint - - torch.cuda.empty_cache() - - return max_accuracy - - -def save_checkpoint(config, epoch, model, max_accuracy, - optimizer, lr_scheduler, scaler, logger): - - save_state = {'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'scaler': scaler.state_dict(), - 'max_accuracy': max_accuracy, - 'epoch': epoch, - 'config': config} - - save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') - - logger.info(f"{save_path} saving......") - - torch.save(save_state, save_path) - - logger.info(f"{save_path} saved !!!") - - -def get_grad_norm(parameters, norm_type=2): - - if isinstance(parameters, torch.Tensor): - - parameters = [parameters] - - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - norm_type = float(norm_type) - - total_norm = 0 - - for p in parameters: - - param_norm = p.grad.data.norm(norm_type) - - total_norm += param_norm.item() ** norm_type - - total_norm = total_norm ** (1. / norm_type) - - return total_norm - - -def auto_resume_helper(output_dir, logger): - - checkpoints = os.listdir(output_dir) - - checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] - - logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") - - if len(checkpoints) > 0: - - latest_checkpoint = max([os.path.join(output_dir, d) - for d in checkpoints], key=os.path.getmtime) - - logger.info(f"The latest checkpoint founded: {latest_checkpoint}") - - resume_file = latest_checkpoint - - else: - - resume_file = None - - return resume_file - - -def reduce_tensor(tensor): - - rt = tensor.clone() - - dist.all_reduce(rt, op=dist.ReduceOp.SUM) - - rt /= dist.get_world_size() - - return rt - - -def load_pretrained(config, model, logger): - - logger.info( - f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") - - checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') - - try: - - checkpoint_model = checkpoint['model'] - - except KeyError: - - try: - - checkpoint_model = checkpoint['module'] - - except KeyError: - - errorMsg = 'Ckpt model does not have key "model" or "module"' - - raise RuntimeError(errorMsg) - - if any([True if 'encoder.' in k else - False for k in checkpoint_model.keys()]): - - checkpoint_model = {k.replace( - 'encoder.', ''): v for k, v in checkpoint_model.items() - if k.startswith('encoder.')} - - logger.info('Detect pre-trained model, remove [encoder.] prefix.') - - else: - - logger.info( - 'Detect non-pre-trained model, pass without doing anything.') - - if config.MODEL.TYPE in ['swin', 'swinv2']: - - logger.info( - ">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") - - checkpoint = remap_pretrained_keys_swin( - model, checkpoint_model, logger) - - else: - - raise NotImplementedError - - msg = model.load_state_dict(checkpoint_model, strict=False) - - logger.info(msg) - - del checkpoint - - torch.cuda.empty_cache() - - logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") - - -def remap_pretrained_keys_swin(model, checkpoint_model, logger): - - state_dict = model.state_dict() - - # Geometric interpolation when pre-trained patch size mismatch - # with fine-tuned patch size - all_keys = list(checkpoint_model.keys()) - - for key in all_keys: - - if "relative_position_bias_table" in key: - - logger.info(f"Key: {key}") - - rel_position_bias_table_pretrained = checkpoint_model[key] - - rel_position_bias_table_current = state_dict[key] - - L1, nH1 = rel_position_bias_table_pretrained.size() - - L2, nH2 = rel_position_bias_table_current.size() - - if nH1 != nH2: - logger.info(f"Error in loading {key}, passing......") - - else: - - if L1 != L2: - - logger.info( - f"{key}: Interpolate " + - "relative_position_bias_table using geo.") - - src_size = int(L1 ** 0.5) - - dst_size = int(L2 ** 0.5) - - def geometric_progression(a, r, n): - return a * (1.0 - r ** n) / (1.0 - r) - - left, right = 1.01, 1.5 - - while right - left > 1e-6: - - q = (left + right) / 2.0 - - gp = geometric_progression(1, q, src_size // 2) - - if gp > dst_size // 2: - - right = q - - else: - - left = q - - # if q > 1.090307: - # q = 1.090307 - - dis = [] - - cur = 1 - - for i in range(src_size // 2): - - dis.append(cur) - - cur += q ** (i + 1) - - r_ids = [-_ for _ in reversed(dis)] - - x = r_ids + [0] + dis - - y = r_ids + [0] + dis - - t = dst_size // 2.0 - - dx = np.arange(-t, t + 0.1, 1.0) - - dy = np.arange(-t, t + 0.1, 1.0) - - logger.info("Original positions = %s" % str(x)) - - logger.info("Target positions = %s" % str(dx)) - - all_rel_pos_bias = [] - - for i in range(nH1): - - z = rel_position_bias_table_pretrained[:, i].view( - src_size, src_size).float().numpy() - - f_cubic = interpolate.interp2d(x, y, z, kind='cubic') - - all_rel_pos_bias_host = \ - torch.Tensor(f_cubic(dx, dy) - ).contiguous().view(-1, 1) - - all_rel_pos_bias.append( - all_rel_pos_bias_host.to( - rel_position_bias_table_pretrained.device)) - - new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) - - checkpoint_model[key] = new_rel_pos_bias - - # delete relative_position_index since we always re-init it - relative_position_index_keys = [ - k for k in checkpoint_model.keys() if "relative_position_index" in k] - - for k in relative_position_index_keys: - - del checkpoint_model[k] - - # delete relative_coords_table since we always re-init it - relative_coords_table_keys = [ - k for k in checkpoint_model.keys() if "relative_coords_table" in k] - - for k in relative_coords_table_keys: - - del checkpoint_model[k] - - # delete attn_mask since we always re-init it - attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] - - for k in attn_mask_keys: - - del checkpoint_model[k] - - return checkpoint_model - - -def remap_pretrained_keys_vit(model, checkpoint_model, logger): - - # Duplicate shared rel_pos_bias to each layer - if getattr(model, 'use_rel_pos_bias', False) and \ - "rel_pos_bias.relative_position_bias_table" in checkpoint_model: - - logger.info( - "Expand the shared relative position " + - "embedding to each transformer block.") - - num_layers = model.get_num_layers() - - rel_pos_bias = \ - checkpoint_model["rel_pos_bias.relative_position_bias_table"] - - for i in range(num_layers): - - checkpoint_model["blocks.%d.attn.relative_position_bias_table" % - i] = rel_pos_bias.clone() - - checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") - - # Geometric interpolation when pre-trained patch - # size mismatch with fine-tuned patch size - all_keys = list(checkpoint_model.keys()) - - for key in all_keys: - - if "relative_position_index" in key: - - checkpoint_model.pop(key) - - if "relative_position_bias_table" in key: - - rel_pos_bias = checkpoint_model[key] - - src_num_pos, num_attn_heads = rel_pos_bias.size() - - dst_num_pos, _ = model.state_dict()[key].size() - - dst_patch_shape = model.patch_embed.patch_shape - - if dst_patch_shape[0] != dst_patch_shape[1]: - - raise NotImplementedError() - - num_extra_tokens = dst_num_pos - \ - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) - - src_size = int((src_num_pos - num_extra_tokens) ** 0.5) - - dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) - - if src_size != dst_size: - - logger.info("Position interpolate for " + - "%s from %dx%d to %dx%d" % ( - key, - src_size, - src_size, - dst_size, - dst_size)) - - extra_tokens = rel_pos_bias[-num_extra_tokens:, :] - - rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] - - def geometric_progression(a, r, n): - - return a * (1.0 - r ** n) / (1.0 - r) - - left, right = 1.01, 1.5 - - while right - left > 1e-6: - - q = (left + right) / 2.0 - - gp = geometric_progression(1, q, src_size // 2) - - if gp > dst_size // 2: - - right = q - - else: - - left = q - - # if q > 1.090307: - # q = 1.090307 - - dis = [] - - cur = 1 - - for i in range(src_size // 2): - - dis.append(cur) - - cur += q ** (i + 1) - - r_ids = [-_ for _ in reversed(dis)] - - x = r_ids + [0] + dis - - y = r_ids + [0] + dis - - t = dst_size // 2.0 - - dx = np.arange(-t, t + 0.1, 1.0) - - dy = np.arange(-t, t + 0.1, 1.0) - - logger.info("Original positions = %s" % str(x)) - - logger.info("Target positions = %s" % str(dx)) - - all_rel_pos_bias = [] - - for i in range(num_attn_heads): - - z = rel_pos_bias[:, i].view( - src_size, src_size).float().numpy() - - f = interpolate.interp2d(x, y, z, kind='cubic') - - all_rel_pos_bias_host = \ - torch.Tensor(f(dx, dy)).contiguous().view(-1, 1) - - all_rel_pos_bias.append( - all_rel_pos_bias_host.to(rel_pos_bias.device)) - - rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) - - new_rel_pos_bias = torch.cat( - (rel_pos_bias, extra_tokens), dim=0) - - checkpoint_model[key] = new_rel_pos_bias - - return checkpoint_model diff --git a/pytorch_caney/training/pre_training.py b/pytorch_caney/training/pre_training.py deleted file mode 100644 index e69de29..0000000 diff --git a/pytorch_caney/training/simmim_utils.py b/pytorch_caney/training/simmim_utils.py deleted file mode 100644 index 949f307..0000000 --- a/pytorch_caney/training/simmim_utils.py +++ /dev/null @@ -1,706 +0,0 @@ -from functools import partial -from torch import optim as optim - -import os -import torch -import torch.distributed as dist -import numpy as np -from scipy import interpolate - - -def build_optimizer(config, model, is_pretrain=False, logger=None): - """ - Build optimizer, set weight decay of normalization to 0 by default. - AdamW only. - """ - logger.info('>>>>>>>>>> Build Optimizer') - - skip = {} - - skip_keywords = {} - - if hasattr(model, 'no_weight_decay'): - skip = model.no_weight_decay() - - if hasattr(model, 'no_weight_decay_keywords'): - skip_keywords = model.no_weight_decay_keywords() - - if is_pretrain: - parameters = get_pretrain_param_groups(model, skip, skip_keywords) - - else: - - depths = config.MODEL.SWIN.DEPTHS if config.MODEL.TYPE == 'swin' \ - else config.MODEL.SWINV2.DEPTHS - - num_layers = sum(depths) - - get_layer_func = partial(get_swin_layer, - num_layers=num_layers + 2, - depths=depths) - - scales = list(config.TRAIN.LAYER_DECAY ** i for i in - reversed(range(num_layers + 2))) - - parameters = get_finetune_param_groups(model, - config.TRAIN.BASE_LR, - config.TRAIN.WEIGHT_DECAY, - get_layer_func, - scales, - skip, - skip_keywords) - - optimizer = None - - optimizer = optim.AdamW(parameters, - eps=config.TRAIN.OPTIMIZER.EPS, - betas=config.TRAIN.OPTIMIZER.BETAS, - lr=config.TRAIN.BASE_LR, - weight_decay=config.TRAIN.WEIGHT_DECAY) - - logger.info(optimizer) - - return optimizer - - -def set_weight_decay(model, skip_list=(), skip_keywords=()): - """ - - Args: - model (_type_): _description_ - skip_list (tuple, optional): _description_. Defaults to (). - skip_keywords (tuple, optional): _description_. Defaults to (). - - Returns: - _type_: _description_ - """ - - has_decay = [] - - no_decay = [] - - for name, param in model.named_parameters(): - - if not param.requires_grad: - - continue # frozen weights - - if len(param.shape) == 1 or name.endswith(".bias") \ - or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): - - no_decay.append(param) - - else: - - has_decay.append(param) - - return [{'params': has_decay}, - {'params': no_decay, 'weight_decay': 0.}] - - -def check_keywords_in_name(name, keywords=()): - - isin = False - - for keyword in keywords: - - if keyword in name: - - isin = True - - return isin - - -def get_pretrain_param_groups(model, skip_list=(), skip_keywords=()): - - has_decay = [] - - no_decay = [] - - has_decay_name = [] - - no_decay_name = [] - - for name, param in model.named_parameters(): - - if not param.requires_grad: - - continue - - if len(param.shape) == 1 or name.endswith(".bias") or \ - (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): - - no_decay.append(param) - - no_decay_name.append(name) - - else: - - has_decay.append(param) - - has_decay_name.append(name) - - return [{'params': has_decay}, - {'params': no_decay, 'weight_decay': 0.}] - - -def get_swin_layer(name, num_layers, depths): - - if name in ("mask_token"): - - return 0 - - elif name.startswith("patch_embed"): - - return 0 - - elif name.startswith("layers"): - - layer_id = int(name.split('.')[1]) - - block_id = name.split('.')[3] - - if block_id == 'reduction' or block_id == 'norm': - - return sum(depths[:layer_id + 1]) - - layer_id = sum(depths[:layer_id]) + int(block_id) - - return layer_id + 1 - - else: - - return num_layers - 1 - - -def get_finetune_param_groups(model, - lr, - weight_decay, - get_layer_func, - scales, - skip_list=(), - skip_keywords=()): - - parameter_group_names = {} - - parameter_group_vars = {} - - for name, param in model.named_parameters(): - - if not param.requires_grad: - - continue - - if len(param.shape) == 1 or name.endswith(".bias") \ - or (name in skip_list) or \ - check_keywords_in_name(name, skip_keywords): - - group_name = "no_decay" - - this_weight_decay = 0. - - else: - - group_name = "decay" - - this_weight_decay = weight_decay - - if get_layer_func is not None: - - layer_id = get_layer_func(name) - - group_name = "layer_%d_%s" % (layer_id, group_name) - - else: - - layer_id = None - - if group_name not in parameter_group_names: - - if scales is not None: - - scale = scales[layer_id] - - else: - - scale = 1. - - parameter_group_names[group_name] = { - "group_name": group_name, - "weight_decay": this_weight_decay, - "params": [], - "lr": lr * scale, - "lr_scale": scale, - } - - parameter_group_vars[group_name] = { - "group_name": group_name, - "weight_decay": this_weight_decay, - "params": [], - "lr": lr * scale, - "lr_scale": scale - } - - parameter_group_vars[group_name]["params"].append(param) - - parameter_group_names[group_name]["params"].append(name) - - return list(parameter_group_vars.values()) - - -def load_checkpoint(config, model, optimizer, lr_scheduler, scaler, logger): - - logger.info(f">>>>>>>>>> Resuming from {config.MODEL.RESUME} ..........") - - if config.MODEL.RESUME.startswith('https'): - - checkpoint = torch.hub.load_state_dict_from_url( - config.MODEL.RESUME, map_location='cpu', check_hash=True) - - else: - - checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') - - # re-map keys due to name change (only for loading provided models) - rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k] - - for k in rpe_mlp_keys: - - checkpoint['model'][k.replace( - 'rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k) - - msg = model.load_state_dict(checkpoint['model'], strict=False) - - logger.info(msg) - - max_accuracy = 0.0 - - if not config.EVAL_MODE and 'optimizer' in checkpoint \ - and 'lr_scheduler' in checkpoint \ - and 'scaler' in checkpoint \ - and 'epoch' in checkpoint: - - optimizer.load_state_dict(checkpoint['optimizer']) - - lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) - - scaler.load_state_dict(checkpoint['scaler']) - - config.defrost() - config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 - config.freeze() - - logger.info( - f"=> loaded successfully '{config.MODEL.RESUME}' " + - f"(epoch {checkpoint['epoch']})") - - if 'max_accuracy' in checkpoint: - max_accuracy = checkpoint['max_accuracy'] - - else: - max_accuracy = 0.0 - - del checkpoint - - torch.cuda.empty_cache() - - return max_accuracy - - -def save_checkpoint(config, epoch, model, max_accuracy, - optimizer, lr_scheduler, scaler, logger): - - save_state = {'model': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'scaler': scaler.state_dict(), - 'max_accuracy': max_accuracy, - 'epoch': epoch, - 'config': config} - - save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') - - logger.info(f"{save_path} saving......") - - torch.save(save_state, save_path) - - logger.info(f"{save_path} saved !!!") - - -def get_grad_norm(parameters, norm_type=2): - - if isinstance(parameters, torch.Tensor): - - parameters = [parameters] - - parameters = list(filter(lambda p: p.grad is not None, parameters)) - - norm_type = float(norm_type) - - total_norm = 0 - - for p in parameters: - - param_norm = p.grad.data.norm(norm_type) - - total_norm += param_norm.item() ** norm_type - - total_norm = total_norm ** (1. / norm_type) - - return total_norm - - -def auto_resume_helper(output_dir, logger): - - checkpoints = os.listdir(output_dir) - - checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] - - logger.info(f"All checkpoints founded in {output_dir}: {checkpoints}") - - if len(checkpoints) > 0: - - latest_checkpoint = max([os.path.join(output_dir, d) - for d in checkpoints], key=os.path.getmtime) - - logger.info(f"The latest checkpoint founded: {latest_checkpoint}") - - resume_file = latest_checkpoint - - else: - - resume_file = None - - return resume_file - - -def reduce_tensor(tensor): - - rt = tensor.clone() - - dist.all_reduce(rt, op=dist.ReduceOp.SUM) - - rt /= dist.get_world_size() - - return rt - - -def load_pretrained(config, model, logger): - - logger.info( - f">>>>>>>>>> Fine-tuned from {config.MODEL.PRETRAINED} ..........") - - checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu') - - checkpoint_model = checkpoint['model'] - - if any([True if 'encoder.' in k else - False for k in checkpoint_model.keys()]): - - checkpoint_model = {k.replace( - 'encoder.', ''): v for k, v in checkpoint_model.items() - if k.startswith('encoder.')} - - logger.info('Detect pre-trained model, remove [encoder.] prefix.') - - else: - - logger.info( - 'Detect non-pre-trained model, pass without doing anything.') - - if config.MODEL.TYPE in ['swin', 'swinv2']: - - logger.info( - ">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") - - checkpoint = remap_pretrained_keys_swin( - model, checkpoint_model, logger) - - else: - - raise NotImplementedError - - msg = model.load_state_dict(checkpoint_model, strict=False) - - logger.info(msg) - - del checkpoint - - torch.cuda.empty_cache() - - logger.info(f">>>>>>>>>> loaded successfully '{config.MODEL.PRETRAINED}'") - - -def remap_pretrained_keys_swin(model, checkpoint_model, logger): - - state_dict = model.state_dict() - - # Geometric interpolation when pre-trained patch size mismatch - # with fine-tuned patch size - all_keys = list(checkpoint_model.keys()) - - for key in all_keys: - - if "relative_position_bias_table" in key: - - logger.info(f"Key: {key}") - - rel_position_bias_table_pretrained = checkpoint_model[key] - - rel_position_bias_table_current = state_dict[key] - - L1, nH1 = rel_position_bias_table_pretrained.size() - - L2, nH2 = rel_position_bias_table_current.size() - - if nH1 != nH2: - logger.info(f"Error in loading {key}, passing......") - - else: - - if L1 != L2: - - logger.info( - f"{key}: Interpolate " + - "relative_position_bias_table using geo.") - - src_size = int(L1 ** 0.5) - - dst_size = int(L2 ** 0.5) - - def geometric_progression(a, r, n): - return a * (1.0 - r ** n) / (1.0 - r) - - left, right = 1.01, 1.5 - - while right - left > 1e-6: - - q = (left + right) / 2.0 - - gp = geometric_progression(1, q, src_size // 2) - - if gp > dst_size // 2: - - right = q - - else: - - left = q - - # if q > 1.090307: - # q = 1.090307 - - dis = [] - - cur = 1 - - for i in range(src_size // 2): - - dis.append(cur) - - cur += q ** (i + 1) - - r_ids = [-_ for _ in reversed(dis)] - - x = r_ids + [0] + dis - - y = r_ids + [0] + dis - - t = dst_size // 2.0 - - dx = np.arange(-t, t + 0.1, 1.0) - - dy = np.arange(-t, t + 0.1, 1.0) - - logger.info("Original positions = %s" % str(x)) - - logger.info("Target positions = %s" % str(dx)) - - all_rel_pos_bias = [] - - for i in range(nH1): - - z = rel_position_bias_table_pretrained[:, i].view( - src_size, src_size).float().numpy() - - f_cubic = interpolate.interp2d(x, y, z, kind='cubic') - - all_rel_pos_bias_host = \ - torch.Tensor(f_cubic(dx, dy) - ).contiguous().view(-1, 1) - - all_rel_pos_bias.append( - all_rel_pos_bias_host.to( - rel_position_bias_table_pretrained.device)) - - new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) - - checkpoint_model[key] = new_rel_pos_bias - - # delete relative_position_index since we always re-init it - relative_position_index_keys = [ - k for k in checkpoint_model.keys() if "relative_position_index" in k] - - for k in relative_position_index_keys: - - del checkpoint_model[k] - - # delete relative_coords_table since we always re-init it - relative_coords_table_keys = [ - k for k in checkpoint_model.keys() if "relative_coords_table" in k] - - for k in relative_coords_table_keys: - - del checkpoint_model[k] - - # delete attn_mask since we always re-init it - attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] - - for k in attn_mask_keys: - - del checkpoint_model[k] - - return checkpoint_model - - -def remap_pretrained_keys_vit(model, checkpoint_model, logger): - - # Duplicate shared rel_pos_bias to each layer - if getattr(model, 'use_rel_pos_bias', False) and \ - "rel_pos_bias.relative_position_bias_table" in checkpoint_model: - - logger.info( - "Expand the shared relative position " + - "embedding to each transformer block.") - - num_layers = model.get_num_layers() - - rel_pos_bias = \ - checkpoint_model["rel_pos_bias.relative_position_bias_table"] - - for i in range(num_layers): - - checkpoint_model["blocks.%d.attn.relative_position_bias_table" % - i] = rel_pos_bias.clone() - - checkpoint_model.pop("rel_pos_bias.relative_position_bias_table") - - # Geometric interpolation when pre-trained patch - # size mismatch with fine-tuned patch size - all_keys = list(checkpoint_model.keys()) - - for key in all_keys: - - if "relative_position_index" in key: - - checkpoint_model.pop(key) - - if "relative_position_bias_table" in key: - - rel_pos_bias = checkpoint_model[key] - - src_num_pos, num_attn_heads = rel_pos_bias.size() - - dst_num_pos, _ = model.state_dict()[key].size() - - dst_patch_shape = model.patch_embed.patch_shape - - if dst_patch_shape[0] != dst_patch_shape[1]: - - raise NotImplementedError() - - num_extra_tokens = dst_num_pos - \ - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1) - - src_size = int((src_num_pos - num_extra_tokens) ** 0.5) - - dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5) - - if src_size != dst_size: - - logger.info("Position interpolate for " + - "%s from %dx%d to %dx%d" % ( - key, - src_size, - src_size, - dst_size, - dst_size)) - - extra_tokens = rel_pos_bias[-num_extra_tokens:, :] - - rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :] - - def geometric_progression(a, r, n): - - return a * (1.0 - r ** n) / (1.0 - r) - - left, right = 1.01, 1.5 - - while right - left > 1e-6: - - q = (left + right) / 2.0 - - gp = geometric_progression(1, q, src_size // 2) - - if gp > dst_size // 2: - - right = q - - else: - - left = q - - # if q > 1.090307: - # q = 1.090307 - - dis = [] - - cur = 1 - - for i in range(src_size // 2): - - dis.append(cur) - - cur += q ** (i + 1) - - r_ids = [-_ for _ in reversed(dis)] - - x = r_ids + [0] + dis - - y = r_ids + [0] + dis - - t = dst_size // 2.0 - - dx = np.arange(-t, t + 0.1, 1.0) - - dy = np.arange(-t, t + 0.1, 1.0) - - logger.info("Original positions = %s" % str(x)) - - logger.info("Target positions = %s" % str(dx)) - - all_rel_pos_bias = [] - - for i in range(num_attn_heads): - - z = rel_pos_bias[:, i].view( - src_size, src_size).float().numpy() - - f = interpolate.interp2d(x, y, z, kind='cubic') - - all_rel_pos_bias_host = \ - torch.Tensor(f(dx, dy)).contiguous().view(-1, 1) - - all_rel_pos_bias.append( - all_rel_pos_bias_host.to(rel_pos_bias.device)) - - rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) - - new_rel_pos_bias = torch.cat( - (rel_pos_bias, extra_tokens), dim=0) - - checkpoint_model[key] = new_rel_pos_bias - - return checkpoint_model diff --git a/pytorch_caney/training/utils.py b/pytorch_caney/training/utils.py deleted file mode 100644 index e69de29..0000000 diff --git a/pytorch_caney/training/fine_tuning.py b/pytorch_caney/transforms/__init__.py similarity index 100% rename from pytorch_caney/training/fine_tuning.py rename to pytorch_caney/transforms/__init__.py diff --git a/pytorch_caney/transforms/abi_radiance_conversion.py b/pytorch_caney/transforms/abi_radiance_conversion.py new file mode 100644 index 0000000..4470b75 --- /dev/null +++ b/pytorch_caney/transforms/abi_radiance_conversion.py @@ -0,0 +1,59 @@ +import numpy as np + + +# ----------------------------------------------------------------------------- +# vis_calibrate +# ----------------------------------------------------------------------------- +def vis_calibrate(data): + """Calibrate visible channels to reflectance.""" + solar_irradiance = np.array(2017) + esd = np.array(0.99) + factor = np.pi * esd * esd / solar_irradiance + + return data * np.float32(factor) * 100 + + +# ----------------------------------------------------------------------------- +# ir_calibrate +# ----------------------------------------------------------------------------- +def ir_calibrate(data): + """Calibrate IR channels to BT.""" + fk1 = np.array(13432.1), + fk2 = np.array(1497.61), + bc1 = np.array(0.09102), + bc2 = np.array(0.99971), + + # if self.clip_negative_radiances: + # min_rad = self._get_minimum_radiance(data) + # data = data.clip(min=data.dtype.type(min_rad)) + + res = (fk2 / np.log(fk1 / data + 1) - bc1) / bc2 + return res + + +# ----------------------------------------------------------------------------- +# ConvertABIToReflectanceBT +# ----------------------------------------------------------------------------- +class ConvertABIToReflectanceBT(object): + """ + Performs scaling of MODIS TOA data + - Scales reflectance percentages to reflectance units (% -> (0,1)) + - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) + """ + + def __init__(self): + + self.reflectance_indices = [0, 1, 2, 3, 4, 6] + self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] + + def __call__(self, img): + + # Reflectance % to reflectance units + img[:, :, self.reflectance_indices] = \ + vis_calibrate(img[:, :, self.reflectance_indices]) + + # Brightness temp scaled to (0,1) range + img[:, :, self.emissive_indices] = ir_calibrate( + img[:, :, self.emissive_indices]) + + return img diff --git a/pytorch_caney/transforms/abi_toa.py b/pytorch_caney/transforms/abi_toa.py new file mode 100644 index 0000000..30afb9c --- /dev/null +++ b/pytorch_caney/transforms/abi_toa.py @@ -0,0 +1,30 @@ +import torchvision.transforms as T + +from .abi_toa_scale import MinMaxEmissiveScaleReflectance +from .abi_radiance_conversion import ConvertABIToReflectanceBT + + +# ----------------------------------------------------------------------------- +# AbiToaTransform +# ----------------------------------------------------------------------------- +class AbiToaTransform: + """ + torchvision transform which transforms the input imagery into + addition to generating a MiM mask + """ + + def __init__(self, img_size): + + self.transform_img = \ + T.Compose([ + ConvertABIToReflectanceBT(), + MinMaxEmissiveScaleReflectance(), + T.ToTensor(), + T.Resize((img_size, img_size), antialias=True), + ]) + + def __call__(self, img): + + img = self.transform_img(img) + + return img diff --git a/pytorch_caney/transforms/abi_toa_scale.py b/pytorch_caney/transforms/abi_toa_scale.py new file mode 100644 index 0000000..852aafd --- /dev/null +++ b/pytorch_caney/transforms/abi_toa_scale.py @@ -0,0 +1,37 @@ +import numpy as np + + +class MinMaxEmissiveScaleReflectance(object): + """ + Performs scaling of MODIS TOA data + - Scales reflectance percentages to reflectance units (% -> (0,1)) + - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) + """ + + def __init__(self): + + self.reflectance_indices = [0, 1, 2, 3, 4, 6] + self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] + + self.emissive_mins = np.array( + [117.04327, 152.00592, 157.96591, 176.15349, + 210.60493, 210.52264, 218.10147, 225.9894], + dtype=np.float32) + + self.emissive_maxs = np.array( + [221.07022, 224.44113, 242.3326, 307.42004, + 290.8879, 343.72617, 345.72894, 323.5239], + dtype=np.float32) + + def __call__(self, img): + + # Reflectance % to reflectance units + img[:, :, self.reflectance_indices] = \ + img[:, :, self.reflectance_indices] * 0.01 + + # Brightness temp scaled to (0,1) range + img[:, :, self.emissive_indices] = \ + (img[:, :, self.emissive_indices] - self.emissive_mins) / \ + (self.emissive_maxs - self.emissive_mins) + + return img diff --git a/pytorch_caney/transforms/mim_mask_generator.py b/pytorch_caney/transforms/mim_mask_generator.py new file mode 100644 index 0000000..530b1ca --- /dev/null +++ b/pytorch_caney/transforms/mim_mask_generator.py @@ -0,0 +1,58 @@ +import numpy as np +from numba import njit + + +# ----------------------------------------------------------------------------- +# MimMaskGenerator +# ----------------------------------------------------------------------------- +class MimMaskGenerator: + """ + Generates the masks for masked-image-modeling + """ + def __init__(self, + input_size=192, + mask_patch_size=32, + model_patch_size=4, + mask_ratio=0.6): + self.input_size = input_size + self.mask_patch_size = mask_patch_size + self.model_patch_size = model_patch_size + self.mask_ratio = mask_ratio + + assert self.input_size % self.mask_patch_size == 0 + assert self.mask_patch_size % self.model_patch_size == 0 + + self.rand_size = self.input_size // self.mask_patch_size + self.scale = self.mask_patch_size // self.model_patch_size + + self.token_count = self.rand_size ** 2 + self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) + + def __call__(self): + mask = make_mim_mask(self.token_count, self.mask_count, + self.rand_size, self.scale) + mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) + return mask + + +# ----------------------------------------------------------------------------- +# make_mim_mask +# ----------------------------------------------------------------------------- +@njit() +def make_mim_mask(token_count, mask_count, rand_size, scale): + """JIT-compiled random mask generation + + Args: + token_count + mask_count + rand_size + scale + + Returns: + mask + """ + mask_idx = np.random.permutation(token_count)[:mask_count] + mask = np.zeros(token_count, dtype=np.int64) + mask[mask_idx] = 1 + mask = mask.reshape((rand_size, rand_size)) + return mask diff --git a/pytorch_caney/data/transforms.py b/pytorch_caney/transforms/mim_modis_toa.py similarity index 62% rename from pytorch_caney/data/transforms.py rename to pytorch_caney/transforms/mim_modis_toa.py index e0a71b5..1d168d9 100644 --- a/pytorch_caney/data/transforms.py +++ b/pytorch_caney/transforms/mim_modis_toa.py @@ -1,10 +1,14 @@ -from .utils import RandomResizedCropNP -from .utils import SimmimMaskGenerator - import torchvision.transforms as T +from .random_resize_crop import RandomResizedCropNP +from .mim_mask_generator import MimMaskGenerator +from .modis_toa_scale import MinMaxEmissiveScaleReflectance + -class SimmimTransform: +# ----------------------------------------------------------------------------- +# MimTransform +# ----------------------------------------------------------------------------- +class MimTransform: """ torchvision transform which transforms the input imagery into addition to generating a MiM mask @@ -14,6 +18,7 @@ def __init__(self, config): self.transform_img = \ T.Compose([ + MinMaxEmissiveScaleReflectance(), RandomResizedCropNP(scale=(0.67, 1.), ratio=(3. / 4., 4. / 3.)), T.ToTensor(), @@ -28,7 +33,7 @@ def __init__(self, config): raise NotImplementedError - self.mask_generator = SimmimMaskGenerator( + self.mask_generator = MimMaskGenerator( input_size=config.DATA.IMG_SIZE, mask_patch_size=config.DATA.MASK_PATCH_SIZE, model_patch_size=model_patch_size, @@ -41,24 +46,3 @@ def __call__(self, img): mask = self.mask_generator() return img, mask - - -class TensorResizeTransform: - """ - torchvision transform which transforms the input imagery into - addition to generating a MiM mask - """ - - def __init__(self, config): - - self.transform_img = \ - T.Compose([ - T.ToTensor(), - T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), - ]) - - def __call__(self, img): - - img = self.transform_img(img) - - return img diff --git a/pytorch_caney/transforms/modis_toa.py b/pytorch_caney/transforms/modis_toa.py new file mode 100644 index 0000000..24fdb1f --- /dev/null +++ b/pytorch_caney/transforms/modis_toa.py @@ -0,0 +1,27 @@ +import torchvision.transforms as T + +from .modis_toa_scale import MinMaxEmissiveScaleReflectance + + +# ----------------------------------------------------------------------------- +# ModisToaTransform +# ----------------------------------------------------------------------------- +class ModisToaTransform: + """ + torchvision transform which transforms the input imagery + """ + + def __init__(self, config): + + self.transform_img = \ + T.Compose([ + MinMaxEmissiveScaleReflectance(), + T.ToTensor(), + T.Resize((config.DATA.IMG_SIZE, config.DATA.IMG_SIZE)), + ]) + + def __call__(self, img): + + img = self.transform_img(img) + + return img diff --git a/pytorch_caney/transforms/modis_toa_scale.py b/pytorch_caney/transforms/modis_toa_scale.py new file mode 100644 index 0000000..1eb5a30 --- /dev/null +++ b/pytorch_caney/transforms/modis_toa_scale.py @@ -0,0 +1,40 @@ +import numpy as np + + +# ----------------------------------------------------------------------------- +# MinMaxEmissiveScaleReflectance +# ----------------------------------------------------------------------------- +class MinMaxEmissiveScaleReflectance(object): + """ + Performs scaling of MODIS TOA data + - Scales reflectance percentages to reflectance units (% -> (0,1)) + - Performs per-channel minmax scaling for emissive bands (k -> (0,1)) + """ + + def __init__(self): + + self.reflectance_indices = [0, 1, 2, 3, 4, 6] + self.emissive_indices = [5, 7, 8, 9, 10, 11, 12, 13] + + self.emissive_mins = np.array( + [223.1222, 178.9174, 204.3739, 204.7677, + 194.8686, 202.1759, 201.3823, 203.3537], + dtype=np.float32) + + self.emissive_maxs = np.array( + [352.7182, 261.2920, 282.5529, 319.0373, + 295.0209, 324.0677, 321.5254, 285.9848], + dtype=np.float32) + + def __call__(self, img): + + # Reflectance % to reflectance units + img[:, :, self.reflectance_indices] = \ + img[:, :, self.reflectance_indices] * 0.01 + + # Brightness temp scaled to (0,1) range + img[:, :, self.emissive_indices] = \ + (img[:, :, self.emissive_indices] - self.emissive_mins) / \ + (self.emissive_maxs - self.emissive_mins) + + return img diff --git a/pytorch_caney/data/utils.py b/pytorch_caney/transforms/random_resize_crop.py similarity index 57% rename from pytorch_caney/data/utils.py rename to pytorch_caney/transforms/random_resize_crop.py index 86d9555..06609ec 100644 --- a/pytorch_caney/data/utils.py +++ b/pytorch_caney/transforms/random_resize_crop.py @@ -1,11 +1,10 @@ import torch import numpy as np -from numba import njit - -# TRANSFORMS UTILS - +# ----------------------------------------------------------------------------- +# RandomResizedCropNP +# ----------------------------------------------------------------------------- class RandomResizedCropNP(object): """ Numpy implementation of RandomResizedCrop @@ -62,55 +61,3 @@ def __call__(self, img): cropped_squeezed_numpy = cropped_resized.squeeze().numpy() cropped_squeezed_numpy = np.moveaxis(cropped_squeezed_numpy, 0, -1) return cropped_squeezed_numpy - - -# MASKING - -class SimmimMaskGenerator: - """ - Generates the masks for masked-image-modeling - """ - def __init__(self, - input_size=192, - mask_patch_size=32, - model_patch_size=4, - mask_ratio=0.6): - self.input_size = input_size - self.mask_patch_size = mask_patch_size - self.model_patch_size = model_patch_size - self.mask_ratio = mask_ratio - - assert self.input_size % self.mask_patch_size == 0 - assert self.mask_patch_size % self.model_patch_size == 0 - - self.rand_size = self.input_size // self.mask_patch_size - self.scale = self.mask_patch_size // self.model_patch_size - - self.token_count = self.rand_size ** 2 - self.mask_count = int(np.ceil(self.token_count * self.mask_ratio)) - - def __call__(self): - mask = make_simmim_mask(self.token_count, self.mask_count, - self.rand_size, self.scale) - mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1) - return mask - - -@njit() -def make_simmim_mask(token_count, mask_count, rand_size, scale): - """JIT-compiled random mask generation - - Args: - token_count - mask_count - rand_size - scale - - Returns: - mask - """ - mask_idx = np.random.permutation(token_count)[:mask_count] - mask = np.zeros(token_count, dtype=np.int64) - mask[mask_idx] = 1 - mask = mask.reshape((rand_size, rand_size)) - return mask diff --git a/pytorch_caney/utils.py b/pytorch_caney/utils.py index af65b11..fc7e46d 100644 --- a/pytorch_caney/utils.py +++ b/pytorch_caney/utils.py @@ -1,15 +1,45 @@ -import torch -import warnings - - -def check_gpus_available(ngpus: int) -> None: - ngpus_available = torch.cuda.device_count() - if ngpus < ngpus_available: - msg = 'Not using all available GPUS.' + \ - f' N GPUs available: {ngpus_available},' + \ - f' N GPUs selected: {ngpus}. ' - warnings.warn(msg) - elif ngpus > ngpus_available: - msg = 'Not enough GPUs to satisfy selected amount' + \ - f': {ngpus}. N GPUs available: {ngpus_available}' - warnings.warn(msg) +from lightning.pytorch.strategies import DeepSpeedStrategy + + +# ----------------------------------------------------------------------------- +# get_strategy +# ----------------------------------------------------------------------------- +def get_strategy(config): + + strategy = config.TRAIN.STRATEGY + + if strategy == 'deepspeed': + deepspeed_config = { + "train_micro_batch_size_per_gpu": config.DATA.BATCH_SIZE, + "steps_per_print": config.PRINT_FREQ, + "zero_allow_untested_optimizer": True, + "zero_optimization": { + "stage": config.DEEPSPEED.STAGE, + "contiguous_gradients": + config.DEEPSPEED.CONTIGUOUS_GRADIENTS, + "overlap_comm": config.DEEPSPEED.OVERLAP_COMM, + "reduce_bucket_size": config.DEEPSPEED.REDUCE_BUCKET_SIZE, + "allgather_bucket_size": + config.DEEPSPEED.ALLGATHER_BUCKET_SIZE, + }, + "activation_checkpointing": { + "partition_activations": config.TRAIN.USE_CHECKPOINT, + }, + } + + return DeepSpeedStrategy(config=deepspeed_config) + + else: + # These may be return as strings + return strategy + + +# ----------------------------------------------------------------------------- +# get_distributed_train_batches +# ----------------------------------------------------------------------------- +def get_distributed_train_batches(config, trainer): + if config.TRAIN.NUM_TRAIN_BATCHES: + return config.TRAIN.NUM_TRAIN_BATCHES + else: + return config.DATA.LENGTH // \ + (config.DATA.BATCH_SIZE * trainer.world_size) diff --git a/requirements/Dockerfile b/requirements/Dockerfile index 5f26ebb..3c238bf 100644 --- a/requirements/Dockerfile +++ b/requirements/Dockerfile @@ -7,11 +7,9 @@ FROM ${FROM_IMAGE}:${VERSION_DATE}-py3 # Ubuntu needs noninteractive to be forced ENV DEBIAN_FRONTEND noninteractive -ENV PROJ_LIB="/usr/share/proj" ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" ENV C_INCLUDE_PATH="/usr/include/gdal" -# System dependencies # System dependencies RUN apt-get update && \ apt-get -y install software-properties-common && \ @@ -75,13 +73,13 @@ RUN git clone --single-branch --branch master https://github.com/pkolano/shift.g rm -rf /app # Pip -RUN pip --no-cache-dir install omegaconf \ +RUN pip --no-cache-dir install \ pytorch-lightning \ Lightning \ transformers \ datasets \ - webdataset \ deepspeed \ + webdataset \ 'huggingface_hub[cli,torch]' \ torchgeo \ rasterio \ @@ -95,9 +93,6 @@ RUN pip --no-cache-dir install omegaconf \ opencv-contrib-python-headless \ tifffile \ webcolors \ - Pillow \ - seaborn \ - xgboost \ tiler \ segmentation-models \ timm \ @@ -110,8 +105,7 @@ RUN pip --no-cache-dir install omegaconf \ yacs \ termcolor \ segmentation-models-pytorch \ - pytorch-caney \ - GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` + coverage HEALTHCHECK NONE ENTRYPOINT [] diff --git a/requirements/Dockerfile.dev b/requirements/Dockerfile.dev index b7fc5d6..0157464 100644 --- a/requirements/Dockerfile.dev +++ b/requirements/Dockerfile.dev @@ -1,5 +1,5 @@ # Arguments to pass to the image -ARG VERSION_DATE=23.01 +ARG VERSION_DATE=24.01 ARG FROM_IMAGE=nvcr.io/nvidia/pytorch # Import RAPIDS container as the BASE Image (cuda base image) @@ -7,7 +7,6 @@ FROM ${FROM_IMAGE}:${VERSION_DATE}-py3 # Ubuntu needs noninteractive to be forced ENV DEBIAN_FRONTEND noninteractive -ENV PROJ_LIB="/usr/share/proj" ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" ENV C_INCLUDE_PATH="/usr/include/gdal" @@ -75,7 +74,7 @@ RUN git clone --single-branch --branch master https://github.com/pkolano/shift.g rm -rf /app # Pip -RUN pip --no-cache-dir install omegaconf \ +RUN pip --no-cache-dir install \ pytorch-lightning \ Lightning \ transformers \ @@ -95,9 +94,6 @@ RUN pip --no-cache-dir install omegaconf \ opencv-contrib-python-headless \ tifffile \ webcolors \ - Pillow \ - seaborn \ - xgboost \ tiler \ segmentation-models \ timm \ @@ -110,7 +106,7 @@ RUN pip --no-cache-dir install omegaconf \ yacs \ termcolor \ segmentation-models-pytorch \ - GDAL==`ogrinfo --version | grep -Eo '[0-9]\.[0-9]\.[0-9]+'` + coverage HEALTHCHECK NONE ENTRYPOINT [] diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index d35fb47..f8749a8 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,22 +1,14 @@ torch>=2.0.0 torchvision>=0.15 -pytorch-lightning -omegaconf +lightning +datasets rasterio rioxarray xarray geopandas -opencv-python -opencv-python-headless -opencv-contrib-python -opencv-contrib-python-headless tifffile webcolors -Pillow -seaborn -xgboost tiler -segmentation-models pytest coveralls rtree @@ -26,5 +18,10 @@ yacs termcolor numba segmentation-models-pytorch -timm +joblib +GDAL>=3.3.0 +coverage deepspeed +timm +webdataset +torchgeo diff --git a/requirements/requirements.txt b/requirements/requirements.txt index a5d7614..d0efeda 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,22 +1,14 @@ torch>=2.0.0 torchvision>=0.15 -pytorch-lightning -omegaconf +lightning +datasets rasterio rioxarray xarray geopandas -opencv-python -opencv-python-headless -opencv-contrib-python -opencv-contrib-python-headless tifffile webcolors -Pillow -seaborn -xgboost tiler -segmentation-models pytest coveralls rtree @@ -30,3 +22,7 @@ joblib timm deepspeed GDAL>=3.3.0 +coverage +deepspeed +webdataset +torchgeo