diff --git a/README.md b/README.md new file mode 100644 index 0000000..41bc4b5 --- /dev/null +++ b/README.md @@ -0,0 +1,204 @@ +# ReadMe + +This is the repository for the submitted Remote Sensing in Environment article: *Deep Point Cloud Regression for +Above-Ground Forest Biomass Estimation from Airborne LiDAR*. + +To start unzip `biomasspointclouds.7z` with the password given in the supplementary material under Reproducability. ( +e.g. `7z x biomasspointclouds"`) + +We include **code**, **model weights** (soon), and the **dataset** (soon). + +Regarding the code: +We forked the [torch-points3d](https://github.com/nicolas-chaulet/torch-points3d) framework and added support for +regression tasks including datasets, tracking, and models on our own. In the process, we also simplified the usage of +package. + +In addition, we also included our code to load the trained linear regression and random forest in +the `pointcloud_stats_method` folder. Just run the notebook `learn_with_stats.ipynb`. + +Finally, the results/plots for each method can be seen in the `eval_scripts` folder within +the `eval_deep_learning_v2.ipynb`. The results for the network size experiment are in `eval_deep_learning_v2_size.ipynb`. + +**results on the test set:** + +| target | model | treeadd | $R^2$ | | RMSE | | MAPE | | mean bias | | +|:----------------|:---------|:--------|---------:|------:|---------:|--------:|----------:|----------:|----------:|---------:| +| | | | *median* | *max* | *median* | *min* | *median* | *min* | *median* | *min* | +| **biomass** | KPConv | False | 0.800 | 0.815 | 45.264 | 43.540 | 396.685 | 272.288 | 0.460 | 0.389 | +| | | True | 0.780 | 0.803 | 47.526 | 44.975 | 467.581 | 246.927 | 3.660 | -0.707 | +| | MSENet14 | False | 0.825 | 0.829 | 42.373 | 41.806 | 299.497 | 192.777 | 0.666 | -0.291 | +| | | True | 0.823 | 0.829 | 42.596 | 41.851 | 271.716 | 131.120 | 0.313 | 0.122 | +| | MSENet50 | False | 0.827 | 0.835 | 42.140 | 41.083 | 469.104 | 174.245 | 0.837 | -0.114 | +| | | True | 0.824 | 0.837 | 42.481 | 40.909 | 339.700 | 119.264 | 0.889 | 0.596 | +| | PointNet | False | 0.770 | 0.772 | 48.565 | 48.288 | 889.293 | 625.091 | 0.539 | 0.119 | +| | | True | 0.766 | 0.768 | 48.932 | 48.753 | 896.835 | 622.713 | 2.464 | 1.774 | +| | RF | False | 0.754 | 0.754 | 50.188 | 50.158 | 625.439 | 616.635 | 1.470 | 1.459 | +| | | True | 0.151 | 0.157 | 93.238 | 92.930 | 7644.787 | 7423.094 | 47.625 | -47.521 | +| | power | False | 0.761 | 0.761 | 49.509 | 49.509 | 365.606 | 365.606 | 2.027 | 2.027 | +| | | True | 0.034 | 0.034 | 99.478 | 99.478 | 7604.844 | 7604.844 | 57.525 | -57.525 | +| | linear | False | 0.762 | 0.762 | 49.420 | 49.420 | 425.605 | 425.605 | 1.894 | 1.894 | +| | | True | 0.195 | 0.195 | 90.801 | 90.801 | 11448.501 | 11448.501 | 39.149 | -39.149 | +| **wood volume** | KPConv | False | 0.799 | 0.805 | 85.434 | 84.255 | 103.866 | 85.633 | 0.377 | 0.285 | +| | | True | 0.778 | 0.792 | 89.808 | 87.002 | 126.543 | 85.812 | 7.885 | -1.012 | +| | MSENet14 | False | 0.823 | 0.826 | 80.309 | 79.631 | 99.105 | 72.597 | 0.515 | 0.389 | +| | | True | 0.821 | 0.825 | 80.750 | 79.716 | 84.473 | 70.097 | 2.577 | 1.829 | +| | MSENet50 | False | 0.824 | 0.831 | 79.986 | 78.344 | 131.525 | 72.381 | 0.169 | 0.123 | +| | | True | 0.822 | 0.832 | 80.571 | 78.177 | 115.634 | 78.422 | 3.572 | 2.646 | +| | PointNet | False | 0.777 | 0.781 | 90.183 | 89.198 | 205.366 | 162.049 | 1.991 | 1.369 | +| | | True | 0.773 | 0.776 | 90.844 | 90.220 | 236.383 | 174.903 | 5.708 | 4.578 | +| | RF | False | 0.757 | 0.757 | 94.091 | 94.070 | 223.652 | 222.600 | 3.979 | 3.955 | +| | | True | 0.192 | 0.197 | 171.475 | 170.930 | 1683.778 | 1676.524 | 85.629 | -85.465 | +| | power | False | 0.763 | 0.763 | 92.819 | 92.819 | 223.654 | 223.654 | 4.497 | 4.497 | +| | | True | 0.120 | 0.120 | 178.973 | 178.973 | 1793.822 | 1793.822 | 101.104 | -101.104 | +| | linear | False | 0.766 | 0.766 | 92.292 | 92.292 | 171.483 | 171.483 | 4.602 | 4.602 | +| | | True | 0.243 | 0.243 | 166.034 | 166.034 | 1747.807 | 1747.807 | 72.340 | -72.340 | + +# Install torch-points3d + +We setup our environment in the following way (conda is already installed): + +1. go to `pointcloud-biomass-estimator/torch-points3d` +2. Make sure to install cuda 11.8 (don't forget to deselect the driver install if your drivers are current) + +``` +wget https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run +sudo sh cuda_11.8.0_520.61.05_linux.run +``` + +3. after installing close and reopen the terminal to check if the PATH is set correctly with `echo $PATH`. It should + **not** have `/usr/local/cuda-10.2` but should have something like `/usr/local/cuda-11.8` in there + +5. install mamba (optional but highly recommended) + +``` +conda install mamba -c conda-forge +``` + +3. create conda environment: + +``` +mamba env create -f env.yml +``` + +or for cpu-version: + +``` +mamba env create -f env_cpu.yml +``` + +4. activate environment: + +``` +mamba activate pts +``` + +5. install missing pip packages for Minkowski networks + +``` +pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --config-settings blas_include_dirs=${CONDA_PREFIX}/include blas=openblas + +``` + +or for cpu-version: + +``` +pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --config-settings blas=openblas + +``` + +5. compile KPConv scripts + +``` +sh compile_wrappers.sh +``` + +# Training for Regression + +run from within the torch-points3d folder. + +*MSENet50:* + +``` +python -u train.py task=instance models=instance/minkowski_baseline model_name=SENet50 data=instance/NFI/reg data.transform_type=sparse_xy training=nfi/minkowski lr_scheduler=cosineawr update_lr_scheduler_on=on_num_batch +``` + +*MSENet14:* + +``` +python -u train.py task=instance models=instance/minkowski_baseline model_name=SENet14 data=instance/NFI/reg data.transform_type=sparse_xy training=nfi/minkowski lr_scheduler=cosineawr update_lr_scheduler_on=on_num_batch +``` + +*KPConv:* + +``` +python -u train.py task=instance models=instance/kpconv model_name=KPConv data=instance/NFI/reg training=nfi/kpconv data.transform_type=xy lr_scheduler=cosineawr update_lr_scheduler_on=on_num_batch +``` + +*PointNet:* + +``` +python -u train.py task=instance models=instance/minkowski_baseline model_name=MPointNet data=instance/NFI/reg training=nfi/pointnet data.transform_type=sparse_xy lr_scheduler=cosineawr update_lr_scheduler_on=on_num_batch +``` + +# Calibration batch normalization + +to calibrate the trained models batch norm statistics. Note that the checkpoint directory has to be an absolute path, +e.g.: `checkpoint_dir=/home/user/torch-points3d/weights/SENet50/0` + +for Minkowski or Pointnet (`model_name=SENet50`, `model_name=SENet14`, or `model_name=MPointNet`): + +``` +python calibrate_bn.py model_name=${model_name} checkpoint_dir=${checkpoint_dir} data=instance/NFI/reg num_workers=4 task=instance weight_name="total_BMag_ha_rmse" batch_size=64 num_workers=4 data.transform_type=sparse_xy epochs=20 +``` + +for KPConv: + +``` +python calibrate_bn.py model_name=KPConv checkpoint_dir=${checkpoint_dir} data=instance/NFI/reg num_workers=4 task=instance weight_name="total_BMag_ha_rmse" batch_size=64 num_workers=4 data.transform_type=xy epochs=20 +``` + +# Evaluating our models + +run from within the torch-points3d folder. Note that the checkpoint directory has to be an absolute path, +e.g.: `PATHTOFRAMEWORK=/home/user/torch-points3d` +Also, there are 5 weights for each model (from different trials): `TRIAL=1` + +*MSENet50:* + +``` +python eval.py model_name=SENet50 checkpoint_dir=${PATHTOFRAMEWORK}/weights/SENet50/${TRIAL}/ weight_name="latest" batch_size=32 num_workers=4 eval_stages=["val","test"] data.transform_type=sparse_xy_eval data=instance/NFI/reg task=instance +``` + +the save folder location is `weights/msenet50/eval`. + +*MSENet14:* + +``` +python eval.py model_name=SENet14 checkpoint_dir=${PATHTOFRAMEWORK}/weights/SENet14/${TRIAL}/ weight_name="latest" batch_size=32 num_workers=4 eval_stages=["val","test"] data.transform_type=sparse_xy_eval data=instance/NFI/reg task=instance +``` + +the save folder location is `weights/msenet14/eval`. + +*KPConv:* + +``` +python eval.py model_name=KPConv checkpoint_dir=${PATHTOFRAMEWORK}/weights/KPConv/${TRIAL}/ weight_name="latest" batch_size=32 num_workers=4 eval_stages=["val","test"] data.transform_type=xy_eval data=instance/NFI/reg task=instance +``` + +the save folder location is `weights/kpconv/eval`. + +*PointNet:* + +``` +python eval.py model_name=MPointNet checkpoint_dir=${PATHTOFRAMEWORK}/weights/PointNet/${TRIAL}/ weight_name="latest" batch_size=32 num_workers=4 eval_stages=["val","test"] data.transform_type=sparse_xy_eval data=instance/NFI/reg task=instance +``` + +the save folder location is `weights/pointnet/eval`. + +# Using tree-adding augmentations during test + +same as before, but the transform type changes to use tree augmentations, e.g.: + +``` +python eval.py model_name=MPointNet checkpoint_dir=${PATHTOFRAMEWORK}/weights/pointnet/ weight_name="total_rmse" batch_size=32 num_workers=4 eval_stages=["val","test"] data.transform_type=sparse_xy_eval_treeadd +``` diff --git a/eval_scripts/eval_deep_learning_v2.ipynb b/eval_scripts/eval_deep_learning_v2.ipynb new file mode 100644 index 0000000..67609ff --- /dev/null +++ b/eval_scripts/eval_deep_learning_v2.ipynb @@ -0,0 +1,9112 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "id": "38cde886", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import geopandas as gpd\n", + "import seaborn as sns\n", + "\n", + "sns.set_context(\"paper\")\n", + "sns.set_style(\"whitegrid\")\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rcParams[\"svg.fonttype\"] = \"none\"\n", + "from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error, r2_score\n", + "\n", + "sns.set_color_codes()\n", + "from glob import glob\n", + "from itertools import product\n", + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "358d4bc2", + "metadata": {}, + "outputs": [], + "source": [ + "target_vars = [\"BMag_ha\", \"V_ha\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "30bad65e", + "metadata": {}, + "outputs": [], + "source": [ + "bias_correct_splits = [\"val\", \"train\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4d80056e", + "metadata": {}, + "outputs": [], + "source": [ + "# choose one of test, train, val\n", + "splits = [\"train\", \"val\", \"test\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "93cd866d", + "metadata": {}, + "outputs": [], + "source": [ + "models = {\n", + " # other baselines and models\n", + " \"linear\": (\n", + " f\"results_new/linear_?.gpkg\",\n", + " ),\n", + " \"RF\": (\n", + " f\"results_new/rf_?.gpkg\",\n", + " ),\n", + "\n", + " \"KPConv\": (\n", + " f\"results_new/KPConv_xy_??.gpkg\",\n", + " ),\n", + " \"PointNet\": (\n", + " f\"results_new/MPointNet_xy_??.gpkg\",\n", + " ),\n", + "\n", + " # favored basline and model\n", + " \"\\power{}\": (\n", + " f\"results_new/power_?.gpkg\",\n", + " ),\n", + "\n", + " \"MSENet14\": (\n", + " f\"results_new/SENet14_xy_??.gpkg\",\n", + " ),\n", + " \"MSENet50\": (\n", + " f\"results_new/SENet50_xy_??.gpkg\",\n", + " ),\n", + " \n", + "\n", + " # evaluation on augmented test set (treeadding augmentation)\n", + " \"linear_treeval\": (\n", + " f\"results_new/linear_?_treeadd.gpkg\",\n", + " ),\n", + " \"RF_treeval\": (\n", + " f\"results_new/rf_?_treeadd.gpkg\",\n", + " ),\n", + " \"\\power{}_treeval\": (\n", + " f\"results_new/power_?_treeadd.gpkg\",\n", + " ),\n", + "\n", + " \"KPConv_treeval\": (\n", + " f\"results_new/KPConv_xy_??_treeadd.gpkg\",\n", + " ),\n", + " \"PointNet_treeval\": (\n", + " f\"results_new/MPointNet_xy_??_treeadd.gpkg\",\n", + " ),\n", + " \n", + "\n", + " \"MSENet14_treeval\": (\n", + " f\"results_new/SENet14_xy_??_treeadd.gpkg\",\n", + " ),\n", + " \"MSENet50_treeval\": (\n", + " f\"results_new/SENet50_xy_??_treeadd.gpkg\",\n", + " ),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "id": "49266c45-aa7e-41a0-a704-f6bcc07bff48", + "metadata": {}, + "outputs": [], + "source": [ + "with open('results_new.pickle', 'rb') as handle:\n", + " results = pickle.load(handle)" + ] + }, + { + "cell_type": "markdown", + "id": "6f20b57b-0f00-4b69-8585-71a5ee12e3e0", + "metadata": {}, + "source": [ + "# Bias correction" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "d4c831b4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "RF 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "RF 1\n", + "0.9151888974556669\n", + "[0 0]\n", + "RF 2\n", + "0.9151888974556669\n", + "[0 0]\n", + "RF 3\n", + "0.9151888974556669\n", + "[0 0]\n", + "RF 4\n", + "0.9151888974556669\n", + "[0 0]\n", + "KPConv 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "KPConv 1\n", + "0.9151888974556669\n", + "[0 0]\n", + "KPConv 2\n", + "0.9151888974556669\n", + "[0 0]\n", + "KPConv 3\n", + "0.9151888974556669\n", + "[0 0]\n", + "KPConv 4\n", + "0.9151888974556669\n", + "[0 0]\n", + "PointNet 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "PointNet 1\n", + "0.9151888974556669\n", + "[0 0]\n", + "PointNet 2\n", + "0.9151888974556669\n", + "[0 0]\n", + "PointNet 3\n", + "0.9151888974556669\n", + "[0 0]\n", + "PointNet 4\n", + "0.9151888974556669\n", + "[0 0]\n", + "\\power{} 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet14 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet14 2\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet14 3\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet14 4\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet14 1\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet50 0\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet50 1\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet50 2\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet50 3\n", + "0.9151888974556669\n", + "[0 0]\n", + "MSENet50 4\n", + "0.9151888974556669\n", + "[0 0]\n" + ] + } + ], + "source": [ + "# get bias correction\n", + "# we do not include the 0 predictions into the adjustment since they come from a different data distribution\n", + "\n", + "deltas = {}\n", + "results_corrected = {}\n", + "use_treeadd = False\n", + "use_treeval = False\n", + "exclude_1y = False\n", + "exclude_pred_0 = False\n", + "clip_0 = True\n", + "for model in models:\n", + " if \"treeval\" in model: # using the original correction\n", + " continue\n", + " corrected = []\n", + " corrected_treeval = []\n", + " for run in pd.unique(results[model][\"run\"]):\n", + " print(model, run)\n", + " pred_vars = [f\"{v}_pred\" for v in target_vars]\n", + " preds_cal = pd.concat(\n", + " [\n", + " results[model].query(f\"(run == {run}) & (split == @split)\")\n", + " for split in bias_correct_splits\n", + " ],\n", + " axis=0,\n", + " )[target_vars + pred_vars + [\"mask\", \"temp_diff_years\"] ].copy(deep=True)\n", + " if use_treeadd and model+\"_treeval\" in results and \"treeadd\" in model:\n", + " preds_cal = pd.concat(\n", + " [ preds_cal ] + \n", + " [\n", + " results[model+\"_treeadd\"].query(f\"(run == {run}) & (split == @split)\")\n", + " for split in bias_correct_splits\n", + " ],\n", + " axis=0,\n", + " )[target_vars + pred_vars + [\"mask\", \"temp_diff_years\"] ].copy(deep=True)\n", + " if use_treeval and model+\"_treeval\" in results:\n", + " preds_cal = pd.concat(\n", + " [ preds_cal ] + \n", + " [\n", + " results[model+\"_treeval\"].query(f\"(run == {run}) & (split == @split)\")\n", + " for split in bias_correct_splits\n", + " ],\n", + " axis=0,\n", + " )[target_vars + pred_vars + [\"mask\", \"temp_diff_years\"] ].copy(deep=True)\n", + " \n", + " \n", + " #reds_cal = preds_cal.sample(len(preds_cal))\n", + " \n", + " mask = np.ones_like(preds_cal[\"mask\"])\n", + " if exclude_1y:\n", + " mask &= (preds_cal[\"temp_diff_years\"] <= 1)\n", + " if exclude_pred_0:\n", + " mask &= ~preds_cal[\"mask\"]\n", + " \n", + " correct_ = ~mask == (preds_cal[target_vars] == 0).any(axis=1)\n", + " print(correct_.sum() / len(correct_))\n", + " #print(f\"num vals != 0: {mask.sum()}\")\n", + " y_cal_ = preds_cal[target_vars][mask].values\n", + " preds_cal_ = preds_cal[pred_vars][mask].values\n", + "\n", + " '''\n", + " ds = []\n", + " num_vals = 100\n", + " for i in range(0, len(y_cal_), num_vals):\n", + " mm = np.ones(len(y_cal_), dtype=bool)\n", + " mm[i:i+num_vals] = False\n", + " ds.append((\n", + " y_cal_[mm].astype(np.float64).sum(0)\n", + " - preds_cal_[mm].astype(np.float64).sum(0)\n", + " ) / (mm.sum()))\n", + " delta = np.median(ds, 0)\n", + " '''\n", + " delta = (y_cal_.astype(np.float64).sum(0)\n", + " - preds_cal_.astype(np.float64).sum(0)) / (len(y_cal_))\n", + " deltas[model, run] = delta\n", + " \n", + " # check if calibration is close to 0 on calibration set\n", + " assert np.isclose(0, y_cal_.sum(0) - ((preds_cal_ + delta).sum(0))).all() \n", + " \n", + " # apply delta to all values\n", + " df = results[model].query(f\"run == {run}\")[target_vars + pred_vars + [\"run\", \"mask\", \"split\", \"C_qfrac\"]]\n", + " dff = df[pred_vars]\n", + " if exclude_pred_0:\n", + " mask = ~df[[\"mask\"]].values\n", + " else:\n", + " mask = np.ones_like(df[[\"mask\"]])\n", + " df[pred_vars] = (dff + delta) * mask + (~mask) * dff\n", + " if clip_0:\n", + " df[pred_vars] = df[pred_vars].mask(dff < 0.00, 0.0)\n", + " corrected.append(df)\n", + " \n", + " # apply delta to all values of treeval\n", + " if model+\"_treeval\" not in results:\n", + " continue\n", + " df = results[model+\"_treeval\"].query(f\"run == {run}\")[target_vars + pred_vars + [\"run\", \"mask\", \"split\", \"C_qfrac\"]]\n", + " dff = df[pred_vars]\n", + " if exclude_pred_0:\n", + " mask = df[[\"mask\"]].values\n", + " else:\n", + " mask = np.ones_like(df[[\"mask\"]])\n", + " df[pred_vars] = (dff + delta) * mask + (~mask) * dff\n", + " if clip_0:\n", + " df[pred_vars] = df[pred_vars].mask(dff < 0.00, 0.0)\n", + " print((dff < 0.00).sum().values)\n", + " corrected_treeval.append(df)\n", + "\n", + " results_corrected[model] = pd.concat(corrected, axis=0)\n", + " if len(corrected_treeval) > 0:\n", + " results_corrected[model+\"_treeval\"] = pd.concat(corrected_treeval, axis=0)" + ] + }, + { + "cell_type": "markdown", + "id": "64349efc-e498-4cd7-9b47-bb677f5a0aab", + "metadata": {}, + "source": [ + "# Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "eb2c8782", + "metadata": {}, + "outputs": [], + "source": [ + "def cohen_d(y1_pred, y2_pred):\n", + " mse1 = (y1_pred**2).mean()\n", + " mse2 = (y2_pred**2).mean()\n", + " \n", + " diff = mse1 - mse2\n", + " s_pooled = np.sqrt((mse1 + mse2) / 2)\n", + " cohens_d = diff / s_pooled\n", + " return cohens_d\n", + "\n", + "def evaluate(name, results):\n", + " print(name)\n", + " columns = [\n", + " \"method\",\n", + " \"target\",\n", + " \"R2\",\n", + " \"MSE\",\n", + " \"RMSE\",\n", + " \"nRMSE\",\n", + " \"MAPE\",\n", + " \"mean error\",\n", + " \"mean bias\",\n", + " \"rel. error\",\n", + " \"run\",\n", + " ]\n", + " results_df = []\n", + "\n", + " for target in target_vars:\n", + " pred = target + \"_pred\"\n", + " for run, result in results.groupby(\"run\"):\n", + " mask = mm = result[target] != 0\n", + " #mm = result[pred] != 0\n", + " \n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " r2_score(result[target], result[pred]),\n", + " mean_squared_error(\n", + " result[target], result[pred]\n", + " ),\n", + " mean_squared_error(\n", + " result[target], result[pred], squared=False\n", + " ),\n", + " mean_squared_error(\n", + " result[target], result[pred], squared=False\n", + " ) / result[target].mean(),\n", + " mean_absolute_percentage_error(\n", + " result[target][mask], result[pred][mask]\n", + " )\n", + " * 100,\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[pred][mm])\n", + " ),\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[target][mm])\n", + " ,\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / (result[target][mm]).sum()\n", + " )\n", + " * 100,\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " results_df = pd.concat(results_df, axis=0)\n", + " return results, results_df\n", + "\n", + "'''\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[pred][mm])\n", + " ),\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[pred][mm])\n", + " ,\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / (result[pred][mm]).sum()\n", + " )\n", + " * 100,\n", + "''';\n", + "'''\n", + " abs(\n", + " (result[target] - result[pred]).sum()\n", + " / len(result[pred])\n", + " ),\n", + " (result[target] - result[pred]).sum()\n", + " / len(result[pred])\n", + " ,\n", + " abs(\n", + " (result[target] - result[pred]).sum()\n", + " / (result[pred]).sum()\n", + " )\n", + " * 100,\n", + "''';" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "4635281c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear\n", + "linear\n", + "RF\n", + "RF\n", + "KPConv\n", + "KPConv\n", + "PointNet\n", + "PointNet\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}\n", + "\\power{}\n", + "MSENet14\n", + "MSENet14\n", + "MSENet50\n", + "MSENet50\n", + "linear_treeval\n", + "linear_treeval\n", + "RF_treeval\n", + "RF_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}_treeval\n", + "\\power{}_treeval\n", + "KPConv_treeval\n", + "KPConv_treeval\n", + "PointNet_treeval\n", + "PointNet_treeval\n", + "MSENet14_treeval\n", + "MSENet14_treeval\n", + "MSENet50_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50_treeval\n", + "linear\n", + "linear\n", + "RF\n", + "RF\n", + "KPConv\n", + "KPConv\n", + "PointNet\n", + "PointNet\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}\n", + "\\power{}\n", + "MSENet14\n", + "MSENet14\n", + "MSENet50\n", + "MSENet50\n", + "linear_treeval\n", + "linear_treeval\n", + "RF_treeval\n", + "RF_treeval\n", + "\\power{}_treeval\n", + "\\power{}_treeval\n", + "KPConv_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "KPConv_treeval\n", + "PointNet_treeval\n", + "PointNet_treeval\n", + "MSENet14_treeval\n", + "MSENet14_treeval\n", + "MSENet50_treeval\n", + "MSENet50_treeval\n", + "linear\n", + "linear\n", + "RF\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RF\n", + "KPConv\n", + "KPConv\n", + "PointNet\n", + "PointNet\n", + "\\power{}\n", + "\\power{}\n", + "MSENet14\n", + "MSENet14\n", + "MSENet50\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50\n", + "linear_treeval\n", + "linear_treeval\n", + "RF_treeval\n", + "RF_treeval\n", + "\\power{}_treeval\n", + "\\power{}_treeval\n", + "KPConv_treeval\n", + "KPConv_treeval\n", + "PointNet_treeval\n", + "PointNet_treeval\n", + "MSENet14_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_treeval\n", + "MSENet50_treeval\n", + "MSENet50_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_42898/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_42898/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + } + ], + "source": [ + "result_dict = {}\n", + "result_dict_corrected = {}\n", + "result_scores = {}\n", + "for split in splits:\n", + " result_score = []\n", + " for name in models.keys():\n", + " # use corrected version except for linear regressor (optimal already)\n", + " file, scores = evaluate(name, results[name].query(\"split == @split\"))\n", + " file.loc[:, \"corrected\"] = False\n", + " scores.loc[:, \"corrected\"] = False\n", + "\n", + " result_dict[name] = file\n", + " result_score.append(scores)\n", + "\n", + " file, scores = evaluate(name, results_corrected[name].query(\"split == @split\"))\n", + " file.loc[:, \"corrected\"] = True\n", + " scores.loc[:, \"corrected\"] = True\n", + "\n", + " result_dict_corrected[name] = file\n", + " result_score.append(scores)\n", + "\n", + " result_score = pd.concat(result_score, axis=0)\n", + " result_scores[split] = result_score" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "id": "85324aa9-c572-42cf-bc46-ee03d286c843", + "metadata": {}, + "outputs": [], + "source": [ + "# resave treeval results via flag\n", + "\n", + "for split in splits:\n", + " if \"treeval\" not in result_scores[split].columns:\n", + " treevals = result_scores[split][\"method\"].str.contains(\"treeval\")\n", + " method = result_scores[split][\"method\"]\n", + " result_scores[split].eval(\"treeval = @treevals\", inplace=True)\n", + " result_scores[split][\"method\"] = result_scores[split][\"method\"].str.replace(\"_treeval\", \"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "id": "f7494784-6232-490e-9e16-897f139b890b", + "metadata": {}, + "outputs": [], + "source": [ + "def abs_min(x): return x.iloc[np.argmin(abs(x))]\n", + "def abs_max(x): return x.iloc[np.argmax(abs(x))]\n", + "def abs_median(x): return np.median(abs(x))\n", + "def avg_sign(x): return np.mean(np.sign(x))\n", + "def abs_mean(x): return np.mean(abs(x))\n", + "def arg_abs_min(x): return np.argmin(abs(x))\n", + "def arg_abs_max(x): return np.argmax(abs(x))\n", + "def arg_max(x): return np.argmax(abs(x))\n", + "\n", + "agg = {\n", + " \"R2\": [\"median\", \"max\"],\n", + " #'MSE' : ['median', 'min'],\n", + " 'RMSE' : ['median', 'min'],\n", + " 'MAPE' : ['median', 'min'],\n", + " #\"mean error\": [\"median\", \"max\", \"min\"],\n", + " \"mean bias\": [abs_median, abs_min],\n", + " #'rel. error' : ['median', \"min\"],\n", + "}\n", + "\n", + "rr = (\n", + " result_scores[\"test\"]\n", + " #.query(\"target == 'BMag_ha'\")\n", + " .query(\"corrected == True\")\n", + " .groupby([\"target\", \"method\", \"treeval\"])\n", + " .agg(agg)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "id": "0dd65233-9f1e-4660-a23f-76e8c24acd2e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
R2RMSEMAPEmean bias
medianmaxmedianminmedianminabs_medianabs_min
targetmethodtreeval
BMag_haKPConvFalse0.7999950.81494245.26376143.539632396.684826272.2877350.4603210.388961
True0.7795040.80253747.52595144.975245467.581224246.9273543.659850-0.707264
MSENet14False0.8247250.82938842.37314941.805641299.496832192.7774400.665678-0.290542
True0.8228800.82902242.59557041.850502271.716014131.1203850.3131840.122477
MSENet50False0.8266480.83523342.14004541.083267469.104138174.2454990.837429-0.114375
True0.8238310.83663242.48099040.908553339.700072119.2637520.8894410.596189
PointNetFalse0.7697590.77237448.56482748.288268889.292563625.0914890.5390270.118963
True0.7662660.76797348.93179748.752772896.834833622.7132942.4638041.774162
RFFalse0.7541100.75440150.18810150.158444625.439118616.6353701.4702751.458803
True0.1513640.15695293.23759092.9301267644.7870897423.09393347.625241-47.521057
\\power{}False0.7607200.76072049.50896149.508961365.605998365.6059982.0265792.026579
True0.0339630.03396399.47806399.4780637604.8440367604.84403657.525097-57.525097
linearFalse0.7615800.76158049.41993149.419931425.605420425.6054201.8938161.893816
True0.1951460.19514690.80062890.80062811448.50058711448.50058739.149145-39.149145
V_haKPConvFalse0.7994930.80498985.43377584.254618103.86617885.6334520.3765680.284536
True0.7784360.79206589.80774187.001898126.54252185.8116327.885448-1.011822
MSENet14False0.8228260.82580780.30902679.63057499.10521072.5969960.5149830.388571
True0.8208740.82543180.75021179.71647484.47311470.0974842.5773541.829469
MSENet50False0.8242490.83139179.98593278.343874131.52539372.3812830.1685060.122771
True0.8216660.83210880.57143178.177082115.63450078.4223543.5717482.645925
PointNetFalse0.7765790.78143790.18336589.197636205.366066162.0488161.9906121.368627
True0.7732930.77639990.84408990.219658236.383012174.9030005.7077394.577837
RFFalse0.7567980.75690994.09102194.069611223.651644222.6003793.9789853.955142
True0.1922560.197387171.475273170.9296971683.7780861676.52439385.629363-85.465170
\\power{}False0.7633280.76332892.81918392.819183223.653969223.6539694.4969084.496908
True0.1200750.120075178.972966178.9729661793.8216791793.821679101.104058-101.104058
linearFalse0.7660100.76601092.29174392.291743171.483495171.4834954.6019734.601973
True0.2427040.242704166.034183166.0341831747.8065061747.80650672.339538-72.339538
\n", + "
" + ], + "text/plain": [ + " R2 RMSE \\\n", + " median max median min \n", + "target method treeval \n", + "BMag_ha KPConv False 0.799995 0.814942 45.263761 43.539632 \n", + " True 0.779504 0.802537 47.525951 44.975245 \n", + " MSENet14 False 0.824725 0.829388 42.373149 41.805641 \n", + " True 0.822880 0.829022 42.595570 41.850502 \n", + " MSENet50 False 0.826648 0.835233 42.140045 41.083267 \n", + " True 0.823831 0.836632 42.480990 40.908553 \n", + " PointNet False 0.769759 0.772374 48.564827 48.288268 \n", + " True 0.766266 0.767973 48.931797 48.752772 \n", + " RF False 0.754110 0.754401 50.188101 50.158444 \n", + " True 0.151364 0.156952 93.237590 92.930126 \n", + " \\power{} False 0.760720 0.760720 49.508961 49.508961 \n", + " True 0.033963 0.033963 99.478063 99.478063 \n", + " linear False 0.761580 0.761580 49.419931 49.419931 \n", + " True 0.195146 0.195146 90.800628 90.800628 \n", + "V_ha KPConv False 0.799493 0.804989 85.433775 84.254618 \n", + " True 0.778436 0.792065 89.807741 87.001898 \n", + " MSENet14 False 0.822826 0.825807 80.309026 79.630574 \n", + " True 0.820874 0.825431 80.750211 79.716474 \n", + " MSENet50 False 0.824249 0.831391 79.985932 78.343874 \n", + " True 0.821666 0.832108 80.571431 78.177082 \n", + " PointNet False 0.776579 0.781437 90.183365 89.197636 \n", + " True 0.773293 0.776399 90.844089 90.219658 \n", + " RF False 0.756798 0.756909 94.091021 94.069611 \n", + " True 0.192256 0.197387 171.475273 170.929697 \n", + " \\power{} False 0.763328 0.763328 92.819183 92.819183 \n", + " True 0.120075 0.120075 178.972966 178.972966 \n", + " linear False 0.766010 0.766010 92.291743 92.291743 \n", + " True 0.242704 0.242704 166.034183 166.034183 \n", + "\n", + " MAPE mean bias \n", + " median min abs_median abs_min \n", + "target method treeval \n", + "BMag_ha KPConv False 396.684826 272.287735 0.460321 0.388961 \n", + " True 467.581224 246.927354 3.659850 -0.707264 \n", + " MSENet14 False 299.496832 192.777440 0.665678 -0.290542 \n", + " True 271.716014 131.120385 0.313184 0.122477 \n", + " MSENet50 False 469.104138 174.245499 0.837429 -0.114375 \n", + " True 339.700072 119.263752 0.889441 0.596189 \n", + " PointNet False 889.292563 625.091489 0.539027 0.118963 \n", + " True 896.834833 622.713294 2.463804 1.774162 \n", + " RF False 625.439118 616.635370 1.470275 1.458803 \n", + " True 7644.787089 7423.093933 47.625241 -47.521057 \n", + " \\power{} False 365.605998 365.605998 2.026579 2.026579 \n", + " True 7604.844036 7604.844036 57.525097 -57.525097 \n", + " linear False 425.605420 425.605420 1.893816 1.893816 \n", + " True 11448.500587 11448.500587 39.149145 -39.149145 \n", + "V_ha KPConv False 103.866178 85.633452 0.376568 0.284536 \n", + " True 126.542521 85.811632 7.885448 -1.011822 \n", + " MSENet14 False 99.105210 72.596996 0.514983 0.388571 \n", + " True 84.473114 70.097484 2.577354 1.829469 \n", + " MSENet50 False 131.525393 72.381283 0.168506 0.122771 \n", + " True 115.634500 78.422354 3.571748 2.645925 \n", + " PointNet False 205.366066 162.048816 1.990612 1.368627 \n", + " True 236.383012 174.903000 5.707739 4.577837 \n", + " RF False 223.651644 222.600379 3.978985 3.955142 \n", + " True 1683.778086 1676.524393 85.629363 -85.465170 \n", + " \\power{} False 223.653969 223.653969 4.496908 4.496908 \n", + " True 1793.821679 1793.821679 101.104058 -101.104058 \n", + " linear False 171.483495 171.483495 4.601973 4.601973 \n", + " True 1747.806506 1747.806506 72.339538 -72.339538 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(rr)" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "id": "226a4b8e-6fbd-4cff-8df3-f12d69735685", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| | ('R2', 'median') | ('R2', 'max') | ('RMSE', 'median') | ('RMSE', 'min') | ('MAPE', 'median') | ('MAPE', 'min') | ('mean bias', 'abs_median') | ('mean bias', 'abs_min') |\n", + "|:--------------------------------|-------------------:|----------------:|---------------------:|------------------:|---------------------:|------------------:|------------------------------:|---------------------------:|\n", + "| ('BMag_ha', 'KPConv', False) | 0.800 | 0.815 | 45.264 | 43.540 | 396.685 | 272.288 | 0.460 | 0.389 |\n", + "| ('BMag_ha', 'KPConv', True) | 0.780 | 0.803 | 47.526 | 44.975 | 467.581 | 246.927 | 3.660 | -0.707 |\n", + "| ('BMag_ha', 'MSENet14', False) | 0.825 | 0.829 | 42.373 | 41.806 | 299.497 | 192.777 | 0.666 | -0.291 |\n", + "| ('BMag_ha', 'MSENet14', True) | 0.823 | 0.829 | 42.596 | 41.851 | 271.716 | 131.120 | 0.313 | 0.122 |\n", + "| ('BMag_ha', 'MSENet50', False) | 0.827 | 0.835 | 42.140 | 41.083 | 469.104 | 174.245 | 0.837 | -0.114 |\n", + "| ('BMag_ha', 'MSENet50', True) | 0.824 | 0.837 | 42.481 | 40.909 | 339.700 | 119.264 | 0.889 | 0.596 |\n", + "| ('BMag_ha', 'PointNet', False) | 0.770 | 0.772 | 48.565 | 48.288 | 889.293 | 625.091 | 0.539 | 0.119 |\n", + "| ('BMag_ha', 'PointNet', True) | 0.766 | 0.768 | 48.932 | 48.753 | 896.835 | 622.713 | 2.464 | 1.774 |\n", + "| ('BMag_ha', 'RF', False) | 0.754 | 0.754 | 50.188 | 50.158 | 625.439 | 616.635 | 1.470 | 1.459 |\n", + "| ('BMag_ha', 'RF', True) | 0.151 | 0.157 | 93.238 | 92.930 | 7644.787 | 7423.094 | 47.625 | -47.521 |\n", + "| ('BMag_ha', '\\\\power{}', False) | 0.761 | 0.761 | 49.509 | 49.509 | 365.606 | 365.606 | 2.027 | 2.027 |\n", + "| ('BMag_ha', '\\\\power{}', True) | 0.034 | 0.034 | 99.478 | 99.478 | 7604.844 | 7604.844 | 57.525 | -57.525 |\n", + "| ('BMag_ha', 'linear', False) | 0.762 | 0.762 | 49.420 | 49.420 | 425.605 | 425.605 | 1.894 | 1.894 |\n", + "| ('BMag_ha', 'linear', True) | 0.195 | 0.195 | 90.801 | 90.801 | 11448.501 | 11448.501 | 39.149 | -39.149 |\n", + "| ('V_ha', 'KPConv', False) | 0.799 | 0.805 | 85.434 | 84.255 | 103.866 | 85.633 | 0.377 | 0.285 |\n", + "| ('V_ha', 'KPConv', True) | 0.778 | 0.792 | 89.808 | 87.002 | 126.543 | 85.812 | 7.885 | -1.012 |\n", + "| ('V_ha', 'MSENet14', False) | 0.823 | 0.826 | 80.309 | 79.631 | 99.105 | 72.597 | 0.515 | 0.389 |\n", + "| ('V_ha', 'MSENet14', True) | 0.821 | 0.825 | 80.750 | 79.716 | 84.473 | 70.097 | 2.577 | 1.829 |\n", + "| ('V_ha', 'MSENet50', False) | 0.824 | 0.831 | 79.986 | 78.344 | 131.525 | 72.381 | 0.169 | 0.123 |\n", + "| ('V_ha', 'MSENet50', True) | 0.822 | 0.832 | 80.571 | 78.177 | 115.634 | 78.422 | 3.572 | 2.646 |\n", + "| ('V_ha', 'PointNet', False) | 0.777 | 0.781 | 90.183 | 89.198 | 205.366 | 162.049 | 1.991 | 1.369 |\n", + "| ('V_ha', 'PointNet', True) | 0.773 | 0.776 | 90.844 | 90.220 | 236.383 | 174.903 | 5.708 | 4.578 |\n", + "| ('V_ha', 'RF', False) | 0.757 | 0.757 | 94.091 | 94.070 | 223.652 | 222.600 | 3.979 | 3.955 |\n", + "| ('V_ha', 'RF', True) | 0.192 | 0.197 | 171.475 | 170.930 | 1683.778 | 1676.524 | 85.629 | -85.465 |\n", + "| ('V_ha', '\\\\power{}', False) | 0.763 | 0.763 | 92.819 | 92.819 | 223.654 | 223.654 | 4.497 | 4.497 |\n", + "| ('V_ha', '\\\\power{}', True) | 0.120 | 0.120 | 178.973 | 178.973 | 1793.822 | 1793.822 | 101.104 | -101.104 |\n", + "| ('V_ha', 'linear', False) | 0.766 | 0.766 | 92.292 | 92.292 | 171.483 | 171.483 | 4.602 | 4.602 |\n", + "| ('V_ha', 'linear', True) | 0.243 | 0.243 | 166.034 | 166.034 | 1747.807 | 1747.807 | 72.340 | -72.340 |\n" + ] + } + ], + "source": [ + "print(rr.to_markdown(floatfmt=\".3f\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "id": "5b115ec7-d849-4a78-b14f-e0fb64b555c7", + "metadata": {}, + "outputs": [], + "source": [ + "def abs_min(x): return x.iloc[np.argmin(abs(x))]\n", + "def abs_max(x): return x.iloc[np.argmax(abs(x))]\n", + "def abs_median(x): return np.median(abs(x))\n", + "def avg_sign(x): return np.mean(np.sign(x))\n", + "def abs_mean(x): return np.mean(abs(x))\n", + "def arg_abs_min(x): return np.argmin(abs(x))\n", + "def arg_abs_max(x): return np.argmax(abs(x))\n", + "def arg_max(x): return np.argmax(abs(x))\n", + "\n", + "agg = {\n", + " 'nRMSE' : ['median', 'min'],\n", + "}\n", + "\n", + "rr = (\n", + " result_scores[\"test\"]\n", + " #.query(\"target == 'BMag_ha'\")\n", + " .query(\"corrected == True\")\n", + " .groupby([\"target\", \"method\", \"treeval\"])\n", + " .agg(agg)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "abd3ae88-8c20-473a-b415-4e05c58b643c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
nRMSE
medianmin
targetmethodtreeval
BMag_haKPConvFalse0.4211150.405074
True0.4421610.418431
MSENet14False0.3942220.388942
True0.3962910.389359
MSENet50False0.3920530.382221
True0.3952250.380596
PointNetFalse0.4518270.449254
True0.4552410.453575
RFFalse0.4669290.466653
True0.8674430.864583
\\power{}False0.4606100.460610
True0.9255020.925502
linearFalse0.4597820.459782
True0.8447710.844771
V_haKPConvFalse0.4227710.416936
True0.4444160.430531
MSENet14False0.3974110.394054
True0.3995940.394479
MSENet50False0.3958120.387686
True0.3987100.386861
PointNetFalse0.4462740.441397
True0.4495440.446454
RFFalse0.4656120.465506
True0.8485490.845850
\\power{}False0.4593180.459318
True0.8856520.885652
linearFalse0.4567080.456708
True0.8216240.821624
\n", + "
" + ], + "text/plain": [ + " nRMSE \n", + " median min\n", + "target method treeval \n", + "BMag_ha KPConv False 0.421115 0.405074\n", + " True 0.442161 0.418431\n", + " MSENet14 False 0.394222 0.388942\n", + " True 0.396291 0.389359\n", + " MSENet50 False 0.392053 0.382221\n", + " True 0.395225 0.380596\n", + " PointNet False 0.451827 0.449254\n", + " True 0.455241 0.453575\n", + " RF False 0.466929 0.466653\n", + " True 0.867443 0.864583\n", + " \\power{} False 0.460610 0.460610\n", + " True 0.925502 0.925502\n", + " linear False 0.459782 0.459782\n", + " True 0.844771 0.844771\n", + "V_ha KPConv False 0.422771 0.416936\n", + " True 0.444416 0.430531\n", + " MSENet14 False 0.397411 0.394054\n", + " True 0.399594 0.394479\n", + " MSENet50 False 0.395812 0.387686\n", + " True 0.398710 0.386861\n", + " PointNet False 0.446274 0.441397\n", + " True 0.449544 0.446454\n", + " RF False 0.465612 0.465506\n", + " True 0.848549 0.845850\n", + " \\power{} False 0.459318 0.459318\n", + " True 0.885652 0.885652\n", + " linear False 0.456708 0.456708\n", + " True 0.821624 0.821624" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(rr)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "ee83f2cc-3b27-40d4-8c8f-b0f8fe246aa6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\begin{tabular}{lllrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{nRMSE} \\\\\n", + " & & & median & min \\\\\n", + "target & method & treeval & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.421 & 0.405 \\\\\n", + " & & True & 0.442 & 0.418 \\\\\n", + " & MSENet14 & False & 0.394 & 0.389 \\\\\n", + " & & True & 0.396 & 0.389 \\\\\n", + " & MSENet50 & False & 0.392 & 0.382 \\\\\n", + " & & True & 0.395 & 0.381 \\\\\n", + " & PointNet & False & 0.452 & 0.449 \\\\\n", + " & & True & 0.455 & 0.454 \\\\\n", + " & RF & False & 0.467 & 0.467 \\\\\n", + " & & True & 0.867 & 0.865 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.461 & 0.461 \\\\\n", + " & & True & 0.926 & 0.926 \\\\\n", + " & linear & False & 0.460 & 0.460 \\\\\n", + " & & True & 0.845 & 0.845 \\\\\n", + "V\\_ha & KPConv & False & 0.423 & 0.417 \\\\\n", + " & & True & 0.444 & 0.431 \\\\\n", + " & MSENet14 & False & 0.397 & 0.394 \\\\\n", + " & & True & 0.400 & 0.394 \\\\\n", + " & MSENet50 & False & 0.396 & 0.388 \\\\\n", + " & & True & 0.399 & 0.387 \\\\\n", + " & PointNet & False & 0.446 & 0.441 \\\\\n", + " & & True & 0.450 & 0.446 \\\\\n", + " & RF & False & 0.466 & 0.466 \\\\\n", + " & & True & 0.849 & 0.846 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.459 & 0.459 \\\\\n", + " & & True & 0.886 & 0.886 \\\\\n", + " & linear & False & 0.457 & 0.457 \\\\\n", + " & & True & 0.822 & 0.822 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/2767379605.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " print(rr.to_latex(formatters=[lambda x: \"%.3f\" % x] * 2))\n" + ] + } + ], + "source": [ + "print(rr.to_latex(formatters=[lambda x: \"%.3f\" % x] * 2))" + ] + }, + { + "cell_type": "markdown", + "id": "f1d45cb1-b689-451c-8588-e8db6dce2cd8", + "metadata": {}, + "source": [ + "### Border artifact augmentation effect" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "id": "fdd3dc7a-c7b1-46f1-a739-0712b7bbbdbd", + "metadata": {}, + "outputs": [], + "source": [ + "pd.set_option(\"display.precision\", 2)\n", + "pd.set_option(\"display.float_format\", lambda x: \"%.2f\" % x)" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "082dffcd-84b2-4f1a-be30-678df2ef0bdd", + "metadata": {}, + "outputs": [], + "source": [ + "agg = {\n", + " \"R2\": \"median\",\n", + " #'MSE' : ['median', 'min'],\n", + " 'RMSE' : 'median',\n", + " #\"MAPE\": \"median\",\n", + " \"mean bias\": abs_median,\n", + " #'rel. error' : ['median', \"min\"],\n", + "}\n", + "\n", + "\n", + "rr = (\n", + " result_scores[\"test\"]\n", + " .query(\"corrected == True\")\n", + " .groupby([\"target\", \"method\", \"treeval\"])\n", + " .agg(agg)\n", + ")\n", + "\n", + "agg_ = {\n", + " \"R2\": [\"median\", \"max\"],\n", + " #'MSE' : ['median', 'min'],\n", + " 'RMSE' : ['median', 'min'],\n", + " #\"mean error\": [\"median\", \"max\", \"min\"],\n", + " \"mean bias\": [abs_median, abs_min],\n", + " #'rel. error' : ['median', \"min\"],\n", + "}\n", + "\n", + "\n", + "rr_ = (\n", + " result_scores[\"test\"]\n", + " .query(\"corrected == True\").query(\"treeval\")\n", + " .groupby([\"target\", \"method\", \"treeval\"])\n", + " .agg(agg_)\n", + ")\n", + "rr_.columns = [' '.join(col).strip() for col in rr_.columns.values]" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "a01acc60-5f7b-4f0f-b29c-c77911c6a9d7", + "metadata": {}, + "outputs": [], + "source": [ + "rr_treeval = rr.query(\"treeval\").abs().reset_index().drop(columns=[\"treeval\"]).set_index([\"target\", \"method\"])\n", + "rr_notreeval = rr.query(\"not treeval\").abs().reset_index().drop(columns=[\"treeval\"]).set_index([\"target\", \"method\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "bbc17b4c-dc7d-47a6-8605-d04273157f57", + "metadata": {}, + "outputs": [], + "source": [ + "rr_diff = rr_treeval - rr_notreeval\n", + "rr_full = rr_.join(rr_diff, rsuffix=\"_diff\").reset_index().drop(columns=[\"treeval\"]).set_index([\"target\", \"method\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "1dbb9a74-9592-46d9-a084-34f0fc50fa62", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['R2 median', 'R2 max', 'RMSE median', 'RMSE min',\n", + " 'mean bias abs_median', 'mean bias abs_min', 'R2', 'RMSE', 'mean bias'],\n", + " dtype='object')" + ] + }, + "execution_count": 64, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rr_full.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "5d07fcf5-b96a-4226-8274-110608d5ab8e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
R2 medianR2R2 maxRMSE medianRMSERMSE minmean bias abs_medianmean biasmean bias abs_min
targetmethod
BMag_haKPConv0.78-0.020.8047.532.2644.983.663.20-0.71
MSENet140.82-0.000.8342.600.2241.850.31-0.350.12
MSENet500.82-0.000.8442.480.3440.910.890.050.60
PointNet0.77-0.000.7748.930.3748.752.461.921.77
RF0.15-0.600.1693.2443.0592.9347.6346.15-47.52
\\power{}0.03-0.730.0399.4849.9799.4857.5355.50-57.53
linear0.20-0.570.2090.8041.3890.8039.1537.26-39.15
V_haKPConv0.78-0.020.7989.814.3787.007.897.51-1.01
MSENet140.82-0.000.8380.750.4479.722.582.061.83
MSENet500.82-0.000.8380.570.5978.183.573.402.65
PointNet0.77-0.000.7890.840.6690.225.713.724.58
RF0.19-0.560.20171.4877.38170.9385.6381.65-85.47
\\power{}0.12-0.640.12178.9786.15178.97101.1096.61-101.10
linear0.24-0.520.24166.0373.74166.0372.3467.74-72.34
\n", + "
" + ], + "text/plain": [ + " R2 median R2 R2 max RMSE median RMSE RMSE min \\\n", + "target method \n", + "BMag_ha KPConv 0.78 -0.02 0.80 47.53 2.26 44.98 \n", + " MSENet14 0.82 -0.00 0.83 42.60 0.22 41.85 \n", + " MSENet50 0.82 -0.00 0.84 42.48 0.34 40.91 \n", + " PointNet 0.77 -0.00 0.77 48.93 0.37 48.75 \n", + " RF 0.15 -0.60 0.16 93.24 43.05 92.93 \n", + " \\power{} 0.03 -0.73 0.03 99.48 49.97 99.48 \n", + " linear 0.20 -0.57 0.20 90.80 41.38 90.80 \n", + "V_ha KPConv 0.78 -0.02 0.79 89.81 4.37 87.00 \n", + " MSENet14 0.82 -0.00 0.83 80.75 0.44 79.72 \n", + " MSENet50 0.82 -0.00 0.83 80.57 0.59 78.18 \n", + " PointNet 0.77 -0.00 0.78 90.84 0.66 90.22 \n", + " RF 0.19 -0.56 0.20 171.48 77.38 170.93 \n", + " \\power{} 0.12 -0.64 0.12 178.97 86.15 178.97 \n", + " linear 0.24 -0.52 0.24 166.03 73.74 166.03 \n", + "\n", + " mean bias abs_median mean bias mean bias abs_min \n", + "target method \n", + "BMag_ha KPConv 3.66 3.20 -0.71 \n", + " MSENet14 0.31 -0.35 0.12 \n", + " MSENet50 0.89 0.05 0.60 \n", + " PointNet 2.46 1.92 1.77 \n", + " RF 47.63 46.15 -47.52 \n", + " \\power{} 57.53 55.50 -57.53 \n", + " linear 39.15 37.26 -39.15 \n", + "V_ha KPConv 7.89 7.51 -1.01 \n", + " MSENet14 2.58 2.06 1.83 \n", + " MSENet50 3.57 3.40 2.65 \n", + " PointNet 5.71 3.72 4.58 \n", + " RF 85.63 81.65 -85.47 \n", + " \\power{} 101.10 96.61 -101.10 \n", + " linear 72.34 67.74 -72.34 " + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "rr_full[\n", + "[\"R2 median\", \"R2\", \"R2 max\", \"RMSE median\", \"RMSE\", \"RMSE min\", \"mean bias abs_median\", \"mean bias\", \"mean bias abs_min\"]\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "id": "289db645-f91c-404c-bb9f-99e2099df080", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\begin{tabular}{llrrrrrrrrr}\n", + "\\toprule\n", + " & & R2 median & R2 & R2 max & RMSE median & RMSE & RMSE min & mean bias abs\\_median & mean bias & mean bias abs\\_min \\\\\n", + "target & method & & & & & & & & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & 0.780 & -0.020 & 0.803 & 47.53 & 2.26 & 44.98 & 3.66 & 3.20 & -0.71 \\\\\n", + " & MSENet14 & 0.823 & -0.002 & 0.829 & 42.60 & 0.22 & 41.85 & 0.31 & -0.35 & 0.12 \\\\\n", + " & MSENet50 & 0.824 & -0.003 & 0.837 & 42.48 & 0.34 & 40.91 & 0.89 & 0.05 & 0.60 \\\\\n", + " & PointNet & 0.766 & -0.003 & 0.768 & 48.93 & 0.37 & 48.75 & 2.46 & 1.92 & 1.77 \\\\\n", + " & RF & 0.151 & -0.603 & 0.157 & 93.24 & 43.05 & 92.93 & 47.63 & 46.15 & -47.52 \\\\\n", + " & \\textbackslash power\\{\\} & 0.034 & -0.727 & 0.034 & 99.48 & 49.97 & 99.48 & 57.53 & 55.50 & -57.53 \\\\\n", + " & linear & 0.195 & -0.566 & 0.195 & 90.80 & 41.38 & 90.80 & 39.15 & 37.26 & -39.15 \\\\\n", + "V\\_ha & KPConv & 0.778 & -0.021 & 0.792 & 89.81 & 4.37 & 87.00 & 7.89 & 7.51 & -1.01 \\\\\n", + " & MSENet14 & 0.821 & -0.002 & 0.825 & 80.75 & 0.44 & 79.72 & 2.58 & 2.06 & 1.83 \\\\\n", + " & MSENet50 & 0.822 & -0.003 & 0.832 & 80.57 & 0.59 & 78.18 & 3.57 & 3.40 & 2.65 \\\\\n", + " & PointNet & 0.773 & -0.003 & 0.776 & 90.84 & 0.66 & 90.22 & 5.71 & 3.72 & 4.58 \\\\\n", + " & RF & 0.192 & -0.565 & 0.197 & 171.48 & 77.38 & 170.93 & 85.63 & 81.65 & -85.47 \\\\\n", + " & \\textbackslash power\\{\\} & 0.120 & -0.643 & 0.120 & 178.97 & 86.15 & 178.97 & 101.10 & 96.61 & -101.10 \\\\\n", + " & linear & 0.243 & -0.523 & 0.243 & 166.03 & 73.74 & 166.03 & 72.34 & 67.74 & -72.34 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/559235959.py:1: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " print(rr_full[\n" + ] + } + ], + "source": [ + "print(rr_full[\n", + "[\"R2 median\", \"R2\", \"R2 max\", \"RMSE median\", \"RMSE\", \"RMSE min\", \"mean bias abs_median\", \"mean bias\", \"mean bias abs_min\"]\n", + "].to_latex(formatters=[lambda x: \"%.3f\" % x] * 3 + [lambda x: \"%.2f\" % x] * 6))" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "id": "2b115509-d3ab-4fe8-8147-7c60d2311cda", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
R2mean bias
medianabs_median
targetmethodruntreeval
BMag_haMSENet500False0.820.84
True0.821.17
1False0.830.66
True0.820.85
2False0.840.96
True0.840.60
3False0.820.11
True0.821.56
4False0.830.96
True0.830.89
\n", + "
" + ], + "text/plain": [ + " R2 mean bias\n", + " median abs_median\n", + "target method run treeval \n", + "BMag_ha MSENet50 0 False 0.82 0.84\n", + " True 0.82 1.17\n", + " 1 False 0.83 0.66\n", + " True 0.82 0.85\n", + " 2 False 0.84 0.96\n", + " True 0.84 0.60\n", + " 3 False 0.82 0.11\n", + " True 0.82 1.56\n", + " 4 False 0.83 0.96\n", + " True 0.83 0.89" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def abs_min(x): return x.iloc[np.argmin(abs(x))]\n", + "def abs_max(x): return x.iloc[np.argmax(abs(x))]\n", + "def abs_median(x): return np.median(abs(x))\n", + "def abs_mean(x): return np.mean(abs(x))\n", + "def arg_abs_min(x): return np.argmin(abs(x))\n", + "def arg_abs_max(x): return np.argmax(abs(x))\n", + "def arg_max(x): return np.argmax(abs(x))\n", + "\n", + "agg = {\n", + " \"R2\": [\"median\"\n", + " ],\n", + " #'MSE' : ['median', 'min'],\n", + " #'RMSE' : ['median', 'min'],\n", + " #'MAPE' : ['median', 'min'],\n", + " #\"mean error\": [\"median\", \"max\", \"min\"],\n", + " \"mean bias\": [abs_median],\n", + " #'rel. error' : ['median', \"min\"],\n", + "}\n", + "\n", + "\n", + "display(\n", + " result_scores[\"test\"]\n", + " .query(\"target == 'BMag_ha'\")\n", + " #.query(\"method in ['PointNet', 'PointNet_treeval']\")\n", + " #.query(\"method in ['MSENet50_treeadd', 'MSENet50_treeadd_treeval']\")\n", + " .query(\"method in ['MSENet50', 'MSENet50_treeval']\")\n", + " #.query(\"method in ['MSENet14_treeadd', 'MSENet14_treeadd_treeval']\")\n", + " #.query(\"method in ['MSENet14', 'MSENet14_treeval']\")\n", + " .query(\"corrected == True\")\n", + " .groupby([\"target\", \"method\", \"run\", \"treeval\"])\n", + " .agg(agg)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "436f2d31-547d-4d77-aa4a-87aaa03a521e", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 68, + "id": "962ef2d9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train\n", + "\\begin{tabular}{lllrrrrrrrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{R2} & \\multicolumn{2}{l}{RMSE} & \\multicolumn{2}{l}{MAPE} & \\multicolumn{2}{l}{mean bias} \\\\\n", + " & & & median & max & median & min & median & min & abs\\_median & abs\\_min \\\\\n", + "target & method & treeval & & & & & & & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.799 & 0.828 & 46.43 & 42.93 & 287.18 & 183.09 & 1.17 & 0.78 \\\\\n", + " & MSENet14 & False & 0.804 & 0.841 & 45.85 & 41.26 & 329.84 & 161.76 & 1.60 & 1.18 \\\\\n", + " & MSENet50 & False & 0.796 & 0.804 & 46.68 & 45.80 & 453.20 & 181.45 & 1.67 & 1.39 \\\\\n", + " & PointNet & False & 0.714 & 0.726 & 55.33 & 54.16 & 887.95 & 643.06 & 2.47 & 2.30 \\\\\n", + " & RF & False & 0.723 & 0.723 & 54.45 & 54.42 & 797.90 & 779.69 & 2.90 & 2.89 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.700 & 0.700 & 56.64 & 56.64 & 690.99 & 690.99 & 2.50 & 2.50 \\\\\n", + " & linear & False & 0.706 & 0.706 & 56.09 & 56.09 & 802.65 & 802.65 & 2.79 & 2.79 \\\\\n", + "V\\_ha & KPConv & False & 0.793 & 0.824 & 88.27 & 81.56 & 97.46 & 81.15 & 2.20 & 1.42 \\\\\n", + " & MSENet14 & False & 0.800 & 0.838 & 86.84 & 78.16 & 111.12 & 71.91 & 3.02 & 2.21 \\\\\n", + " & MSENet50 & False & 0.795 & 0.803 & 87.95 & 86.18 & 130.12 & 72.89 & 3.29 & 2.42 \\\\\n", + " & PointNet & False & 0.721 & 0.732 & 102.51 & 100.61 & 211.84 & 172.15 & 4.52 & 4.29 \\\\\n", + " & RF & False & 0.728 & 0.728 & 101.31 & 101.26 & 204.57 & 202.69 & 5.41 & 5.38 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.703 & 0.703 & 105.84 & 105.84 & 197.07 & 197.07 & 4.56 & 4.56 \\\\\n", + " & linear & False & 0.708 & 0.708 & 104.92 & 104.92 & 185.27 & 185.27 & 5.27 & 5.27 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n", + "val\n", + "\\begin{tabular}{lllrrrrrrrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{R2} & \\multicolumn{2}{l}{RMSE} & \\multicolumn{2}{l}{MAPE} & \\multicolumn{2}{l}{mean bias} \\\\\n", + " & & & median & max & median & min & median & min & abs\\_median & abs\\_min \\\\\n", + "target & method & treeval & & & & & & & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.790 & 0.802 & 48.46 & 47.11 & 362.01 & 239.81 & 0.46 & -0.01 \\\\\n", + " & MSENet14 & False & 0.808 & 0.811 & 46.36 & 45.99 & 188.78 & 151.14 & 0.25 & -0.04 \\\\\n", + " & MSENet50 & False & 0.805 & 0.810 & 46.75 & 46.16 & 330.53 & 192.81 & 0.25 & -0.06 \\\\\n", + " & PointNet & False & 0.710 & 0.713 & 56.94 & 56.69 & 814.77 & 592.68 & 0.15 & -0.03 \\\\\n", + " & RF & False & 0.758 & 0.758 & 52.08 & 52.01 & 445.04 & 429.35 & 2.03 & 1.98 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.739 & 0.739 & 54.02 & 54.02 & 385.32 & 385.32 & 1.96 & 1.96 \\\\\n", + " & linear & False & 0.739 & 0.739 & 54.02 & 54.02 & 283.79 & 283.79 & 2.35 & 2.35 \\\\\n", + "V\\_ha & KPConv & False & 0.808 & 0.817 & 85.77 & 83.78 & 90.12 & 73.12 & 1.34 & -0.49 \\\\\n", + " & MSENet14 & False & 0.821 & 0.826 & 82.75 & 81.69 & 80.74 & 69.17 & 0.46 & 0.08 \\\\\n", + " & MSENet50 & False & 0.824 & 0.826 & 82.11 & 81.65 & 94.88 & 62.36 & 0.20 & 0.04 \\\\\n", + " & PointNet & False & 0.745 & 0.750 & 98.82 & 97.83 & 221.24 & 194.86 & 0.64 & -0.00 \\\\\n", + " & RF & False & 0.784 & 0.785 & 90.93 & 90.81 & 110.56 & 109.67 & 3.62 & 3.52 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.766 & 0.766 & 94.66 & 94.66 & 118.16 & 118.16 & 3.30 & 3.30 \\\\\n", + " & linear & False & 0.766 & 0.766 & 94.65 & 94.65 & 98.58 & 98.58 & 4.29 & 4.29 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n", + "test\n", + "\\begin{tabular}{lllrrrrrrrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{R2} & \\multicolumn{2}{l}{RMSE} & \\multicolumn{2}{l}{MAPE} & \\multicolumn{2}{l}{mean bias} \\\\\n", + " & & & median & max & median & min & median & min & abs\\_median & abs\\_min \\\\\n", + "target & method & treeval & & & & & & & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.800 & 0.815 & 45.26 & 43.54 & 396.68 & 272.29 & 0.46 & 0.39 \\\\\n", + " & MSENet14 & False & 0.825 & 0.829 & 42.37 & 41.81 & 299.50 & 192.78 & 0.67 & -0.29 \\\\\n", + " & MSENet50 & False & 0.827 & 0.835 & 42.14 & 41.08 & 469.10 & 174.25 & 0.84 & -0.11 \\\\\n", + " & PointNet & False & 0.770 & 0.772 & 48.56 & 48.29 & 889.29 & 625.09 & 0.54 & 0.12 \\\\\n", + " & RF & False & 0.754 & 0.754 & 50.19 & 50.16 & 625.44 & 616.64 & 1.47 & 1.46 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.761 & 0.761 & 49.51 & 49.51 & 365.61 & 365.61 & 2.03 & 2.03 \\\\\n", + " & linear & False & 0.762 & 0.762 & 49.42 & 49.42 & 425.61 & 425.61 & 1.89 & 1.89 \\\\\n", + "V\\_ha & KPConv & False & 0.799 & 0.805 & 85.43 & 84.25 & 103.87 & 85.63 & 0.38 & 0.28 \\\\\n", + " & MSENet14 & False & 0.823 & 0.826 & 80.31 & 79.63 & 99.11 & 72.60 & 0.51 & 0.39 \\\\\n", + " & MSENet50 & False & 0.824 & 0.831 & 79.99 & 78.34 & 131.53 & 72.38 & 0.17 & 0.12 \\\\\n", + " & PointNet & False & 0.777 & 0.781 & 90.18 & 89.20 & 205.37 & 162.05 & 1.99 & 1.37 \\\\\n", + " & RF & False & 0.757 & 0.757 & 94.09 & 94.07 & 223.65 & 222.60 & 3.98 & 3.96 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.763 & 0.763 & 92.82 & 92.82 & 223.65 & 223.65 & 4.50 & 4.50 \\\\\n", + " & linear & False & 0.766 & 0.766 & 92.29 & 92.29 & 171.48 & 171.48 & 4.60 & 4.60 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/4029165409.py:15: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " result_scores[split]\n", + "/tmp/ipykernel_42898/4029165409.py:15: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " result_scores[split]\n", + "/tmp/ipykernel_42898/4029165409.py:15: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " result_scores[split]\n" + ] + } + ], + "source": [ + "pd.set_option(\"display.precision\", 3)\n", + "pd.set_option(\"display.float_format\", lambda x: \"%.3f\" % x)\n", + "\n", + "agg = {\n", + " \"R2\": [\"median\", \"max\"],\n", + " #'MSE' : ['median', 'min'],\n", + " 'RMSE' : ['median', 'min'],\n", + " 'MAPE' : ['median', 'min'],\n", + " \"mean bias\": [abs_median, abs_min],\n", + "}\n", + "\n", + "for split in splits:\n", + " print(split)\n", + " print(\n", + " result_scores[split]\n", + " .query(\"corrected == True\")\n", + " .query(\"treeval == False\")\n", + " .groupby([\"target\", \"method\", \"treeval\"])[[\"R2\", \"RMSE\", \"MAPE\", \"mean bias\"]]\n", + " .agg(agg)\n", + " .to_latex(formatters=[lambda x: \"%.3f\" % x] * 2 + [lambda x: \"%.2f\" % x] * 6)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 69, + "id": "2b07de00-3b1d-4b70-9578-5eb6a8c31b33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train\n", + "\\begin{tabular}{lllrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{nRMSE} \\\\\n", + " & & & median & min \\\\\n", + "target & method & treeval & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.434 & 0.401 \\\\\n", + " & & True & 0.473 & 0.401 \\\\\n", + " & MSENet14 & False & 0.428 & 0.386 \\\\\n", + " & & True & 0.429 & 0.388 \\\\\n", + " & MSENet50 & False & 0.436 & 0.428 \\\\\n", + " & & True & 0.437 & 0.430 \\\\\n", + " & PointNet & False & 0.517 & 0.506 \\\\\n", + " & & True & 0.525 & 0.512 \\\\\n", + " & RF & False & 0.509 & 0.508 \\\\\n", + " & & True & 0.877 & 0.875 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.529 & 0.529 \\\\\n", + " & & True & 0.937 & 0.937 \\\\\n", + " & linear & False & 0.524 & 0.524 \\\\\n", + " & & True & 0.863 & 0.863 \\\\\n", + "V\\_ha & KPConv & False & 0.441 & 0.407 \\\\\n", + " & & True & 0.482 & 0.409 \\\\\n", + " & MSENet14 & False & 0.434 & 0.390 \\\\\n", + " & & True & 0.434 & 0.392 \\\\\n", + " & MSENet50 & False & 0.439 & 0.430 \\\\\n", + " & & True & 0.440 & 0.431 \\\\\n", + " & PointNet & False & 0.512 & 0.502 \\\\\n", + " & & True & 0.521 & 0.508 \\\\\n", + " & RF & False & 0.506 & 0.506 \\\\\n", + " & & True & 0.861 & 0.859 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.529 & 0.529 \\\\\n", + " & & True & 0.909 & 0.909 \\\\\n", + " & linear & False & 0.524 & 0.524 \\\\\n", + " & & True & 0.851 & 0.851 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n", + "val\n", + "\\begin{tabular}{lllrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{nRMSE} \\\\\n", + " & & & median & min \\\\\n", + "target & method & treeval & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.431 & 0.419 \\\\\n", + " & & True & 0.444 & 0.428 \\\\\n", + " & MSENet14 & False & 0.412 & 0.409 \\\\\n", + " & & True & 0.415 & 0.409 \\\\\n", + " & MSENet50 & False & 0.416 & 0.411 \\\\\n", + " & & True & 0.417 & 0.413 \\\\\n", + " & PointNet & False & 0.506 & 0.504 \\\\\n", + " & & True & 0.514 & 0.510 \\\\\n", + " & RF & False & 0.463 & 0.463 \\\\\n", + " & & True & 0.805 & 0.803 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.481 & 0.481 \\\\\n", + " & & True & 0.857 & 0.857 \\\\\n", + " & linear & False & 0.481 & 0.481 \\\\\n", + " & & True & 0.798 & 0.798 \\\\\n", + "V\\_ha & KPConv & False & 0.408 & 0.398 \\\\\n", + " & & True & 0.422 & 0.407 \\\\\n", + " & MSENet14 & False & 0.394 & 0.389 \\\\\n", + " & & True & 0.396 & 0.389 \\\\\n", + " & MSENet50 & False & 0.391 & 0.388 \\\\\n", + " & & True & 0.392 & 0.391 \\\\\n", + " & PointNet & False & 0.470 & 0.465 \\\\\n", + " & & True & 0.480 & 0.472 \\\\\n", + " & RF & False & 0.432 & 0.432 \\\\\n", + " & & True & 0.776 & 0.774 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.450 & 0.450 \\\\\n", + " & & True & 0.813 & 0.813 \\\\\n", + " & linear & False & 0.450 & 0.450 \\\\\n", + " & & True & 0.767 & 0.767 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n", + "test\n", + "\\begin{tabular}{lllrr}\n", + "\\toprule\n", + " & & & \\multicolumn{2}{l}{nRMSE} \\\\\n", + " & & & median & min \\\\\n", + "target & method & treeval & & \\\\\n", + "\\midrule\n", + "BMag\\_ha & KPConv & False & 0.421 & 0.405 \\\\\n", + " & & True & 0.442 & 0.418 \\\\\n", + " & MSENet14 & False & 0.394 & 0.389 \\\\\n", + " & & True & 0.396 & 0.389 \\\\\n", + " & MSENet50 & False & 0.392 & 0.382 \\\\\n", + " & & True & 0.395 & 0.381 \\\\\n", + " & PointNet & False & 0.452 & 0.449 \\\\\n", + " & & True & 0.455 & 0.454 \\\\\n", + " & RF & False & 0.467 & 0.467 \\\\\n", + " & & True & 0.867 & 0.865 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.461 & 0.461 \\\\\n", + " & & True & 0.926 & 0.926 \\\\\n", + " & linear & False & 0.460 & 0.460 \\\\\n", + " & & True & 0.845 & 0.845 \\\\\n", + "V\\_ha & KPConv & False & 0.423 & 0.417 \\\\\n", + " & & True & 0.444 & 0.431 \\\\\n", + " & MSENet14 & False & 0.397 & 0.394 \\\\\n", + " & & True & 0.400 & 0.394 \\\\\n", + " & MSENet50 & False & 0.396 & 0.388 \\\\\n", + " & & True & 0.399 & 0.387 \\\\\n", + " & PointNet & False & 0.446 & 0.441 \\\\\n", + " & & True & 0.450 & 0.446 \\\\\n", + " & RF & False & 0.466 & 0.466 \\\\\n", + " & & True & 0.849 & 0.846 \\\\\n", + " & \\textbackslash power\\{\\} & False & 0.459 & 0.459 \\\\\n", + " & & True & 0.886 & 0.886 \\\\\n", + " & linear & False & 0.457 & 0.457 \\\\\n", + " & & True & 0.822 & 0.822 \\\\\n", + "\\bottomrule\n", + "\\end{tabular}\n", + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/876785269.py:11: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " result_scores[split]\n", + "/tmp/ipykernel_42898/876785269.py:11: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " result_scores[split]\n", + "/tmp/ipykernel_42898/876785269.py:11: FutureWarning: In future versions `DataFrame.to_latex` is expected to utilise the base implementation of `Styler.to_latex` for formatting and rendering. The arguments signature may therefore change. It is recommended instead to use `DataFrame.style.to_latex` which also contains additional functionality.\n", + " result_scores[split]\n" + ] + } + ], + "source": [ + "pd.set_option(\"display.precision\", 3)\n", + "pd.set_option(\"display.float_format\", lambda x: \"%.3f\" % x)\n", + "\n", + "agg = {\n", + " 'nRMSE' : ['median', 'min'],\n", + "}\n", + "\n", + "for split in splits:\n", + " print(split)\n", + " print(\n", + " result_scores[split]\n", + " .query(\"corrected == True\")\n", + " #.query(\"treeval == False\")\n", + " .groupby([\"target\", \"method\", \"treeval\"])[[\"nRMSE\"]]\n", + " .agg(agg)\n", + " .to_latex(formatters=[lambda x: \"%.3f\" % x] * 2)\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "736680ac-c3df-436e-a822-31042b42e3b4", + "metadata": {}, + "source": [ + "# Statistical Tests" + ] + }, + { + "cell_type": "code", + "execution_count": 70, + "id": "f8499e57", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['linear', 'linear_treeval', 'RF', 'RF_treeval', 'KPConv', 'KPConv_treeval', 'PointNet', 'PointNet_treeval', '\\\\power{}', '\\\\power{}_treeval', 'MSENet14', 'MSENet14_treeval', 'MSENet50', 'MSENet50_treeval'])" + ] + }, + "execution_count": 70, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_corrected.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 71, + "id": "b3252e87-2731-4164-b587-22bab5cba9cd", + "metadata": {}, + "outputs": [], + "source": [ + "split = \"test\"\n", + "baseline = '\\\\power{}'\n", + "rr = {r: results_corrected[r].query(\"split == @split\") for r in results_corrected}\n", + "rb = rr[baseline].reset_index()\n", + "favorite = \"MSENet50\"\n", + "rf = rr[favorite].query(\"run == 0\").reset_index()" + ] + }, + { + "cell_type": "code", + "execution_count": 78, + "id": "f4a28c1a-77f1-45d3-88cf-dcd2045c8e7d", + "metadata": {}, + "outputs": [], + "source": [ + "target = \"BMag_ha\"\n", + "pred = \"BMag_ha_pred\"" + ] + }, + { + "cell_type": "code", + "execution_count": 82, + "id": "b3c56ead-7d60-4e2a-82d6-b0b54ed09b4a", + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.stats import wilcoxon, ttest_rel\n", + "# effect size from wilcoxon: Kerby, Dave S. (2014), \"The simple difference formula: An approach to teaching nonparametric correlation.\", Comprehensive Psychology, 3: 11.IT.3.1, doi:10.2466/11.IT.3.1\n", + "\n", + "def get_total_ranksum(x, y):\n", + " diff = np.array(x) - np.array(y)\n", + " samples = len(diff) - (diff == 0).sum() # remove ties\n", + " return sum([i+1 for i in range(samples)])\n", + "\n", + "stat_columns = [\n", + " \"method\", \"target\", \"run\",\n", + " \"wilcoxon_two_T\", \"wilcoxon_two_p\", \"wilcoxon_two_es\", \n", + "]\n", + "stat_diff_results = []\n", + "for target in target_vars:\n", + " pred = f\"{target}_pred\"\n", + " y_b = abs(rb[target].values - rb[pred].values)\n", + " # y_b = (rb[target].values)\n", + " # y_b = [np.mean(y_b[i]) for i in index_folds]\n", + " for run_i in range(10):\n", + " for method, df in rr.items():\n", + " if method == baseline or \"treeval\" in method:\n", + " continue\n", + " df = df.query(\"run==@run_i\")\n", + " n_runs = len(df.run.unique())\n", + " if n_runs == 0:\n", + " continue\n", + " yy = rb[target].values[None, :].flatten()\n", + " y_m = abs(yy - df[pred].values)\n", + "\n", + " S = get_total_ranksum(y_b, y_m)\n", + " wilcoxon_two_T, wilcoxon_two_p = wilcoxon(y_b, y_m, alternative=\"two-sided\")\n", + " wilcoxon_two_effect_size = ((S-wilcoxon_two_T)/S) - (wilcoxon_two_T/S)\n", + "\n", + " stat_diff_results.append([\n", + " method, target, run_i,\n", + " wilcoxon_two_T, wilcoxon_two_p, wilcoxon_two_effect_size,\n", + " ])\n", + " \n", + "stat_diff_results = pd.DataFrame(stat_diff_results, columns = stat_columns)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 83, + "id": "5d173654-6e9a-4e2a-a9ef-133070aa6b84", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
runwilcoxon_two_Twilcoxon_two_pwilcoxon_two_es
targetmethod
BMag_haKPConv2.0000171037.00000.00000.1766
MSENet142.0000156050.00000.00000.2487
MSENet502.0000154345.00000.00000.2569
PointNet2.0000194488.00000.09610.0636
RF2.0000207225.00000.95150.0023
linear0.0000205608.00000.79150.0101
V_haKPConv2.0000172431.00000.00000.1698
MSENet142.0000156289.00000.00000.2476
MSENet502.0000155540.00000.00000.2512
PointNet2.0000188915.00000.01800.0905
RF2.0000199221.00000.28540.0409
linear0.0000192746.00000.05960.0720
\n", + "
" + ], + "text/plain": [ + " run wilcoxon_two_T wilcoxon_two_p wilcoxon_two_es\n", + "target method \n", + "BMag_ha KPConv 2.0000 171037.0000 0.0000 0.1766\n", + " MSENet14 2.0000 156050.0000 0.0000 0.2487\n", + " MSENet50 2.0000 154345.0000 0.0000 0.2569\n", + " PointNet 2.0000 194488.0000 0.0961 0.0636\n", + " RF 2.0000 207225.0000 0.9515 0.0023\n", + " linear 0.0000 205608.0000 0.7915 0.0101\n", + "V_ha KPConv 2.0000 172431.0000 0.0000 0.1698\n", + " MSENet14 2.0000 156289.0000 0.0000 0.2476\n", + " MSENet50 2.0000 155540.0000 0.0000 0.2512\n", + " PointNet 2.0000 188915.0000 0.0180 0.0905\n", + " RF 2.0000 199221.0000 0.2854 0.0409\n", + " linear 0.0000 192746.0000 0.0596 0.0720" + ] + }, + "execution_count": 83, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.set_option(\"display.precision\", 3)\n", + "pd.set_option(\"display.float_format\", lambda x: \"%.4f\" % x)\n", + "stat_diff_results.query(\"target != 'Cag_ha'\").groupby([\"target\", \"method\"]).median()" + ] + }, + { + "cell_type": "markdown", + "id": "a46929b8-7766-45e3-ab23-51f82d7ed8df", + "metadata": {}, + "source": [ + "# Statistical Tests Aggregating Trees" + ] + }, + { + "cell_type": "code", + "execution_count": 85, + "id": "08ace53f-8789-4c15-9bfb-d073113f3da3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['linear', 'linear_treeval', 'RF', 'RF_treeval', 'KPConv', 'KPConv_treeval', 'PointNet', 'PointNet_treeval', '\\\\power{}', '\\\\power{}_treeval', 'MSENet14', 'MSENet14_treeval', 'MSENet50', 'MSENet50_treeval'])" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results_corrected.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 86, + "id": "15ba5a6d-c26b-463c-b0f4-d8a148dad4a2", + "metadata": {}, + "outputs": [], + "source": [ + "split = \"test\"\n", + "rr = {r: results_corrected[r].query(\"split == @split\") for r in results_corrected}" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "id": "37773f1e-32f0-4479-962c-f6ce608d8438", + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.stats import wilcoxon, ttest_rel\n", + "# effect size from wilcoxon: Kerby, Dave S. (2014), \"The simple difference formula: An approach to teaching nonparametric correlation.\", Comprehensive Psychology, 3: 11.IT.3.1, doi:10.2466/11.IT.3.1\n", + "\n", + "def get_total_ranksum(x, y):\n", + " diff = np.array(x) - np.array(y)\n", + " samples = len(diff) - (diff == 0).sum() # remove ties\n", + " return sum([i+1 for i in range(samples)])\n", + "\n", + "stat_columns = [\n", + " \"method\", \"target\", \"run\",\n", + " \"wilcoxon_two_T\", \"wilcoxon_two_p\", \"wilcoxon_two_es\", \n", + "]\n", + "stat_diff_results = []\n", + "for target in target_vars:\n", + " pred = f\"{target}_pred\"\n", + " for run_i in range(10):\n", + " for method in rr.keys():\n", + " if \"treeval\" in method:\n", + " continue\n", + " df = rr[f\"{method}\"].query(\"run==@run_i\")\n", + " if len(df) == 0: # run does not exist\n", + " continue\n", + " df_treeval = rr[f\"{method}_treeval\"].query(\"run==@run_i\")\n", + " \n", + " y_b = abs(df[target].values[None, :].flatten() - df[pred].values[None, :].flatten())\n", + " y_m = abs(df[target].values[None, :].flatten() - df_treeval[pred].values[None, :].flatten())\n", + " # resid = (df[target].values)\n", + " # y_m = [np.mean(y_m[i]) for i in index_folds]\n", + "\n", + " S = get_total_ranksum(y_b, y_m)\n", + " wilcoxon_two_T, wilcoxon_two_p = wilcoxon(y_b, y_m, alternative=\"two-sided\")\n", + " wilcoxon_two_effect_size = ((S-wilcoxon_two_T)/S) - (wilcoxon_two_T/S)\n", + "\n", + "\n", + " stat_diff_results.append([\n", + " method, target, run_i,\n", + " wilcoxon_two_T, wilcoxon_two_p, wilcoxon_two_effect_size,\n", + " ])\n", + " \n", + "stat_diff_results = pd.DataFrame(stat_diff_results, columns = stat_columns)\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "id": "1e8070d1-f59c-452d-9bf0-caa2810d966b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
runwilcoxon_two_Twilcoxon_two_pwilcoxon_two_es
targetmethod
BMag_haKPConv2.0000167358.00000.00030.1377
MSENet142.0000119040.00000.23630.0519
MSENet502.0000125323.00000.19640.0547
PointNet2.0000187113.50000.36860.0358
RF2.000056803.00000.00000.6997
\\power{}0.000057945.00000.00000.7210
linear0.000054639.00000.00000.7021
V_haKPConv2.0000172075.00000.00410.1097
MSENet142.0000118817.00000.15560.0612
MSENet502.0000119360.50000.09220.0726
PointNet2.0000189249.50000.70720.0146
RF2.000060310.00000.00000.6809
\\power{}0.000063803.00000.00000.6928
linear0.000059058.00000.00000.6780
\n", + "
" + ], + "text/plain": [ + " run wilcoxon_two_T wilcoxon_two_p wilcoxon_two_es\n", + "target method \n", + "BMag_ha KPConv 2.0000 167358.0000 0.0003 0.1377\n", + " MSENet14 2.0000 119040.0000 0.2363 0.0519\n", + " MSENet50 2.0000 125323.0000 0.1964 0.0547\n", + " PointNet 2.0000 187113.5000 0.3686 0.0358\n", + " RF 2.0000 56803.0000 0.0000 0.6997\n", + " \\power{} 0.0000 57945.0000 0.0000 0.7210\n", + " linear 0.0000 54639.0000 0.0000 0.7021\n", + "V_ha KPConv 2.0000 172075.0000 0.0041 0.1097\n", + " MSENet14 2.0000 118817.0000 0.1556 0.0612\n", + " MSENet50 2.0000 119360.5000 0.0922 0.0726\n", + " PointNet 2.0000 189249.5000 0.7072 0.0146\n", + " RF 2.0000 60310.0000 0.0000 0.6809\n", + " \\power{} 0.0000 63803.0000 0.0000 0.6928\n", + " linear 0.0000 59058.0000 0.0000 0.6780" + ] + }, + "execution_count": 90, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pd.set_option(\"display.precision\", 3)\n", + "pd.set_option(\"display.float_format\", lambda x: \"%.4f\" % x)\n", + "stat_diff_results.query(\"target != 'Cag_ha'\").groupby([\"target\", \"method\"]).median()" + ] + }, + { + "cell_type": "markdown", + "id": "216cbf89-d074-4774-955a-b756cc0179f4", + "metadata": {}, + "source": [ + "# Species Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "id": "8d827085-2e90-4805-8266-ceb8206e91ac", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (7.48031 / 2, 2)\n", + "g = sns.histplot(data=results[\"linear\"].reset_index(), x=\"C_qfrac\", hue=\"split\", stat=\"count\", multiple=\"stack\", discrete=True)\n", + "g.set(xlabel=\"conifer fraction\")\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "#plt.savefig(\"figures/c_qfrag_hist.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "5b5e22e1-5d33-472c-b8ec-3a4e23e165cd", + "metadata": {}, + "source": [ + "## boxplot comparing NFI over conifer fractions" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "id": "da5c5d4d-0443-4472-bf39-c099df23c85b", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (7.48031 / 2, 2.5)\n", + "g = sns.catplot(data=results[\"linear\"].reset_index(), y=\"BMag_ha\", x=\"C_qfrac\", kind=\"box\", hue=\"split\", notch=True)\n", + "g.set(xlabel=\"conifer fraction\")\n", + "g.set(ylabel=\"AGB in \\,Mg\\,ha$^{-1}$\")\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "fig.tight_layout()\n", + "#plt.savefig(\"figures/frac_agb.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "id": "a8ddf178-8089-4832-869b-63e4c4669a33", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (7.48031 / 2, 2.5)\n", + "g = sns.catplot(data=results[\"linear\"].reset_index(), y=\"V_ha\", x=\"C_qfrac\", kind=\"box\", hue=\"split\", notch=True)\n", + "g.set(xlabel=\"conifer fraction\")\n", + "g.set(ylabel=\"wood volume in m$^3$\\,ha$^{-1}$\")\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "fig.tight_layout()\n", + "#plt.savefig(\"figures/frac_volume.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "ccc0d723-cf0d-4457-9f4f-b09db3fef285", + "metadata": {}, + "source": [ + "## boxplot comparing MSENet and NFI over fractions" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "id": "4bdf76dc-0673-4dcb-b562-e8593aeeb653", + "metadata": {}, + "outputs": [], + "source": [ + "df_pred = results_corrected[\"MSENet14\"].query(\"split == 'test'\").reset_index().copy()\n", + "df_pred[\"method\"] = 'MSENet14'\n", + "df_pred[\"BMag_ha\"] = df_pred[\"BMag_ha_pred\"]\n", + "df_pred[\"V_ha\"] = df_pred[\"V_ha_pred\"]\n", + "df_pred2 = results_corrected[\"\\power{}\"].query(\"split == 'test'\").reset_index().copy()\n", + "df_pred2[\"method\"] = '\\power{}'\n", + "df_pred2[\"BMag_ha\"] = df_pred2[\"BMag_ha_pred\"]\n", + "df_pred2[\"V_ha\"] = df_pred2[\"V_ha_pred\"]\n", + "df_pred3 = results_corrected[\"MSENet50\"].query(\"split == 'test'\").reset_index().copy()\n", + "df_pred3[\"method\"] = 'MSENet50'\n", + "df_pred3[\"BMag_ha\"] = df_pred3[\"BMag_ha_pred\"]\n", + "df_pred3[\"V_ha\"] = df_pred3[\"V_ha_pred\"]\n", + "df_truth = results_corrected[\"MSENet14\"].query(\"split == 'test'\").reset_index().copy()\n", + "df_truth[\"method\"] = 'NFI'\n", + "\n", + "df = pd.concat([df_pred, df_pred2, \n", + " #df_pred3, \n", + " df_truth])\n", + "\n", + "\n", + "df = df.query(\"C_qfrac != 'nan'\")\n", + "df[\"C_qfrac\"] = df[\"C_qfrac\"].astype(int)" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "id": "9e9409de-2271-4bbc-8142-da028d2f8f0e", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (7.48031 / 2, 2.5)\n", + "g = sns.catplot(data=df, y=\"V_ha\", x=\"C_qfrac\", kind=\"box\", hue=\"method\", notch=True)\n", + "g.set(xlabel=\"conifer fraction\")\n", + "g.set(ylabel=\"wood volume in m^3\\,ha^{-1}\")\n", + "#g.set(yscale=\"log\")\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "fig.tight_layout()\n", + "#plt.savefig(\"figures/frac_volume_comp.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "id": "44dd396a-8234-44ee-bc3f-aa8500f54882", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (7.48031 / 2, 2.5)\n", + "g = sns.catplot(data=df, y=\"BMag_ha\", x=\"C_qfrac\", kind=\"box\", hue=\"method\", notch=True)\n", + "g.set(xlabel=\"conifer fraction\")\n", + "g.set(ylabel=\"AGB in \\,Mg\\,ha^{-1}\")\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "fig.tight_layout()\n", + "#plt.savefig(\"figures/frac_agb_comp.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 101, + "id": "00d6b498", + "metadata": {}, + "outputs": [], + "source": [ + "def evaluate_species(name, results, split):\n", + " print(name)\n", + " columns = [\"method\", \"target\", \"C_qfrac\", \"value\", \"metric\", \"run\"]\n", + " \n", + " results = results.query(\"split == @split\")\n", + " \n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + " \n", + " results_df = []\n", + " for target in target_vars:\n", + " for run, result in results.groupby(\"run\"):\n", + " C_qfracs = set(result.C_qfrac)\n", + " mask = result[target] != 0\n", + " for qfrac in C_qfracs:\n", + " frac_mask = result.C_qfrac == qfrac\n", + " if frac_mask.sum() == 0:\n", + " continue\n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " qfrac,\n", + " r2_score(\n", + " result[frac_mask][target],\n", + " result[frac_mask][target+\"_pred\"],\n", + " ),\n", + " \"R2\",\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " qfrac,\n", + " mean_squared_error(\n", + " result[frac_mask][target],\n", + " result[frac_mask][target+\"_pred\"],\n", + " squared=False,\n", + " ),\n", + " \"RMSE\",\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " qfrac,\n", + " mean_absolute_percentage_error(\n", + " result[frac_mask & mask][target],\n", + " result[frac_mask & mask][target+\"_pred\"],\n", + " )\n", + " * 100,\n", + " \"MAPE\",\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " return pd.concat(results_df, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "id": "b1b7c010", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear\n", + "RF\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "KPConv\n", + "PointNet\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}\n", + "MSENet14\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50\n", + "linear_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RF_treeval\n", + "\\power{}_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "KPConv_treeval\n", + "PointNet_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_treeval\n", + "MSENet50_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear\n", + "RF\n", + "KPConv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PointNet\n", + "\\power{}\n", + "MSENet14\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50\n", + "linear_treeval\n", + "RF_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}_treeval\n", + "KPConv_treeval\n", + "PointNet_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_treeval\n", + "MSENet50_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear\n", + "RF\n", + "KPConv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PointNet\n", + "\\power{}\n", + "MSENet14\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50\n", + "linear_treeval\n", + "RF_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}_treeval\n", + "KPConv_treeval\n", + "PointNet_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_treeval\n", + "MSENet50_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:7: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_42898/276587871.py:8: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + } + ], + "source": [ + "result_scores_cfrac = {}\n", + "for split in splits:\n", + " result_score = []\n", + " for name in models.keys():\n", + " result_dict[name] = file\n", + "\n", + " scores = evaluate_species(name, results_corrected[name], split)\n", + " result_score.append(scores)\n", + " result_score = pd.concat(result_score, axis=0)\n", + " result_score = result_score.query(\"C_qfrac != 'nan'\")\n", + " result_score[\"C_qfrac\"] = result_score[\"C_qfrac\"].astype(int)\n", + " result_scores_cfrac[split] = result_score" + ] + }, + { + "cell_type": "code", + "execution_count": 103, + "id": "213f4777", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
value
metricMAPER2RMSE
targetC_qfracmethod
BMag_ha0KPConv374.62280.787955.9159
KPConv_treeval1222.24040.717663.8283
MSENet14357.06640.794755.0272
MSENet14_treeval322.72380.794055.1264
MSENet50474.15870.779457.1103
MSENet50_treeval469.87020.778557.2248
PointNet1075.72630.668070.0579
PointNet_treeval1120.39390.658871.0236
RF664.55060.687268.0115
RF_treeval7388.16750.372996.3007
\\power{}677.34190.656271.3079
\\power{}_treeval13318.93840.332199.3840
linear1062.98820.671869.6684
linear_treeval6325.31910.397994.3655
33KPConv49.87770.817544.4828
KPConv_treeval62.10220.761050.0162
MSENet1455.98610.837241.9658
MSENet14_treeval44.52790.834242.3487
MSENet5059.76460.831242.8003
MSENet50_treeval62.40840.829643.0044
PointNet86.05330.727854.3584
PointNet_treeval85.53420.711255.9890
RF91.75920.721654.9777
RF_treeval514.07590.400880.6557
\\power{}108.48410.718655.2698
\\power{}_treeval458.58980.255189.9315
linear70.69660.705156.5878
linear_treeval436.65320.414579.7319
66KPConv45.34800.821133.5294
KPConv_treeval57.34540.789436.0695
MSENet1448.70370.831832.5518
MSENet14_treeval43.71270.832232.5081
MSENet5050.45140.808734.7096
MSENet50_treeval53.05260.811234.4847
PointNet62.86210.726941.4939
PointNet_treeval73.11280.704943.1428
RF59.29320.725641.6017
RF_treeval223.13430.031178.1727
\\power{}66.54870.702843.2948
\\power{}_treeval332.0502-0.198486.9403
linear63.04170.710642.7254
linear_treeval173.38060.111974.8453
100KPConv45.45460.706240.8768
KPConv_treeval57.05340.606046.8419
MSENet1442.89860.733838.9256
MSENet14_treeval42.05400.730439.1778
MSENet5047.72780.710040.6361
MSENet50_treeval45.55460.706240.9045
PointNet68.29570.674943.0369
PointNet_treeval70.69300.661143.9330
RF68.21380.697341.5248
RF_treeval441.2907-0.069378.0501
\\power{}62.62620.658244.1299
\\power{}_treeval263.7134-0.315986.5819
linear104.01960.662743.8370
linear_treeval451.5604-0.094778.9698
101KPConv454.00000.765839.8251
KPConv_treeval1253.16820.706244.2901
MSENet14370.63720.782238.4141
MSENet14_treeval368.39950.778638.7274
MSENet50466.27300.768039.6512
MSENet50_treeval485.54970.765039.9078
PointNet1010.52660.712344.1623
PointNet_treeval1026.52410.711944.1935
RF1527.42710.710644.2926
RF_treeval11221.5671-0.220190.9475
\\power{}1177.91380.687746.0147
\\power{}_treeval13871.9134-0.417398.0225
linear962.35510.689045.9157
linear_treeval9970.9772-0.138487.8503
V_ha0KPConv113.02050.7696103.4888
KPConv_treeval166.44360.6926118.0473
MSENet14114.87620.7808100.9990
MSENet14_treeval103.01410.7800101.1862
MSENet50131.97200.7681104.0084
MSENet50_treeval138.00760.7677104.1027
PointNet241.23100.6621125.5487
PointNet_treeval255.36530.6548126.8861
RF208.12990.6807122.0570
RF_treeval1434.52080.3468174.5697
\\power{}224.72210.6418129.2763
\\power{}_treeval1862.14000.3082179.6580
linear208.68480.6591126.1115
linear_treeval1176.93930.3698171.4722
33KPConv43.09830.807884.2133
KPConv_treeval47.05350.732197.2151
MSENet1440.23640.833478.3611
MSENet14_treeval38.93540.829279.3392
MSENet5042.54840.821881.1590
MSENet50_treeval43.34030.818981.8220
PointNet56.94190.7241100.9738
PointNet_treeval56.55480.7054104.3490
RF57.71720.7195101.8241
RF_treeval301.75680.4133147.2648
\\power{}65.46930.7123103.1315
\\power{}_treeval271.88890.2750163.6982
linear57.80890.7056104.3137
linear_treeval254.97760.4311145.0114
66KPConv44.60530.824563.8559
KPConv_treeval52.97340.770772.0635
MSENet1448.16980.836761.6807
MSENet14_treeval44.03880.835261.9369
MSENet5049.99090.809566.6498
MSENet50_treeval52.67610.812266.1735
PointNet61.69270.729579.4505
PointNet_treeval70.70720.706782.7277
RF58.47280.731779.1334
RF_treeval222.06750.1077144.3070
\\power{}69.44770.709182.3981
\\power{}_treeval325.0404-0.0803158.7845
linear62.18680.716881.2927
linear_treeval170.15060.1699139.1890
100KPConv42.78370.752676.1588
KPConv_treeval51.16350.627891.5020
MSENet1439.47940.774272.7413
MSENet14_treeval38.66490.772672.9951
MSENet5043.31390.752376.2178
MSENet50_treeval42.92790.748276.8461
PointNet55.74470.717581.4207
PointNet_treeval57.87670.700383.8527
RF55.28710.734378.9587
RF_treeval330.74960.1102144.4909
\\power{}56.41690.707282.8810
\\power{}_treeval247.1476-0.0617157.8359
linear72.94860.703383.4383
linear_treeval310.62370.0860146.4426
101KPConv159.66860.785481.9686
KPConv_treeval207.46730.699895.6697
MSENet14138.93510.799779.2320
MSENet14_treeval126.60560.796879.7954
MSENet50162.20380.790881.0053
MSENet50_treeval166.68210.788981.3671
PointNet267.67940.744189.5783
PointNet_treeval275.87350.738690.5287
RF336.31780.742089.9442
RF_treeval2840.50920.0755170.2752
\\power{}285.95540.719193.8661
\\power{}_treeval2939.5035-0.0358180.2302
linear257.94840.721293.5128
linear_treeval2527.98750.1315165.0414
\n", + "
" + ], + "text/plain": [ + " value \n", + "metric MAPE R2 RMSE\n", + "target C_qfrac method \n", + "BMag_ha 0 KPConv 374.6228 0.7879 55.9159\n", + " KPConv_treeval 1222.2404 0.7176 63.8283\n", + " MSENet14 357.0664 0.7947 55.0272\n", + " MSENet14_treeval 322.7238 0.7940 55.1264\n", + " MSENet50 474.1587 0.7794 57.1103\n", + " MSENet50_treeval 469.8702 0.7785 57.2248\n", + " PointNet 1075.7263 0.6680 70.0579\n", + " PointNet_treeval 1120.3939 0.6588 71.0236\n", + " RF 664.5506 0.6872 68.0115\n", + " RF_treeval 7388.1675 0.3729 96.3007\n", + " \\power{} 677.3419 0.6562 71.3079\n", + " \\power{}_treeval 13318.9384 0.3321 99.3840\n", + " linear 1062.9882 0.6718 69.6684\n", + " linear_treeval 6325.3191 0.3979 94.3655\n", + " 33 KPConv 49.8777 0.8175 44.4828\n", + " KPConv_treeval 62.1022 0.7610 50.0162\n", + " MSENet14 55.9861 0.8372 41.9658\n", + " MSENet14_treeval 44.5279 0.8342 42.3487\n", + " MSENet50 59.7646 0.8312 42.8003\n", + " MSENet50_treeval 62.4084 0.8296 43.0044\n", + " PointNet 86.0533 0.7278 54.3584\n", + " PointNet_treeval 85.5342 0.7112 55.9890\n", + " RF 91.7592 0.7216 54.9777\n", + " RF_treeval 514.0759 0.4008 80.6557\n", + " \\power{} 108.4841 0.7186 55.2698\n", + " \\power{}_treeval 458.5898 0.2551 89.9315\n", + " linear 70.6966 0.7051 56.5878\n", + " linear_treeval 436.6532 0.4145 79.7319\n", + " 66 KPConv 45.3480 0.8211 33.5294\n", + " KPConv_treeval 57.3454 0.7894 36.0695\n", + " MSENet14 48.7037 0.8318 32.5518\n", + " MSENet14_treeval 43.7127 0.8322 32.5081\n", + " MSENet50 50.4514 0.8087 34.7096\n", + " MSENet50_treeval 53.0526 0.8112 34.4847\n", + " PointNet 62.8621 0.7269 41.4939\n", + " PointNet_treeval 73.1128 0.7049 43.1428\n", + " RF 59.2932 0.7256 41.6017\n", + " RF_treeval 223.1343 0.0311 78.1727\n", + " \\power{} 66.5487 0.7028 43.2948\n", + " \\power{}_treeval 332.0502 -0.1984 86.9403\n", + " linear 63.0417 0.7106 42.7254\n", + " linear_treeval 173.3806 0.1119 74.8453\n", + " 100 KPConv 45.4546 0.7062 40.8768\n", + " KPConv_treeval 57.0534 0.6060 46.8419\n", + " MSENet14 42.8986 0.7338 38.9256\n", + " MSENet14_treeval 42.0540 0.7304 39.1778\n", + " MSENet50 47.7278 0.7100 40.6361\n", + " MSENet50_treeval 45.5546 0.7062 40.9045\n", + " PointNet 68.2957 0.6749 43.0369\n", + " PointNet_treeval 70.6930 0.6611 43.9330\n", + " RF 68.2138 0.6973 41.5248\n", + " RF_treeval 441.2907 -0.0693 78.0501\n", + " \\power{} 62.6262 0.6582 44.1299\n", + " \\power{}_treeval 263.7134 -0.3159 86.5819\n", + " linear 104.0196 0.6627 43.8370\n", + " linear_treeval 451.5604 -0.0947 78.9698\n", + " 101 KPConv 454.0000 0.7658 39.8251\n", + " KPConv_treeval 1253.1682 0.7062 44.2901\n", + " MSENet14 370.6372 0.7822 38.4141\n", + " MSENet14_treeval 368.3995 0.7786 38.7274\n", + " MSENet50 466.2730 0.7680 39.6512\n", + " MSENet50_treeval 485.5497 0.7650 39.9078\n", + " PointNet 1010.5266 0.7123 44.1623\n", + " PointNet_treeval 1026.5241 0.7119 44.1935\n", + " RF 1527.4271 0.7106 44.2926\n", + " RF_treeval 11221.5671 -0.2201 90.9475\n", + " \\power{} 1177.9138 0.6877 46.0147\n", + " \\power{}_treeval 13871.9134 -0.4173 98.0225\n", + " linear 962.3551 0.6890 45.9157\n", + " linear_treeval 9970.9772 -0.1384 87.8503\n", + "V_ha 0 KPConv 113.0205 0.7696 103.4888\n", + " KPConv_treeval 166.4436 0.6926 118.0473\n", + " MSENet14 114.8762 0.7808 100.9990\n", + " MSENet14_treeval 103.0141 0.7800 101.1862\n", + " MSENet50 131.9720 0.7681 104.0084\n", + " MSENet50_treeval 138.0076 0.7677 104.1027\n", + " PointNet 241.2310 0.6621 125.5487\n", + " PointNet_treeval 255.3653 0.6548 126.8861\n", + " RF 208.1299 0.6807 122.0570\n", + " RF_treeval 1434.5208 0.3468 174.5697\n", + " \\power{} 224.7221 0.6418 129.2763\n", + " \\power{}_treeval 1862.1400 0.3082 179.6580\n", + " linear 208.6848 0.6591 126.1115\n", + " linear_treeval 1176.9393 0.3698 171.4722\n", + " 33 KPConv 43.0983 0.8078 84.2133\n", + " KPConv_treeval 47.0535 0.7321 97.2151\n", + " MSENet14 40.2364 0.8334 78.3611\n", + " MSENet14_treeval 38.9354 0.8292 79.3392\n", + " MSENet50 42.5484 0.8218 81.1590\n", + " MSENet50_treeval 43.3403 0.8189 81.8220\n", + " PointNet 56.9419 0.7241 100.9738\n", + " PointNet_treeval 56.5548 0.7054 104.3490\n", + " RF 57.7172 0.7195 101.8241\n", + " RF_treeval 301.7568 0.4133 147.2648\n", + " \\power{} 65.4693 0.7123 103.1315\n", + " \\power{}_treeval 271.8889 0.2750 163.6982\n", + " linear 57.8089 0.7056 104.3137\n", + " linear_treeval 254.9776 0.4311 145.0114\n", + " 66 KPConv 44.6053 0.8245 63.8559\n", + " KPConv_treeval 52.9734 0.7707 72.0635\n", + " MSENet14 48.1698 0.8367 61.6807\n", + " MSENet14_treeval 44.0388 0.8352 61.9369\n", + " MSENet50 49.9909 0.8095 66.6498\n", + " MSENet50_treeval 52.6761 0.8122 66.1735\n", + " PointNet 61.6927 0.7295 79.4505\n", + " PointNet_treeval 70.7072 0.7067 82.7277\n", + " RF 58.4728 0.7317 79.1334\n", + " RF_treeval 222.0675 0.1077 144.3070\n", + " \\power{} 69.4477 0.7091 82.3981\n", + " \\power{}_treeval 325.0404 -0.0803 158.7845\n", + " linear 62.1868 0.7168 81.2927\n", + " linear_treeval 170.1506 0.1699 139.1890\n", + " 100 KPConv 42.7837 0.7526 76.1588\n", + " KPConv_treeval 51.1635 0.6278 91.5020\n", + " MSENet14 39.4794 0.7742 72.7413\n", + " MSENet14_treeval 38.6649 0.7726 72.9951\n", + " MSENet50 43.3139 0.7523 76.2178\n", + " MSENet50_treeval 42.9279 0.7482 76.8461\n", + " PointNet 55.7447 0.7175 81.4207\n", + " PointNet_treeval 57.8767 0.7003 83.8527\n", + " RF 55.2871 0.7343 78.9587\n", + " RF_treeval 330.7496 0.1102 144.4909\n", + " \\power{} 56.4169 0.7072 82.8810\n", + " \\power{}_treeval 247.1476 -0.0617 157.8359\n", + " linear 72.9486 0.7033 83.4383\n", + " linear_treeval 310.6237 0.0860 146.4426\n", + " 101 KPConv 159.6686 0.7854 81.9686\n", + " KPConv_treeval 207.4673 0.6998 95.6697\n", + " MSENet14 138.9351 0.7997 79.2320\n", + " MSENet14_treeval 126.6056 0.7968 79.7954\n", + " MSENet50 162.2038 0.7908 81.0053\n", + " MSENet50_treeval 166.6821 0.7889 81.3671\n", + " PointNet 267.6794 0.7441 89.5783\n", + " PointNet_treeval 275.8735 0.7386 90.5287\n", + " RF 336.3178 0.7420 89.9442\n", + " RF_treeval 2840.5092 0.0755 170.2752\n", + " \\power{} 285.9554 0.7191 93.8661\n", + " \\power{}_treeval 2939.5035 -0.0358 180.2302\n", + " linear 257.9484 0.7212 93.5128\n", + " linear_treeval 2527.9875 0.1315 165.0414" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "val\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
value
metricMAPER2RMSE
targetC_qfracmethod
BMag_ha0KPConv804.78580.754763.8309
KPConv_treeval576.07410.694570.7592
MSENet14446.10040.770461.7558
MSENet14_treeval376.25800.765862.3754
MSENet50703.44580.762162.8652
MSENet50_treeval703.97740.758863.3022
PointNet1611.72130.651576.0939
PointNet_treeval1654.45820.639177.4405
RF924.32330.688771.9182
RF_treeval7470.63410.398699.9654
\\power{}727.66890.673973.6097
\\power{}_treeval3606.88540.3884100.8134
linear473.89770.668474.2345
linear_treeval7580.03550.406199.3420
33KPConv126.97580.751148.8300
KPConv_treeval167.20710.717851.8262
MSENet14101.74510.786345.2657
MSENet14_treeval100.90170.790244.8564
MSENet50114.53450.787045.2197
MSENet50_treeval112.17700.792544.6263
PointNet229.99330.718851.9652
PointNet_treeval229.08610.705853.1498
RF247.39210.832640.1019
RF_treeval623.29160.426874.2055
\\power{}347.82010.809442.7924
\\power{}_treeval2631.41210.375477.4635
linear284.10870.822241.3230
linear_treeval638.74690.426274.2462
66KPConv36.11280.739743.7282
KPConv_treeval36.27960.651050.2197
MSENet1432.10030.782640.0265
MSENet14_treeval29.38240.784139.8832
MSENet5034.06370.773440.8590
MSENet50_treeval32.98160.778240.4295
PointNet43.90190.664049.7441
PointNet_treeval46.65540.658050.1903
RF36.05970.685648.1408
RF_treeval238.08660.087981.9933
\\power{}39.26810.616753.1544
\\power{}_treeval296.8905-0.191193.7010
linear36.98820.697047.2561
linear_treeval196.34960.168278.3018
100KPConv33.71570.764936.3041
KPConv_treeval36.28520.704340.1052
MSENet1433.37020.781134.9987
MSENet14_treeval33.16100.783634.8044
MSENet5031.46890.798633.6059
MSENet50_treeval31.67370.791134.2236
PointNet39.21630.719739.6459
PointNet_treeval39.25980.715439.9497
RF29.04050.775335.5112
RF_treeval140.2106-0.151080.3693
\\power{}32.95680.721339.5450
\\power{}_treeval173.3453-0.532992.7506
linear31.47650.745137.8251
linear_treeval131.0684-0.124879.4501
101KPConv136.52590.810134.0605
KPConv_treeval191.47060.732039.7428
MSENet14113.48610.819233.2394
MSENet14_treeval119.71730.820133.1538
MSENet50159.19820.827132.5048
MSENet50_treeval155.81070.825532.6540
PointNet487.06190.695343.1507
PointNet_treeval503.46940.698342.9364
RF194.02990.804734.5516
RF_treeval1749.3710-0.075981.0880
\\power{}199.08290.788035.9955
\\power{}_treeval2158.5527-0.367991.4309
linear188.24960.777736.8614
linear_treeval1529.4261-0.091581.6741
V_ha0KPConv112.59980.7659107.4898
KPConv_treeval106.56250.6898122.3325
MSENet1488.62800.7852102.9869
MSENet14_treeval79.30600.7802104.1861
MSENet50101.70460.7801104.1924
MSENet50_treeval102.32820.7764105.0706
PointNet165.49210.6829125.1439
PointNet_treeval171.36410.6713127.4126
RF132.31150.7150118.6341
RF_treeval940.97030.3854174.2194
\\power{}128.20500.6978122.1618
\\power{}_treeval551.72760.3832174.5280
linear104.49050.6980122.1251
linear_treeval841.59270.3972172.5494
33KPConv67.53590.729591.1093
KPConv_treeval74.66540.697696.0043
MSENet1455.57070.761385.6213
MSENet14_treeval56.39180.760185.8523
MSENet5058.64200.766984.6641
MSENet50_treeval58.76090.768084.4736
PointNet83.72700.721392.5901
PointNet_treeval83.28170.712194.0927
RF77.53330.822373.9322
RF_treeval389.51310.4115134.5535
\\power{}95.18320.803677.7242
\\power{}_treeval612.40380.3542140.9528
linear85.44280.813175.8177
linear_treeval382.69350.4177133.8423
66KPConv33.30950.781478.8848
KPConv_treeval35.63300.662595.5652
MSENet1430.88570.818072.0372
MSENet14_treeval27.90520.815272.5972
MSENet5032.49760.821871.2757
MSENet50_treeval31.51160.826170.4043
PointNet41.68960.723088.8534
PointNet_treeval44.72740.711190.7499
RF33.82610.736386.7481
RF_treeval237.03720.2204149.1483
\\power{}38.62180.688994.2185
\\power{}_treeval311.77230.0201167.2185
linear33.24710.738886.3327
linear_treeval194.60520.2636144.9619
100KPConv35.96330.813467.1629
KPConv_treeval38.08930.772473.5215
MSENet1435.08680.817666.3357
MSENet14_treeval34.82840.820065.8931
MSENet5033.14570.839062.3983
MSENet50_treeval32.60390.834063.3564
PointNet43.14340.783972.2820
PointNet_treeval43.07360.777973.2763
RF30.18210.837962.6256
RF_treeval155.03550.1164146.2063
\\power{}38.13730.796570.1571
\\power{}_treeval192.1182-0.1183164.4848
linear34.21640.807768.2068
linear_treeval141.42970.1184146.0452
101KPConv112.14850.842069.4998
KPConv_treeval147.44710.756684.0405
MSENet14103.69560.843169.2719
MSENet14_treeval102.44610.842869.3301
MSENet50128.75110.842469.4285
MSENet50_treeval125.03590.840469.8583
PointNet466.02460.754886.5901
PointNet_treeval477.91830.751187.2450
RF146.52710.821573.8867
RF_treeval1021.14050.2328153.1810
\\power{}162.49110.803777.4920
\\power{}_treeval1388.84900.0801167.7326
linear141.35430.796878.8431
linear_treeval856.12780.2460151.8639
\n", + "
" + ], + "text/plain": [ + " value \n", + "metric MAPE R2 RMSE\n", + "target C_qfrac method \n", + "BMag_ha 0 KPConv 804.7858 0.7547 63.8309\n", + " KPConv_treeval 576.0741 0.6945 70.7592\n", + " MSENet14 446.1004 0.7704 61.7558\n", + " MSENet14_treeval 376.2580 0.7658 62.3754\n", + " MSENet50 703.4458 0.7621 62.8652\n", + " MSENet50_treeval 703.9774 0.7588 63.3022\n", + " PointNet 1611.7213 0.6515 76.0939\n", + " PointNet_treeval 1654.4582 0.6391 77.4405\n", + " RF 924.3233 0.6887 71.9182\n", + " RF_treeval 7470.6341 0.3986 99.9654\n", + " \\power{} 727.6689 0.6739 73.6097\n", + " \\power{}_treeval 3606.8854 0.3884 100.8134\n", + " linear 473.8977 0.6684 74.2345\n", + " linear_treeval 7580.0355 0.4061 99.3420\n", + " 33 KPConv 126.9758 0.7511 48.8300\n", + " KPConv_treeval 167.2071 0.7178 51.8262\n", + " MSENet14 101.7451 0.7863 45.2657\n", + " MSENet14_treeval 100.9017 0.7902 44.8564\n", + " MSENet50 114.5345 0.7870 45.2197\n", + " MSENet50_treeval 112.1770 0.7925 44.6263\n", + " PointNet 229.9933 0.7188 51.9652\n", + " PointNet_treeval 229.0861 0.7058 53.1498\n", + " RF 247.3921 0.8326 40.1019\n", + " RF_treeval 623.2916 0.4268 74.2055\n", + " \\power{} 347.8201 0.8094 42.7924\n", + " \\power{}_treeval 2631.4121 0.3754 77.4635\n", + " linear 284.1087 0.8222 41.3230\n", + " linear_treeval 638.7469 0.4262 74.2462\n", + " 66 KPConv 36.1128 0.7397 43.7282\n", + " KPConv_treeval 36.2796 0.6510 50.2197\n", + " MSENet14 32.1003 0.7826 40.0265\n", + " MSENet14_treeval 29.3824 0.7841 39.8832\n", + " MSENet50 34.0637 0.7734 40.8590\n", + " MSENet50_treeval 32.9816 0.7782 40.4295\n", + " PointNet 43.9019 0.6640 49.7441\n", + " PointNet_treeval 46.6554 0.6580 50.1903\n", + " RF 36.0597 0.6856 48.1408\n", + " RF_treeval 238.0866 0.0879 81.9933\n", + " \\power{} 39.2681 0.6167 53.1544\n", + " \\power{}_treeval 296.8905 -0.1911 93.7010\n", + " linear 36.9882 0.6970 47.2561\n", + " linear_treeval 196.3496 0.1682 78.3018\n", + " 100 KPConv 33.7157 0.7649 36.3041\n", + " KPConv_treeval 36.2852 0.7043 40.1052\n", + " MSENet14 33.3702 0.7811 34.9987\n", + " MSENet14_treeval 33.1610 0.7836 34.8044\n", + " MSENet50 31.4689 0.7986 33.6059\n", + " MSENet50_treeval 31.6737 0.7911 34.2236\n", + " PointNet 39.2163 0.7197 39.6459\n", + " PointNet_treeval 39.2598 0.7154 39.9497\n", + " RF 29.0405 0.7753 35.5112\n", + " RF_treeval 140.2106 -0.1510 80.3693\n", + " \\power{} 32.9568 0.7213 39.5450\n", + " \\power{}_treeval 173.3453 -0.5329 92.7506\n", + " linear 31.4765 0.7451 37.8251\n", + " linear_treeval 131.0684 -0.1248 79.4501\n", + " 101 KPConv 136.5259 0.8101 34.0605\n", + " KPConv_treeval 191.4706 0.7320 39.7428\n", + " MSENet14 113.4861 0.8192 33.2394\n", + " MSENet14_treeval 119.7173 0.8201 33.1538\n", + " MSENet50 159.1982 0.8271 32.5048\n", + " MSENet50_treeval 155.8107 0.8255 32.6540\n", + " PointNet 487.0619 0.6953 43.1507\n", + " PointNet_treeval 503.4694 0.6983 42.9364\n", + " RF 194.0299 0.8047 34.5516\n", + " RF_treeval 1749.3710 -0.0759 81.0880\n", + " \\power{} 199.0829 0.7880 35.9955\n", + " \\power{}_treeval 2158.5527 -0.3679 91.4309\n", + " linear 188.2496 0.7777 36.8614\n", + " linear_treeval 1529.4261 -0.0915 81.6741\n", + "V_ha 0 KPConv 112.5998 0.7659 107.4898\n", + " KPConv_treeval 106.5625 0.6898 122.3325\n", + " MSENet14 88.6280 0.7852 102.9869\n", + " MSENet14_treeval 79.3060 0.7802 104.1861\n", + " MSENet50 101.7046 0.7801 104.1924\n", + " MSENet50_treeval 102.3282 0.7764 105.0706\n", + " PointNet 165.4921 0.6829 125.1439\n", + " PointNet_treeval 171.3641 0.6713 127.4126\n", + " RF 132.3115 0.7150 118.6341\n", + " RF_treeval 940.9703 0.3854 174.2194\n", + " \\power{} 128.2050 0.6978 122.1618\n", + " \\power{}_treeval 551.7276 0.3832 174.5280\n", + " linear 104.4905 0.6980 122.1251\n", + " linear_treeval 841.5927 0.3972 172.5494\n", + " 33 KPConv 67.5359 0.7295 91.1093\n", + " KPConv_treeval 74.6654 0.6976 96.0043\n", + " MSENet14 55.5707 0.7613 85.6213\n", + " MSENet14_treeval 56.3918 0.7601 85.8523\n", + " MSENet50 58.6420 0.7669 84.6641\n", + " MSENet50_treeval 58.7609 0.7680 84.4736\n", + " PointNet 83.7270 0.7213 92.5901\n", + " PointNet_treeval 83.2817 0.7121 94.0927\n", + " RF 77.5333 0.8223 73.9322\n", + " RF_treeval 389.5131 0.4115 134.5535\n", + " \\power{} 95.1832 0.8036 77.7242\n", + " \\power{}_treeval 612.4038 0.3542 140.9528\n", + " linear 85.4428 0.8131 75.8177\n", + " linear_treeval 382.6935 0.4177 133.8423\n", + " 66 KPConv 33.3095 0.7814 78.8848\n", + " KPConv_treeval 35.6330 0.6625 95.5652\n", + " MSENet14 30.8857 0.8180 72.0372\n", + " MSENet14_treeval 27.9052 0.8152 72.5972\n", + " MSENet50 32.4976 0.8218 71.2757\n", + " MSENet50_treeval 31.5116 0.8261 70.4043\n", + " PointNet 41.6896 0.7230 88.8534\n", + " PointNet_treeval 44.7274 0.7111 90.7499\n", + " RF 33.8261 0.7363 86.7481\n", + " RF_treeval 237.0372 0.2204 149.1483\n", + " \\power{} 38.6218 0.6889 94.2185\n", + " \\power{}_treeval 311.7723 0.0201 167.2185\n", + " linear 33.2471 0.7388 86.3327\n", + " linear_treeval 194.6052 0.2636 144.9619\n", + " 100 KPConv 35.9633 0.8134 67.1629\n", + " KPConv_treeval 38.0893 0.7724 73.5215\n", + " MSENet14 35.0868 0.8176 66.3357\n", + " MSENet14_treeval 34.8284 0.8200 65.8931\n", + " MSENet50 33.1457 0.8390 62.3983\n", + " MSENet50_treeval 32.6039 0.8340 63.3564\n", + " PointNet 43.1434 0.7839 72.2820\n", + " PointNet_treeval 43.0736 0.7779 73.2763\n", + " RF 30.1821 0.8379 62.6256\n", + " RF_treeval 155.0355 0.1164 146.2063\n", + " \\power{} 38.1373 0.7965 70.1571\n", + " \\power{}_treeval 192.1182 -0.1183 164.4848\n", + " linear 34.2164 0.8077 68.2068\n", + " linear_treeval 141.4297 0.1184 146.0452\n", + " 101 KPConv 112.1485 0.8420 69.4998\n", + " KPConv_treeval 147.4471 0.7566 84.0405\n", + " MSENet14 103.6956 0.8431 69.2719\n", + " MSENet14_treeval 102.4461 0.8428 69.3301\n", + " MSENet50 128.7511 0.8424 69.4285\n", + " MSENet50_treeval 125.0359 0.8404 69.8583\n", + " PointNet 466.0246 0.7548 86.5901\n", + " PointNet_treeval 477.9183 0.7511 87.2450\n", + " RF 146.5271 0.8215 73.8867\n", + " RF_treeval 1021.1405 0.2328 153.1810\n", + " \\power{} 162.4911 0.8037 77.4920\n", + " \\power{}_treeval 1388.8490 0.0801 167.7326\n", + " linear 141.3543 0.7968 78.8431\n", + " linear_treeval 856.1278 0.2460 151.8639" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "test\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
value
metricMAPER2RMSE
targetC_qfracmethod
BMag_ha0KPConv113.73070.779555.5312
KPConv_treeval174.99030.711862.8864
MSENet14120.00680.818250.4434
MSENet14_treeval107.84850.818350.4226
MSENet50165.72270.806352.0664
MSENet50_treeval159.74680.806951.9793
PointNet280.62550.749459.2245
PointNet_treeval280.45030.747959.4101
RF525.01870.734760.9495
RF_treeval2006.29370.435388.9206
\\power{}567.69820.743959.8803
\\power{}_treeval3790.54110.2753100.7281
linear421.47090.748659.3251
linear_treeval2041.37960.450387.7271
33KPConv37.87530.754748.2156
KPConv_treeval69.68220.655656.5580
MSENet1436.02720.733250.2557
MSENet14_treeval36.80390.734650.1321
MSENet5035.46460.762247.4598
MSENet50_treeval33.30310.776446.0203
PointNet45.02470.645257.9907
PointNet_treeval47.26730.635258.8029
RF37.27800.685754.5893
RF_treeval441.16360.132390.6949
\\power{}38.47410.690654.1601
\\power{}_treeval405.36480.149889.7779
linear39.51120.690354.1861
linear_treeval418.76930.197987.2009
66KPConv60.08660.801834.8103
KPConv_treeval136.23170.739439.6866
MSENet1454.36850.824432.7850
MSENet14_treeval49.20670.818533.3328
MSENet5058.54160.839631.3105
MSENet50_treeval56.33700.836631.5929
PointNet96.71670.761038.2226
PointNet_treeval132.82000.786036.1788
RF85.28490.727940.8137
RF_treeval638.60870.015377.6468
\\power{}55.30330.722541.2167
\\power{}_treeval1246.9583-0.514696.2992
linear51.35840.736340.1812
linear_treeval537.63810.170571.2671
100KPConv61.02290.796634.6517
KPConv_treeval143.92630.681142.9693
MSENet1441.49860.828231.8774
MSENet14_treeval39.25390.826732.0152
MSENet5050.94610.831631.5574
MSENet50_treeval71.05170.820532.5835
PointNet85.39180.775436.4519
PointNet_treeval86.67780.767837.0703
RF93.11040.801834.2500
RF_treeval355.1950-0.105380.8893
\\power{}99.84290.810633.4873
\\power{}_treeval476.1943-0.203284.3959
linear94.46810.803934.0701
linear_treeval522.7593-0.088580.2702
101KPConv1241.28820.813236.3953
KPConv_treeval1122.74150.730442.9605
MSENet14698.14850.816736.0585
MSENet14_treeval551.30030.811436.5736
MSENet50947.83110.833534.3580
MSENet50_treeval839.55440.826335.0886
PointNet2308.18910.783739.1815
PointNet_treeval2406.12210.771340.2842
RF1340.63830.756741.5531
RF_treeval21985.7921-0.227193.3179
\\power{}333.95810.753241.8513
\\power{}_treeval19277.0374-0.267694.8449
linear750.19620.750342.0953
linear_treeval35514.9320-0.159790.7198
V_ha0KPConv94.08200.7593102.4256
KPConv_treeval115.89260.6788116.9080
MSENet14106.01730.801593.0178
MSENet14_treeval93.62080.802192.8621
MSENet50126.55500.789795.7319
MSENet50_treeval121.78440.789495.7884
PointNet189.22030.7502104.3610
PointNet_treeval186.54030.7510104.1872
RF360.04260.7317108.1604
RF_treeval1270.61480.4106160.3037
\\power{}396.47690.7436105.7417
\\power{}_treeval2159.13470.2623179.3437
linear260.64030.7532103.7421
linear_treeval1275.93630.4372156.6507
33KPConv38.04920.704995.2133
KPConv_treeval53.50470.5912110.9469
MSENet1435.95080.675199.8496
MSENet14_treeval36.64520.6732100.1552
MSENet5034.91750.705495.1077
MSENet50_treeval33.32520.719192.8766
PointNet40.37280.6030110.4342
PointNet_treeval42.73920.5948111.5663
RF36.85570.6442104.5526
RF_treeval363.30430.1086165.4977
\\power{}38.42230.6496103.7661
\\power{}_treeval338.05410.1676159.9216
linear36.17550.6473104.0946
linear_treeval339.11220.1626160.4031
66KPConv49.63620.836561.7971
KPConv_treeval77.95570.779371.2130
MSENet1450.06750.860857.0490
MSENet14_treeval44.26550.854458.3527
MSENet5051.42950.866655.7932
MSENet50_treeval48.85900.862756.6026
PointNet73.94280.782471.3400
PointNet_treeval89.32210.804067.7084
RF62.69790.758975.1154
RF_treeval406.78130.1496141.0797
\\power{}54.46530.774672.6313
\\power{}_treeval699.6692-0.2760172.8152
linear52.39120.770773.2547
linear_treeval305.44280.2803129.7873
100KPConv70.27170.812268.3703
KPConv_treeval130.39460.659589.0039
MSENet1449.94870.832864.5441
MSENet14_treeval47.17930.831564.7895
MSENet5061.98060.836363.8682
MSENet50_treeval84.03230.825365.9787
PointNet94.21330.796171.2679
PointNet_treeval94.47690.787072.8608
RF107.95120.793371.7916
RF_treeval409.12710.0536153.6323
\\power{}118.29720.801970.2878
\\power{}_treeval545.43590.0044157.5772
linear113.47100.792172.0060
linear_treeval556.10680.0719152.1399
101KPConv201.97180.836474.2849
KPConv_treeval197.38270.725893.2016
MSENet14134.64170.846771.8896
MSENet14_treeval122.55830.842772.8249
MSENet50172.91020.862268.1356
MSENet50_treeval173.77400.855669.7393
PointNet325.22200.809980.0643
PointNet_treeval407.91010.796482.8831
RF198.67510.781485.8857
RF_treeval3635.66540.0954174.7154
\\power{}143.09740.768788.3447
\\power{}_treeval2599.07180.1106173.2390
linear153.54390.775587.0372
linear_treeval3820.08250.1605168.3134
\n", + "
" + ], + "text/plain": [ + " value \n", + "metric MAPE R2 RMSE\n", + "target C_qfrac method \n", + "BMag_ha 0 KPConv 113.7307 0.7795 55.5312\n", + " KPConv_treeval 174.9903 0.7118 62.8864\n", + " MSENet14 120.0068 0.8182 50.4434\n", + " MSENet14_treeval 107.8485 0.8183 50.4226\n", + " MSENet50 165.7227 0.8063 52.0664\n", + " MSENet50_treeval 159.7468 0.8069 51.9793\n", + " PointNet 280.6255 0.7494 59.2245\n", + " PointNet_treeval 280.4503 0.7479 59.4101\n", + " RF 525.0187 0.7347 60.9495\n", + " RF_treeval 2006.2937 0.4353 88.9206\n", + " \\power{} 567.6982 0.7439 59.8803\n", + " \\power{}_treeval 3790.5411 0.2753 100.7281\n", + " linear 421.4709 0.7486 59.3251\n", + " linear_treeval 2041.3796 0.4503 87.7271\n", + " 33 KPConv 37.8753 0.7547 48.2156\n", + " KPConv_treeval 69.6822 0.6556 56.5580\n", + " MSENet14 36.0272 0.7332 50.2557\n", + " MSENet14_treeval 36.8039 0.7346 50.1321\n", + " MSENet50 35.4646 0.7622 47.4598\n", + " MSENet50_treeval 33.3031 0.7764 46.0203\n", + " PointNet 45.0247 0.6452 57.9907\n", + " PointNet_treeval 47.2673 0.6352 58.8029\n", + " RF 37.2780 0.6857 54.5893\n", + " RF_treeval 441.1636 0.1323 90.6949\n", + " \\power{} 38.4741 0.6906 54.1601\n", + " \\power{}_treeval 405.3648 0.1498 89.7779\n", + " linear 39.5112 0.6903 54.1861\n", + " linear_treeval 418.7693 0.1979 87.2009\n", + " 66 KPConv 60.0866 0.8018 34.8103\n", + " KPConv_treeval 136.2317 0.7394 39.6866\n", + " MSENet14 54.3685 0.8244 32.7850\n", + " MSENet14_treeval 49.2067 0.8185 33.3328\n", + " MSENet50 58.5416 0.8396 31.3105\n", + " MSENet50_treeval 56.3370 0.8366 31.5929\n", + " PointNet 96.7167 0.7610 38.2226\n", + " PointNet_treeval 132.8200 0.7860 36.1788\n", + " RF 85.2849 0.7279 40.8137\n", + " RF_treeval 638.6087 0.0153 77.6468\n", + " \\power{} 55.3033 0.7225 41.2167\n", + " \\power{}_treeval 1246.9583 -0.5146 96.2992\n", + " linear 51.3584 0.7363 40.1812\n", + " linear_treeval 537.6381 0.1705 71.2671\n", + " 100 KPConv 61.0229 0.7966 34.6517\n", + " KPConv_treeval 143.9263 0.6811 42.9693\n", + " MSENet14 41.4986 0.8282 31.8774\n", + " MSENet14_treeval 39.2539 0.8267 32.0152\n", + " MSENet50 50.9461 0.8316 31.5574\n", + " MSENet50_treeval 71.0517 0.8205 32.5835\n", + " PointNet 85.3918 0.7754 36.4519\n", + " PointNet_treeval 86.6778 0.7678 37.0703\n", + " RF 93.1104 0.8018 34.2500\n", + " RF_treeval 355.1950 -0.1053 80.8893\n", + " \\power{} 99.8429 0.8106 33.4873\n", + " \\power{}_treeval 476.1943 -0.2032 84.3959\n", + " linear 94.4681 0.8039 34.0701\n", + " linear_treeval 522.7593 -0.0885 80.2702\n", + " 101 KPConv 1241.2882 0.8132 36.3953\n", + " KPConv_treeval 1122.7415 0.7304 42.9605\n", + " MSENet14 698.1485 0.8167 36.0585\n", + " MSENet14_treeval 551.3003 0.8114 36.5736\n", + " MSENet50 947.8311 0.8335 34.3580\n", + " MSENet50_treeval 839.5544 0.8263 35.0886\n", + " PointNet 2308.1891 0.7837 39.1815\n", + " PointNet_treeval 2406.1221 0.7713 40.2842\n", + " RF 1340.6383 0.7567 41.5531\n", + " RF_treeval 21985.7921 -0.2271 93.3179\n", + " \\power{} 333.9581 0.7532 41.8513\n", + " \\power{}_treeval 19277.0374 -0.2676 94.8449\n", + " linear 750.1962 0.7503 42.0953\n", + " linear_treeval 35514.9320 -0.1597 90.7198\n", + "V_ha 0 KPConv 94.0820 0.7593 102.4256\n", + " KPConv_treeval 115.8926 0.6788 116.9080\n", + " MSENet14 106.0173 0.8015 93.0178\n", + " MSENet14_treeval 93.6208 0.8021 92.8621\n", + " MSENet50 126.5550 0.7897 95.7319\n", + " MSENet50_treeval 121.7844 0.7894 95.7884\n", + " PointNet 189.2203 0.7502 104.3610\n", + " PointNet_treeval 186.5403 0.7510 104.1872\n", + " RF 360.0426 0.7317 108.1604\n", + " RF_treeval 1270.6148 0.4106 160.3037\n", + " \\power{} 396.4769 0.7436 105.7417\n", + " \\power{}_treeval 2159.1347 0.2623 179.3437\n", + " linear 260.6403 0.7532 103.7421\n", + " linear_treeval 1275.9363 0.4372 156.6507\n", + " 33 KPConv 38.0492 0.7049 95.2133\n", + " KPConv_treeval 53.5047 0.5912 110.9469\n", + " MSENet14 35.9508 0.6751 99.8496\n", + " MSENet14_treeval 36.6452 0.6732 100.1552\n", + " MSENet50 34.9175 0.7054 95.1077\n", + " MSENet50_treeval 33.3252 0.7191 92.8766\n", + " PointNet 40.3728 0.6030 110.4342\n", + " PointNet_treeval 42.7392 0.5948 111.5663\n", + " RF 36.8557 0.6442 104.5526\n", + " RF_treeval 363.3043 0.1086 165.4977\n", + " \\power{} 38.4223 0.6496 103.7661\n", + " \\power{}_treeval 338.0541 0.1676 159.9216\n", + " linear 36.1755 0.6473 104.0946\n", + " linear_treeval 339.1122 0.1626 160.4031\n", + " 66 KPConv 49.6362 0.8365 61.7971\n", + " KPConv_treeval 77.9557 0.7793 71.2130\n", + " MSENet14 50.0675 0.8608 57.0490\n", + " MSENet14_treeval 44.2655 0.8544 58.3527\n", + " MSENet50 51.4295 0.8666 55.7932\n", + " MSENet50_treeval 48.8590 0.8627 56.6026\n", + " PointNet 73.9428 0.7824 71.3400\n", + " PointNet_treeval 89.3221 0.8040 67.7084\n", + " RF 62.6979 0.7589 75.1154\n", + " RF_treeval 406.7813 0.1496 141.0797\n", + " \\power{} 54.4653 0.7746 72.6313\n", + " \\power{}_treeval 699.6692 -0.2760 172.8152\n", + " linear 52.3912 0.7707 73.2547\n", + " linear_treeval 305.4428 0.2803 129.7873\n", + " 100 KPConv 70.2717 0.8122 68.3703\n", + " KPConv_treeval 130.3946 0.6595 89.0039\n", + " MSENet14 49.9487 0.8328 64.5441\n", + " MSENet14_treeval 47.1793 0.8315 64.7895\n", + " MSENet50 61.9806 0.8363 63.8682\n", + " MSENet50_treeval 84.0323 0.8253 65.9787\n", + " PointNet 94.2133 0.7961 71.2679\n", + " PointNet_treeval 94.4769 0.7870 72.8608\n", + " RF 107.9512 0.7933 71.7916\n", + " RF_treeval 409.1271 0.0536 153.6323\n", + " \\power{} 118.2972 0.8019 70.2878\n", + " \\power{}_treeval 545.4359 0.0044 157.5772\n", + " linear 113.4710 0.7921 72.0060\n", + " linear_treeval 556.1068 0.0719 152.1399\n", + " 101 KPConv 201.9718 0.8364 74.2849\n", + " KPConv_treeval 197.3827 0.7258 93.2016\n", + " MSENet14 134.6417 0.8467 71.8896\n", + " MSENet14_treeval 122.5583 0.8427 72.8249\n", + " MSENet50 172.9102 0.8622 68.1356\n", + " MSENet50_treeval 173.7740 0.8556 69.7393\n", + " PointNet 325.2220 0.8099 80.0643\n", + " PointNet_treeval 407.9101 0.7964 82.8831\n", + " RF 198.6751 0.7814 85.8857\n", + " RF_treeval 3635.6654 0.0954 174.7154\n", + " \\power{} 143.0974 0.7687 88.3447\n", + " \\power{}_treeval 2599.0718 0.1106 173.2390\n", + " linear 153.5439 0.7755 87.0372\n", + " linear_treeval 3820.0825 0.1605 168.3134" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pd.set_option(\"display.max_rows\", None)\n", + "for split in splits:\n", + " print(split)\n", + " display(\n", + " result_scores_cfrac[split]\n", + " .groupby(\n", + " [\n", + " \"target\",\n", + " \"C_qfrac\",\n", + " \"metric\",\n", + " \"method\",\n", + " ],\n", + " sort=True,\n", + " as_index=False,\n", + " )[\"value\"]\n", + " .mean()\n", + " .pivot(index=[\"target\", \"C_qfrac\", \"method\"], columns=[\"metric\"])\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "c538d37e-7a79-4ae1-aee4-e09b6176b476", + "metadata": {}, + "source": [ + "# Errors Cancel Out" + ] + }, + { + "cell_type": "code", + "execution_count": 104, + "id": "e5430d97-79a5-408d-bf0e-82fde618147f", + "metadata": {}, + "outputs": [], + "source": [ + "def spatial_aggregate(\n", + " name: str,\n", + " results,\n", + " split,\n", + " max_n_samples: int,\n", + " n_steps: int,\n", + " n_start: 1,\n", + " seed: int = 42,\n", + " n_repetitions: int = 10,\n", + "):\n", + " print(name)\n", + " columns = [\"method\", \"target\", \"n_samples\", \"i_repeat\", \"value\", \"metric\", \"run\"]\n", + " \n", + " results = results.query(\"split == @split\")\n", + " \n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n", + "\n", + " results_df = []\n", + " for target in [\"Cag_ha\"]:\n", + " for run, result in results.groupby(\"run\"):\n", + " for n_samples in range(n_start, max_n_samples, n_steps):\n", + " # if n_samples == 0:\n", + " # n_samples = 1\n", + " n_splits = len(result) / n_samples\n", + " if n_splits < 10:\n", + " print(f\"less than 10 samples possible with n_samples = {n_samples}\")\n", + " break\n", + " for i_repeat in range(n_repetitions):\n", + " # aggregate\n", + " index = np.linspace(0, n_splits, len(result)).astype(int)\n", + " rs = np.random.RandomState(seed + i_repeat)\n", + " rs.shuffle(index)\n", + "\n", + " agg_results = result.copy().select_dtypes(include=np.number)\n", + " # set aggregate index\n", + " agg_results[\"agg\"] = index\n", + " # ignore overhanging samples\n", + " agg_results.query(f\"agg < {int(n_splits)}\", inplace=True)\n", + "\n", + " agg_results = agg_results.groupby(\"agg\").apply(\n", + " lambda x: x.sum() / n_samples\n", + " )\n", + "\n", + " mask = agg_results[target] != 0\n", + "\n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " n_samples,\n", + " i_repeat,\n", + " r2_score(agg_results[target], agg_results[target+\"_pred\"]),\n", + " \"R2\",\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " n_samples,\n", + " i_repeat,\n", + " mean_squared_error(\n", + " agg_results[target],\n", + " agg_results[target+\"_pred\"],\n", + " squared=False,\n", + " ),\n", + " \"RMSE\",\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " n_samples,\n", + " i_repeat,\n", + " mean_absolute_percentage_error(\n", + " agg_results[mask][target],\n", + " agg_results[mask][target+\"_pred\"],\n", + " )\n", + " * 100,\n", + " \"MAPE\",\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " if n_samples == 1:\n", + " break\n", + " return pd.concat(results_df, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "id": "604f0c70-79a6-4f79-8555-b0e60a681d5b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RF\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "KPConv\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PointNet\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_xy_treeadd\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50_xy_treeadd\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "linear_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RF_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\power{}_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "KPConv_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PointNet_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14_xy_treeadd_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet50_xy_treeadd_treeval\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2805/1353084170.py:17: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha\"] = results[\"BMag_ha\"] * 0.47\n", + "/tmp/ipykernel_2805/1353084170.py:18: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " results[\"Cag_ha_pred\"] = results[\"BMag_ha_pred\"] * 0.47\n" + ] + } + ], + "source": [ + "result_scores_agg = {}\n", + "for split in [\"test\"]:\n", + " result_score = []\n", + " for name in models.keys():\n", + " result_dict[name] = file\n", + " scores = spatial_aggregate(name, results_corrected[name], split, 91, 1, 1)\n", + " result_score.append(scores)\n", + " result_score = pd.concat(result_score, axis=0)\n", + " result_scores_agg[split] = result_score" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "id": "3c8d3b69-734d-4f2b-8ea4-2a51d8d92274", + "metadata": {}, + "outputs": [], + "source": [ + "with open('result_scores_agg.pickle', 'wb') as handle:\n", + " pickle.dump(result_scores_agg, handle, protocol=pickle.HIGHEST_PROTOCOL)" + ] + }, + { + "cell_type": "code", + "execution_count": 105, + "id": "9755eb89-1e31-4fc3-b503-a81f689dea05", + "metadata": {}, + "outputs": [], + "source": [ + "with open('result_scores_agg.pickle', 'rb') as handle:\n", + " result_scores_agg = pickle.load(handle)" + ] + }, + { + "cell_type": "code", + "execution_count": 108, + "id": "269c16f8-f6d3-449b-8396-c081cd87996b", + "metadata": {}, + "outputs": [], + "source": [ + "hue_order = [\n", + " \"linear\",\n", + " \"linear_treeval\",\n", + " \"\\power{}\",\n", + " \"\\power{}_treeval\",\n", + " \"RF\",\n", + " \"RF_treeval\",\n", + " \"PointNet\",\n", + " \"PointNet_treeval\",\n", + " \"KPConv\",\n", + " \"KPConv_treeval\",\n", + " \"MSENet14\",\n", + " \"MSENet14_treeval\",\n", + " \"MSENet50\",\n", + " \"MSENet50_treeval\",\n", + "]\n", + "hue_order = hue_order[1::2]\n", + "palette = sns.color_palette(\"Set2\", n_colors=len(hue_order))" + ] + }, + { + "cell_type": "code", + "execution_count": 110, + "id": "d9dda8b4-2652-4a33-972d-b254862e6d4f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "[(0.4, 0.7607843137254902, 0.6470588235294118),\n", + " (0.9882352941176471, 0.5529411764705883, 0.3843137254901961),\n", + " (0.5529411764705883, 0.6274509803921569, 0.796078431372549),\n", + " (0.9058823529411765, 0.5411764705882353, 0.7647058823529411),\n", + " (0.6509803921568628, 0.8470588235294118, 0.32941176470588235),\n", + " (1.0, 0.8509803921568627, 0.1843137254901961),\n", + " (0.8980392156862745, 0.7686274509803922, 0.5803921568627451)]" + ] + }, + "execution_count": 110, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "palette" + ] + }, + { + "cell_type": "code", + "execution_count": 111, + "id": "b56ceb24-2ab5-4938-8e17-231a22f93e24", + "metadata": {}, + "outputs": [], + "source": [ + "treevals = result_scores_agg[\"test\"][\"method\"].str.contains(\"treeval\")\n", + "rs = result_scores_agg[\"test\"].query(\"target == 'Cag_ha' & metric == 'RMSE' & @treevals\")" + ] + }, + { + "cell_type": "code", + "execution_count": 113, + "id": "3f19d6da-f6e6-41e2-b28b-1408010f2a8c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# RMSE\n", + "figsize = (7.48031, 2.1)\n", + "g = sns.lineplot(\n", + " data=rs,\n", + " x=\"n_samples\",\n", + " y=\"value\",\n", + " style=\"method\", \n", + " hue=\"method\",\n", + " hue_order=hue_order,\n", + " palette=palette,\n", + " orient=\"x\",\n", + " errorbar=\"se\",\n", + " #err_style=\"bars\",\n", + " #err_kws={\"capsize\": 3},\n", + " dashes=True\n", + ")\n", + "g.set(xlim=(1, 90.25))\n", + "g.set(xticks=np.arange(0, 91, 10))\n", + "# g.set(ylim=(0, 26.5))\n", + "g.set(ylabel=\"RMSE\")\n", + "g.set(xlabel=\"number of samples\")\n", + "# sns.despine(left=True, right=True, top=False)\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "plt.subplots_adjust(left=0.03, right=1, top=1, bottom=0.125, hspace=0.15, wspace=0.1)\n", + "#plt.savefig(\"figures/spatial_agg_RMSE_carbon.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "6d3264fb-7a93-4231-a689-ee8fc95b2f6a", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# RMSE 2\n", + "figsize = (7.48031 / 3, 1)\n", + "g = sns.lineplot(\n", + " data=rs,\n", + " x=\"n_samples\",\n", + " y=\"value\",\n", + " style=\"method\", hue=\"method\",\n", + " hue_order=hue_order,\n", + " palette=palette,\n", + " orient=\"x\",\n", + " errorbar=\"se\",\n", + " err_style=\"bars\",\n", + " dashes=True,\n", + " err_kws={\"capsize\": 3},\n", + " legend=False\n", + ")\n", + "# sns.despine(left=True, right=True, top=False)\n", + "g.set(ylabel=\"RMSE\")\n", + "g.set(xlabel=\"number of samples\")\n", + "g.set(xlim=(12.75, 16.25))\n", + "g.set(xticks=np.arange(13, 17, 1))\n", + "g.set(ylim=(4.75, 7))\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "plt.subplots_adjust(left=0.03, right=1, top=1, bottom=0.125, hspace=0.15, wspace=0.1)\n", + "#plt.savefig(\"figures/spatial_agg_RMSE_zoom_carbon.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "f5c900e4-dd77-48eb-b786-b803eaa5959d", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# RMSE 3\n", + "figsize = (7.48031 / 3, 1)\n", + "g = sns.lineplot(\n", + " data=rs,\n", + " x=\"n_samples\",\n", + " y=\"value\",\n", + " style=\"method\", hue=\"method\",\n", + " hue_order=hue_order,\n", + " palette=palette,\n", + " orient=\"x\",\n", + " errorbar=\"se\",\n", + " err_style=\"bars\",\n", + " dashes=True,\n", + " err_kws={\"capsize\": 3},\n", + " legend=False\n", + ")\n", + "# sns.despine(left=True, right=True, top=False)\n", + "g.set(ylabel=\"RMSE\")\n", + "g.set(xlabel=\"number of samples\")\n", + "g.set(xlim=(85.75, 90.25))\n", + "g.set(xticks=np.arange(86, 91, 1))\n", + "g.set(ylim=(1.95, 4.1))\n", + "g.set(yticks=np.arange(2., 4.1, .5))\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "plt.subplots_adjust(left=0.03, right=1, top=1, bottom=0.125, hspace=0.15, wspace=0.1)\n", + "#plt.savefig(\"figures/spatial_agg_RMSE_zoom2_carbon.svg\")" + ] + }, + { + "cell_type": "markdown", + "id": "7b282eb4-ea1d-40d6-abd7-f77021248f5f", + "metadata": {}, + "source": [ + "# Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 117, + "id": "e9305ab9", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (6.556, 1.750)\n", + "g = sns.catplot(\n", + " data=result_scores_cfrac[\"test\"].query(\"target == 'BMag_ha'\"),\n", + " x=\"value\",\n", + " y=\"C_qfrac\",\n", + " hue=\"method\",\n", + " hue_order=[\n", + " \"linear\",\n", + " \"\\power{}\",\n", + " \"RF\",\n", + " \"PointNet\",\n", + " \"KPConv\",\n", + " \"MSENet14\",\n", + " \"MSENet50\",\n", + " ],\n", + " col=\"metric\",\n", + " kind=\"bar\",\n", + " palette=sns.color_palette(\"Set2\", n_colors=7),\n", + " legend_out=True,\n", + " height=1.750,\n", + " orient=\"h\",\n", + " sharex=False,\n", + " edgecolor=\".0\",\n", + " linewidth=0.05,\n", + " errorbar=None,\n", + " estimator=np.median,\n", + ")\n", + "# g.set(xlim=(0.5, 1.01))\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "plt.subplots_adjust(left=0.08, right=0.97, top=1, bottom=0.075, hspace=0.15, wspace=0.1)\n", + "#plt.savefig(\"figures/species_bplot_b.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 118, + "id": "13c4ea56-8d73-4dc1-9a99-3e0c8e7109fa", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (6.556, 1.750)\n", + "g = sns.catplot(\n", + " data=result_scores_cfrac[\"test\"].query(\"target == 'V_ha'\"),\n", + " x=\"value\",\n", + " y=\"C_qfrac\",\n", + " hue=\"method\",\n", + " hue_order=[\n", + " \"linear\",\n", + " \"\\power{}\",\n", + " \"RF\",\n", + " \"PointNet\",\n", + " \"KPConv\",\n", + " \"MSENet14\",\n", + " \"MSENet50\",\n", + " ],\n", + " col=\"metric\",\n", + " kind=\"bar\",\n", + " palette=sns.color_palette(\"Set2\", n_colors=7),\n", + " legend_out=True,\n", + " height=1.750,\n", + " orient=\"h\",\n", + " sharex=False,\n", + " edgecolor=\".0\",\n", + " linewidth=0.05,\n", + " errorbar=None,\n", + " estimator=np.median,\n", + ")\n", + "# g.set(xlim=(0.5, 1.01))\n", + "fig = plt.gcf()\n", + "fig.set_size_inches(figsize)\n", + "plt.subplots_adjust(left=0.08, right=0.97, top=1, bottom=0.075, hspace=0.15, wspace=0.1)\n", + "#plt.savefig(\"figures/species_bplot_v.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 119, + "id": "11155116", + "metadata": {}, + "outputs": [], + "source": [ + "target = \"BMag_ha\"" + ] + }, + { + "cell_type": "code", + "execution_count": 122, + "id": "b0d996be-ff49-4564-a4f6-4938ba36a90f", + "metadata": {}, + "outputs": [], + "source": [ + "dff = []\n", + "\"\"\"\n", + "df_ensemble = df_test.copy()\n", + "df_ensemble[\"method\"] = \"ensemble\"\n", + "for target in target_vars:\n", + " df_ensemble[f\"{target}_pred\"] = 0\n", + "ensemble_member = [\"Minkowski\", \"KPConv\", \"PointNet\", \"exp.\\ model\"]\n", + "\"\"\"\n", + "best = {\n", + " \"MSENet14\": 0,\n", + " \"MSENet50\": 3,\n", + " \"KPConv\": 4,\n", + " \"PointNet\": 4,\n", + " \"RF\": 3,\n", + " \"linear\": 0,\n", + " \"\\power{}\": 0,\n", + "}\n", + "for name, split in product(models, [\"test\"]):\n", + " if \"treeval\" in name:\n", + " continue\n", + " dfr = results_corrected[name].query(f\"run == {best[name]} & split == @split\").reset_index(drop=True).copy()\n", + " dfr[\"method\"] = name\n", + " \"\"\"\n", + " if name in ensemble_member:\n", + " for target in target_vars:\n", + " df_ensemble[f\"{target}_pred\"] += dfr[f\"{target}_pred\"] * (\n", + " 1.0 / len(ensemble_member)\n", + " )\n", + " \"\"\"\n", + " dff.append(dfr)\n", + "# dff.append(df_ensemble)\n", + "\n", + "dff = pd.concat(dff, axis=0)" + ] + }, + { + "cell_type": "code", + "execution_count": 123, + "id": "7d0d7ff0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "KPConv\n", + "MSENet14\n", + "MSENet50\n", + "PointNet\n", + "RF\n", + "\\power{}\n", + "linear\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJsAAADQCAYAAAAQ/hMjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAvpElEQVR4nO2dd3xUVdrHv/femclkUklIAoQWkaEJhCZV1kUUBQRRYKUqrkpRwRUF2XV1FVexS3NVfBGU1fejooCLoCzvKkVYmlKkhF6kJaRnMuXee94/JjMkhJJMZiYE7/fz4QNzbjnP3PlxynOe81xJCCEwMAgDcnUbYPDbwRCbQdgwxGYQNgyxGYQNQ2wGYcMQm0HYMMRmEDZM1W1ATUHXdX6rLklJkpDlqrdLhtiugK7rHD16FKfTWd2mVCtWq5VGjRpVSXSSsYJweU6fPo3L5SI1NRVJkqrbnGpBCMGvv/5KREQEderUCfg+Rst2GYQQ5Obm0rhxY0ym3/ajSklJ4ciRI6SkpAT8n86YIFwGIQRCCMxmc3WbUu2YzWb/8wgUQ2yXwRhhlMcQm0GNwBBbDeHdd9+la9euNGvWjP/+97/VbU5A/LZHvUHmxIkTzJkzh3Xr1pGXl0e9evW46aabePDBB6s0i/v11195++23mTNnDm3btiUuLi6IVocPo2ULEocOHeKee+4hNzeXt99+m++++44ZM2agqioLFiyo0r1PnDiBEIJbbrmFpKQkLBZLuXPcbneV6ggLwuCSqKoqdu/eLVRVveK59913nxg8eLDQdb3csby8vMtee/LkSTF69Ghxww03iAEDBojly5cLu90ujh8/LhYvXizsdnuZP0IIMXLkSDFjxgwxbdo00a5dOzF9+nThcrnEU089JXr27Cnatm0rBg0aJH788ccydWVlZYnJkyeLTp06ifT0dPGHP/xBHD16NKjP4lIY3WgQyM7OZuPGjbz55psX9UHFxsZe9vopU6bg8Xj4/PPPycrK4oUXXvAf69u3L5GRkTz++OOsW7euzHX/+7//y/jx4/nqq6+QZRlVVWncuDFjxozBZrPx9ddfM2HCBP7973+TmJgIwKOPPoqu6/zjH/8gMTGRn3/+GVVVg/AUKkDAMv0NUNH/zT///LOw2+1i9+7dla7jwIEDwm63iwMHDvjLPvnkE3/LJoQQ69ev97doPkaOHClGjhx5xfv36dNHfPXVV0IIITZs2CBatWolTp8+XWk7jZbtGuDw4cNERUXRpEkTf1mbNm0qdG3Lli3Llf3P//wPS5Ys4cyZM3g8HpxOJ6dOnQJg//79NG7cmJSUlOAYX0mMCUIQaNCgAZIkcfjw4UpfK4QIePnHarWW+bx06VLmzp3LAw88wEcffcSSJUu4/vrr/d2kqGYntSG2IJCQkEDnzp1ZuHDhRX/QgoKCS16blpZGYWEhhw4d8pft3LkzIDu2b99Oly5dGDRoEM2bN6d27dqcPHnSf9xut3PkyBHOnDkT0P2riiG2IPHss89y+PBh7r//ftavX8+JEyfYvn0706dPZ+7cuZe87vrrr6dTp0789a9/Ze/evfz44498+OGHAdnQsGFDfvrpJ7Zs2cL+/fuZNm0auq77j3fp0oXWrVszceJEtm7dyrFjx1i2bFkZoYcSQ2xBokmTJixevJg6deowdepU7rjjDp566ikkSWLMmDGXvfbVV19FlmUGDx7Myy+/zGOPPRaQDffeey9du3bloYceYsyYMbRv357mzZuXOWfOnDmkpqby8MMPM3DgQD799NOwRbQY8WyXQdM0MjIysNvtKIoStnqPHj3KbbfdxurVq6lfv37Y6r0cwXgWRstmEDYM10cY6NevX5mBeml++umnMFtTfRhiCwPvv/9+pbz0jRo1Yt++fSG0qHowxBYGUlNTq9uEqwJjzGYQNgyxGYQNQ2wGYcMQm0HYMMRmEDYMsV0D5OfnM2nSJNq1a8dNN93EP//5z+o26aIYro8Q4nSrbM/IJCvPSe04K23tSVgtwX/kL7zwApqmsXbtWo4dO8aYMWNo0qQJXbp0CXpdVcEQW4g4fDKPuV9sJzPH4S9LqmXjkcFtSasXvN1RDoeDlStXsmTJEqKjo2nZsiWDBg1i8eLFV53YjG40BDjdKnO/2M6ZbAexURYSYq3ERlk4c87B3C+243QHL+b/yJEjgDdUyUfz5s3Zv39/0OoIFobYQsD2jEwycxzER1swKd5HbFJk4mMsnM1xsD0jM2h1ORwOoqKiypTFxsZSVFQUtDqChSG2EJCV583l5hOaD9/nc/nBy/Vms9nKCaugoKCcAK8GDLGFgNpx3r0BqqaXKfd9Toy1lrsmUBo3bgzAwYMH/WV79+6ladOmQasjWBhiCwFt7Ukk1bKRW+D2C0zVdHIL3CTXstHWnhS0umw2G3369GHmzJkUFhayd+9evvzyS+6+++6g1REsDLGFAKvFxCOD25KSaCO/yM25fCd5RW5SEr2z0WC7P5577jkAf16RiRMn0rVr16DWEQyMsPDLUNVQaJ+f7Vy+k8TY0PnZwkEwwsJr5jevIVgtJjrfULe6zbhqMLpRg7BhiM0gbBhiMwgbhtgMwoYhNoOwYYjNIGwYYjMIG4bYDMKGIbYazqJFi7j77ru54YYb+NOf/lTmWEZGBkOHDqVt27b079+fLVu2+I+dPXuWcePG0aNHD5o1a1ZmIT9UGGILIbrHRVHGZvK2rKQoYzO6xxX0OpKTk5kwYQJDhw4tU+7xeBg/fjy9e/dm8+bNPPTQQ0yYMIG8vDwAZFnmpptu4p133gm6TZfCWK4KEa4zR8ha8R5q3vlASVNcErXvGEtESuOg1XPbbbcBsGfPHnJycvzlmzZtwul08uCDDyLLMgMHDmThwoV89913DBkyhNq1azNixIig2VERjJYtBOgel1douWeRI2NRohOQI2NRc8+SteK9kLRwF7J//37sdnuZl9FWd7i4IbYQUHx4B2peJrItDknxdh6SYkK2xaHmZVJ8eEfIbSgqKiImJqZMWXWHixtiCwFq/jkAv9B8+D5rBedCbkNUVBSFhYVlyqo7XNwQWwgwxXrfpiK0sruofJ+VmMSQ29C0aVMyMjLKJHDes2dPtYaLG2ILAZFpbTDFJaE78vwCE5qK7sjDFJdEZFrFXqpREVRVxeVyoaoquq7jcrnweDzceOONWCwW5s+fj9vt5uuvv+bEiRPceuut/mtdLhcul3f86PF4cLlcIX1XghGpexmqEp0artno7NmzmTNnTpmyQYMGMWPGDPbt28czzzzDvn37aNCgAX/729/o1KmT/7xmzZqVu9+lkkYHI1LXENtlqOoD1j0uig/vQCs4hxKTSGRaG2RzRAgsDT1GWPhVjmyOIMre6con/kYwxmwGYcMQm0HYMMRmEDYMsRmEDUNsBmHDEJtB2DDEZhA2DLEZhA1DbDWcQMPCAVauXMktt9xCeno6DzzwQMhf522ILYS4VDdbft3Ot/t/YMuv23Gp7qDXEWhY+MGDB5k2bRrTp09n48aNNGrUiMmTJwfdvtIYy1Uh4kjOCeZt/YSsomwEIAG1oxJ4uONwGsUH7+3IgYaFL1u2jJ49e9KtWzcAJk2aRPfu3Tl27BgNGzYMmn2lMVq2EOBS3czb+glnC7OIiYgmITKOmIhozhZm8f6WT0LSwl3IlcLCMzIyyrw/Pj4+nrp165KRkREymwyxhYCdZ/aQVZRNnDUWk+yNkDDJCnHWWDKLstl5Zk/IbbhSWLjD4Qh72LghthBwzpGLAL/QfJhkBQlBdnFuyG24Uli4zWYLe9i4IbYQkGiLRwJUXStTruoaAomEyPiQ23ClsHC73c7evXv9x/Ly8jh16hR2uz1kNhliCwGtU1pQOyqBPGe+X3CqrpHnzCcpKoHWKS2CVlegYeEDBgxgzZo1bNiwAafTyaxZs0hPTw/Z5ACMSN3LUpXo1KO5J3h/i3c2CgKBRFIIZqNVCQtfsWIFr7/+OllZWXTo0IGXX36ZlJSUi9ZjhIWHmKo+YJfqZueZPWQX55IQGU/rlBZEmCwhsDT0GGHhVzkRJgsdU9tWtxlXDcaYzSBsGGIzCBuG2AzChiE2g7BhiM0gbBhiMwgbFXZ9LFmypMI3veuuuwIwxeBap8Jie+utt8p8zsvLw+l0+hdui4qKsFqtxMfHG2IzuDgiAJYsWSJGjhwpDh486C87ePCgGD16tPjqq68CueVViaqqYvfu3UJV1eo25YqsXLlS9OvXT7Rt21bcfPPN4ttvvxVCCKFpmpg9e7bo2bOnSE9PF3379hVHjx6t9P2D8SwCEtvNN98s9u7dW658z549omfPngEbc7VR1QesOp0ia+MmcXL5CpG1cZNQnc4gW+jlxx9/FD179hSbN28WmqaJrKwscezYMSGEELNmzRIjRowQx44dE7qui0OHDonc3NxK1xEMsQW0XJWbm0t2dna58pycHPLz86vc2l4LFB0+wsF/vIfz7Pn8bNbkJJpMGEtU48ZBrWvWrFk88sgjdOzYEYDExEQSExPJz89n/vz5fPXVVzRo0ACAtLS0oNZdGQKajfbv358pU6bw2WefsWfPHvbu3ctnn33GlClT6NevX7BtrHFoLpdXaGfOYoqNxZKQgCk2FueZsxx85z00V/CyhWuaxs6dO8nJyeHWW2+lR48eTJ06lby8PDIyMlAUhe+++47u3bvTu3dv5s6dG9LskpcjoJbt2WefJSUlhZkzZ3LunDcZcWJiIvfeey9jx44NqoE1kdyfd+A8m4kpLg7Z5H3EssmEKS4O59lMcn/eQWLn4ORty8rKwuPxsGLFCj7++GNsNhuTJ0/mpZdeokePHhQUFHDw4EFWrVrFmTNn+OMf/0idOnW45557glJ/ZQhIbGazmUcffZRHH32UwsJChBDl4tl/y7hL/gP6hObD99mdHbxs4ZGRkQCMGDGCOnXqADBu3DgeeeQRf6DkI488gs1mIy0tjSFDhvDDDz9Ui9gCduoKIdi2bRurV6/27+DJy8vD7Q79zqGrHUuiNxu4rpbNFu77bEkIXrbw2NhY6tatiyRJ5Y75cuZe7Fh1EJDYfv31VwYMGMAf//hHpk2b5p8szJ49m5deeimoBtZE4tPbYE1OQs3L8wtMV1XUvDysyUnEpwcvWzjA4MGD+ec//0lmZiaFhYXMmzePXr160aBBAzp37sw777yDy+Xi+PHjfP755/Tq1Suo9VeUgMQ2ffp02rRpw6ZNm4iIOJ+QuE+fPqxfvz5oxtVUlIgImkwYizUlGbUgH3d2Nmp+PtaUZJpMGIsSEdwkzuPGjaNDhw7069ePW2+9lVq1avHnP/8ZgNdff53s7Gy6dOnCqFGjuPfee6vP6R6Iv6Rjx47i8OHDQggh0tPT/T6d48ePi9atWwfsh7naCJqf7ZvQ+tnCQbX52UwmEw6Ho1z5kSNHqFWrVpX/A1wrKBERQZt1XgsE1I3efvvtvPHGGxQUFPjL9u/fzyuvvELfvn2DZpzBtUVAYps6dSqJiYl069YNp9PJgAEDGDBgAGlpaeXSNhkY+AhoK19BQQFWq5WzZ89y4MABioqKaN68Odddd10obKw2grF97VqhWrbyqapK165dWbZsGddddx2pqakBVWzw26PS3ajJZKJRo0bGgrtBpQl4zDZjxgy2bNlCUVERuq6X+WNgcDECcn08/PDDAIwaNeqix/fsCX3+MYOaR0Bi++ijj4Jth8FvgIDEduONNwbbDoPfAAEnljl+/Diffvophw8fBrwRoMOGDfNHhBqAx61xKCOT/DwnsXFWrrMnYbb8dl0oAYltxYoVPPXUU7Rp04Y2bbwRDNu3b+ejjz7itdde44477giqkTWR0yfzWf7FDvJyiv1lcbUi6T+4DSn1YqvRsuojILG99tprPProo4wbN65M+Xvvvcerr75aY8Sm6zpOpxOr1Vomq3ZV8bg1ln+xg9zsYqKiLSiKjKbp5J4r5l9f7GD0uK5Ba+F69erFyJEj+de//sWRI0fo0KEDr7/+OnFxcTzxxBNs2rSJ4uJimjVrxnPPPeePcXv66aexWq1kZWWxfv166tevz6uvvkqLFsHLinkhAT3hnJwcbr/99nLlffr0ITc3t6o2hQ2n08mePXtwOp2AV3w7duyosvvmUEYmeTnnhQagKDJRMRbysovZvesUuYUuCovd6HrFF3B0XXD8TAGHTuZx7HQ+x88UoGo6i79cyl/+9go//LCGgoICFixYgMPhoFu3bqxcuZINGzbQqlWrci/VWL58OQ888ABbtmyhS5cuvPjii1e0weVR+e+uUyxff5j/7jqF061e8RofAbVsffr0YeXKleVatm+//ZbevXsHcsurAiEEHo+nyhtC8vO84vUJzYcsS2i64NTpAmLreMPoTYqL5FqRRFiu/FM4XB5UTcckS2Wib++6ZygJtZORTBb69OnDhg0bEEJwzz33+M977LHH+Oijj8jJyfFH5txyyy20b9/ee4+77uKLL764bP2qpvP8B//l9LnzET9JtWw8MrgtafXirmh/QGKrVasW8+bN4/vvv6d169ZIksSOHTvYv38/Q4cOZebMmf5zJ02aFEgVNZrYOCsAmqaXEZzLpQKC6FgrJkX2ilvVOZtTTHxMBJouMCkStggzslw+lNuj6ggh0HVAAt8pCYm1S+oTWK1WHA4HmqbxxhtvsHLlSrKzs/3DhNJiq127tv/evusuha4LCorcnM1xEBtlwaTIqJrOmXMO5n6xndcn9rzicwlIbLt27aJly5YA/vTmZrOZli1bsmvXLv95V0vse7i5zp5EXK1Ics8VExXj7Uo9Ho1ih4eYeCupjb0/tiRJKDK4PBqZOQ7/8zIpLmJsZpAkv/g8qkZ+oRtNBx0BCCQJvP/0tsSKcv55r1ixgu+++44PP/yQ+vXrU1hYSMeOHQNutYvdKpouiI+ygCSX2CkTH2PhbM6lRVqagMT28ccfV+i806dPo+t6UAffNQGzRaH/4Db8q9RsVNcF0XFWutzaFCFJeFQdWQZVEwgBkixhUmR0XeBya7jcGpLsVZNJlpEk0PRSAgN04f2npnt/eFuE2W+Dw+HAYrEQHx+P0+nk7bffrtJ3UjWfoGW0UkNak1Lx3zakCZz79u3L0qVLf5O+t5R6sYx6uDMHdp0kP89BhM1CZEocitmEpgtkdCzCiRUdXZJxy2Y0CTQAyYQQkldNgLtkwmJSJBB+rfmRJEiuFVmm6+3fvz+bNm2iZ8+exMfHV3k4YyppNTVN97ds4B3HVZSQpqZv164dy5Ytu2rF5nA42LNnDy1atMBms6FpGj///DPp6ekoihJwDJfQddSifNzZeQhNB0lC13VUFPLNUciyRiTFeN+PAEICXZJwmWU0GYQuI9yRSLoCkr+XBLxv98OnqRLhmRWJhnVi/WITQuBwOLDZbEEbyng8Kpu37WTet2eItp0fs+UWuElJtIVuzGZQHl0XOFwedLcbszMHvcjtFYkECAkdCUVoxHoK0SLA1z75dCQLgdWjU2RWkCQdk6kIyWNBCBkVM6JEYb5blm7htJK6oyND944FWZaIibKQnGDj9DmHv+6URO9stCIYYqsCPoE53RpFDg9C6CSSh65pfqFJeFsaBYGGhCSVTCVLfi0ZzgtHFlg0HZMukASA09sq4sIpItFKfq4L+yJdQEGRO6RiA+/47LkHO7PzQDbn8p0kxlppa0/CWgG3DRhiCxiXW+VsTjGqpqNp3u7QJrtRZA1dSAgEuiShl7RJiu4VnCpJCCF5xVSqh/O1VuaSgbiQQCoRpYKGVSrGIaL9LdyFFBar5BW6St4GKJUf2AWJCLOJzjfUDehaQ2wBoOuCsznFeDTdvwQjSV5RIECXwG2W0CUJn6JkIbB4fC1WCRcIDkASAiFJyHrZExU0zLhwYz1/LgITHqSSd2Nl5gqUkpm/LIHJolW41QkHIbWkY8eOZXbM13R0XcfjclLsdCFUFZNiwb+yJUBFQUjglkBHKiMsXZJwWs6LSPi6WN8Jvg+SV0Rl8R6MkFx4RAQCCQUVq1SMzPnZoI4LTYpGSAoeTSczp5jUpOiLOoirgyqJTQhBZmYm6gUJVOrVqwfAvHnzLnu92+3m+eefZ8OGDeTk5FCvXj3Gjh3LgAEDAO8ic1ZWln8mWK9ePZYvX+6/fuXKlbz22mucO3eO9u3bX/atclXF43ZRkHsOXVPRBViEAM2FS7IhkL2tiybw6BK6VHbWKEr9Wy/leL1QUoom0ExSuTGZlxKfGx5UzH6hebtV7zEZHUkvRFXiUCSvWyLUE4fKEJDYcnJyeOGFF1i1ahWappU7XtGwcFVVSU5OZuHChaSmprJt2zbGjh1LgwYNaNeuHQBz5syhZ8/y0+qDBw8ybdo05s6dS/v27XnllVeYPHkyixYtCuQrXRZd1ynIPYemqn4fk0BHFhpmUYQQVmI9DkxCQzMJUGREZRqTkhZNL/GjSX6RXtjPSv6uU0HDJzJRckwAktCRhBskMwjvEtbVQkCu/RdffJGzZ8+yaNEirFYr7733Hi+//DJpaWnMmjWrwvex2WxMmjSJBg0aIMsyHTt2pH379vz0009XvHbZsmX07NmTbt26YbVamTRpEj/99BPHjh0L5CtdFtXtQlNVNCGhagJN94pAR0JGJ0YrQhEakiTKDfwrhE9c54d4pYq9YvGNDmVJECs7kBEo6JgkHRO6T2ol7Zx+0SWs6iaglu3HH3/kgw8+oFWrVkiSRIMGDejZsycJCQnMnj3bn4SusjgcDnbt2sXo0aP9ZU8//TS6rtO0aVMef/xxOnToAEBGRoY/cBMgPj6eunXrkpGRUem3AWua5v/j++z7WwiBqnrQhVdIsuRdXhLC11UKr+wkgSR5B/dBo6SV03zdNDJxkhPQUUv10ZIkUISOiq/VldEEWEwykRGmoKQ1Fd5k3xftySrq8A5IbKqqEhvrjTZNSEjg7NmzpKWl0bhxYzIyMgK5JUIIpk2bRps2bejRowcAr776KjfccAMAX375JQ899BBff/01qampOByOctkuY2NjKSoqqnTdF9q8c+dO/79NJhNOpwtJ6H6fGOK8r0vgnWnqMqiSjC5z0VnmJb/3haf6HcHnixS8qw8CMybJ4xWf8L2fGWa8MZMNGzfhKC4mJiaG2/sNYsToB1BEMSOGP8Thw4fxeDykpqYybtw4br755ko8HS+6ruPxeMo8Gx++BuBKBCS2Fi1a8Msvv/jHVnPmzKGwsJClS5cGlI1aCMFzzz3HmTNnmD9/vn+JxZf9GmD48OF88803rFmzhmHDhmGz2SgsLCxzn4KCAv9LQCqD3W73L1ft3LmT1q1b+5erDhzYj1Vz4SpZqpRKLQUKyeuqUE0+NwdllKN6PJw5cYjiwnwio2NJqX8dJvP5xfLLtTeKLpB1gWaRkFDI16OJkpz+4xYZPEJCF4Kh9wzi8UfHYY20kVXg5vGJj9HCnkbfvncwffp00tLSUBSFbdu28eCDD7JixYpKT6Q0TcNsNtOiRYvwvkn5T3/6kz/26cknn2Tq1Kk8+eSTNGzYsELRnqURQvD888+ze/duFixYgM1mu+S5kiT5uwS73e4PbwJvitVTp05ht9sr/X0URUHX4ODeLI5kFGEzZ3F98xQURUGoHiShYkbGI7yOWp+gvOMzyet8uKAly806zdbvl1NUkOcvi4qJo8PN/YmvnXJxoZUqlHVvP21SBG7NhLezVkqeAyhIyCWO4xZNGiMJgTU+AQ8FmEwKx48fw2Kx0LRpU++thUBRFFRV5eTJk/78uxVFkrwBm4qihFdsvpkiQEpKCgsWLAiocoAXXniB7du3s2DBAqKjo/3lJ0+e5OTJk/5x2ZIlS9i1axd///vfARgwYABDhgxhw4YNtGvXjlmzZpGenl7p8RpAVmYhq5ZuIy+7GI/q4ei+X4hLOETfu29A9ajecZEQZVwYwCVnnKrHw9bvl1OYn4vVFoUsK+iaRmF+Llu//xe/GzgapVQLp0sg66VcbUIgJHBGysRrOud0K7IEmhKBJopRhIaOjCRJWBUJhM7MeQv4ZPESiouLqVevXpnUZcOHD2fHjh14PB66detG27bV8yrxKkV9FBcXc+7cuXID0IpGefz666/06tULi8WCqVRm7bFjx9K7d28mT57MsWPHMJvNNGnShMcff5zOnTv7z1uxYgWvv/46WVlZdOjQodJ+Nl/Ux76fnZw67sAWZcbtdmGxROAo9JCQFEmbzpHY68bjEqW6yit9r8P72Pr9ciIibcjy+VZA1zRcTgcdbu5HvTTvxpPSgvU5gTUFdJNErKpToEVRLCKIsCjUrR2Fq7gYuSgbSWh+F4mkmDDFJSGZLOzcuZPVq1czYsQIkpKS/EMSt9vNmjVrOH78OGPGjKnwM/JRLVmMAPbt28ef//xndu/eDXibaF8XJ0lShf1sqamp7Nu375LHly5detnr77jjjqDs5CrMdxEVbfF72n2bU3JzinGrVnQh0KWKe4mKC71Jd0oLDUAu+ZGKi84nUXRESCg6KJoEuowsCxQJTG6FTD0SgYzFLJNcKxKTImOKjkLYItHdxaCpoJiQLZEIJIqcHhpe1wz+73veffc9/vrXZ/z1WCwWevfuzX333UejRo2qJYlzQGKbNm0aycnJfPrpp9SuXfuaCP9WSvYEgDcezVVUhMfpIYJoyk/2L09ktHemrutauZYNIDKqZBYtgVkDk1YyRkPzRk9KJoiIIUY2EWk1EWUtuydBkmUU6/mJUOmgAICCIheHjhzD5Sm/NqppWkh8kRUhILEdOnSIt956i0aNGgXbnmpDVTVUj05xkRt3bi4m3MiyFR0dN3KlvN8p9a8jKibOO2aLjEJWvGM2Z3ER0bHxJNcvSZoowKyW+P8lr4fd68PzoHjyiIyOxeuTNXEpX0peXj6Ll35Dp643ERMVxZ7dO1m+bDH3jnyAHzf+RKRZ9bsmli5dys8//+zPJB5uAp4gHDp06JoRmzXSzLGD+d6WTQjcWJGIJF7JRqPye0hNZjMdbu7P1u//VWY2Gh0bT4eb+5dxf0jC50KhxGvmFZ+mqRQV5CJJErJiIia+NmZL+TXOYrfKt998zdyZr6NpGrWTkrjnDyO46+4h7N2zm5lvvsaxo0cxmUykpaUxc+ZM/2alcFPhCcKGDRv8/z59+jTvvPMOw4YNo2nTpmUG9wBdu3YNrpUhwjdB+GWzg2MH80s2DPseh06MdIDud99Iw/qpAUVO+P1sRQVERsWQUv+6MrNQH6VXqUr/GLKsICkKuq6hKCbia9cpt3kot9BFTr6zzMYTIUAXOkJAQqyVuOiqR96EdYJwsRnMq6++Wq6sMhOEq4WiApdfaL4f20Q2Zssh4MbKr3X67mE2k+qbdZaUlbnVBSsFFynwtmyygq6peFxOIiLL+iF9G1F8kzM/JaHCNXJttLQD9VrDu3nY9zOrWKU92Ez7kaTghftd+JP7lr78nfT5UI/z15S0ppLkXWbXtPKpDmwRZkyKC4+qY1J8e3WFf2209Pa+6iakGzrvvPNOTp06FcoqgohAkXKJNX2HzbQ/pDX5Y92k87vaxQUtmixJSL5wJn8ER3nxy7JEcq1IzCYZTRfeMHVdYJIlki7Y3lfdhDRS98SJE+UCK69eVKKUrZhKrT9CSUMTSNjQBZTuIM+HO0oISUaRdCIkbyvn1kuqU0xQ4rv0jdnMEdaL3jvCYiI1KRqHy4OmCRRZAt1DhDn4ueCq4ua6egLUqxmzdBZZKihTpmsquqai6Xo5B21lEZzvRryjQ2/LpZhMxMQnIOsaaCpWAYWOIu8OrZJuUymZjV4us4AsS/6IXO++UU+V7L0Qj8fjXx8NFENsAAhkqRjpIm6OM4f3EBMTS2JC/JUftPBvE0UHPOaSMZfw7kEQsozFoxBpNmGxWlAUMyZLhFdEihnMoAAxETZvwKamoigm/zkXiyW7qBlCoOs6mqYFxeEuhODMmTPEx1fgGVwGQ2zg3TQinBftKU/u+4mYhBSKHMkV7kmFBKoiockgC28wo7fc61SLscSEdNeTL/WX2WwO2uqO1WolOTm5SvcIqdhqyjKWLJ3BZrr4Gq2uqexZ9zWSrFzy+4iSbZqqLOExe5egDqSayY02oWgxgOyfTcoWD+O7DqdNg9A5Vn1xeVWJPSuN1/1S9blkSMUWwjQiAOTn5/PXv/6VNWvWEB0dzbhx4xgxYkSl7xOp7OVSSZ/8gbO6Vi4GzRcCrgPOCK/oLC4ojpDIjjGDFolH1ShJF4PZpGCxyhSqBWF5F1ZVYs9CQUjFVpGNK1XhhRdeQNM01q5dy7FjxxgzZgxNmjShS5culbpPIO2vrzXTZPwL6hKgyBJS8xuItpwh2hKNx+11R5gUGbNFotCtkRAZH0CNNZ+AxNarV6+LdimSJGGxWGjYsCEDBgwI6btHHQ4HK1euZMmSJURHR9OyZUsGDRrE4sWLKy22SiG8ERpui4wrAjIamoksVoh2ySQ46/JL8Q30T2lOlmspZ4vOEWuNwSZbUHWNfGcByVGJtKxtr/BgPxAu3LgTakK64WXYsGF88MEH9OjRg9atWwPeTSLr1q1j9OjRnDp1iqeffprCwkKGDh0aSBVX5MiRIwBcf/31/rLmzZtXKmrYl6g5MqbWZc/z71AvcbqpJpkoScdkM3PO2QBhikdSa3EuUiZKEaiF57i5Vkc2uLdT6HYgSgK6G5jr0DW6LXt27a74F60CF9ucEgpatGhRoYzrAYlt69atPPnkkwwZMqRM+eeff87q1at59913adWqFQsXLgyZ2BwOR7nNLZXdXeVyuQCwd7ktYDuaX/KIlX51fhfwfWsSpXPcXY6AxLZx40aefvrpcuWdOnXy7xHo0aMHM2bMCOT2FcJms5UTVmV3V8XFxdG4cWMiIiJ+c6lYg43VevHVjdIEJLa6devy6aefMm3atDLln376KXXretMp5eTkEB8fH8jtK0Tjxo0BbxqGJk2aAN5gAd9uoopgMplITEwMhXkGFyEgsT377LM89thjrFq1ihYtWiBJErt376agoIDZs2cDsH///nLdbDCx2Wz06dOHmTNn8tJLL3HixAm+/PLLKicqNggdAe+uKigoYNmyZRw9ehQhBGlpadx5553ldqmHkvz8fJ555hnWrl1LVFQU48ePD8jPZhAeQprA2cCgNAE7dd1uNzt27OD06dPlwojuuuuuqtplcA0SUMu2d+9exo8fT25uLi6Xi5iYGPLy8rBarcTHx/P999+HwFSDmk7A+dl+97vfsWXLFiIiIvjiiy/4z3/+Q9u2bZkyZUqwbTS4RghIbLt372bMmDEoioLJZMLlclG3bl2mTJnCm2++GWwbDa4RAhKbzWbD4/FGgiYlJfmXjiRJ4ty5c0EzzuDaIqAJQvv27dm4cSPXX389vXv3Zvr06WzevJm1a9fSqVOnYNtocI0Q0ATh3LlzOJ1OUlNTUVWV999/nx07dtCgQQPGjx9PQkJCKGw1qOkIAyGEEHl5eWLixIkiPT1d9OjRQyxatKi6TSrH1KlTRatWrUR6err/z6+//uo/vm/fPjFkyBDRpk0b0a9fP7F58+Yy169YsUL06tVLtG3bVowZM0acPn06rPZXqhvdvHlzhc6riV1psAIxQ83999/Pk08+Wa7c4/Ewfvx4/vCHP7Bo0SJWrFjBhAkTWLVqFXFxcWFN5X8pKiW2UaNG+YMmxSV635qYfqHaAjGDyKZNm3A6nTz44IPIsszAgQNZuHAh3333HUOGDCmTyh+8r1Pv3r07x44dCyhbZyBUSmypqalomsbAgQMZMGCAP/KiphOMQMxw8dlnn/HZZ59Rp04dRo8ezeDBgwFv4IPdbi8TKtW8eXP27/fu7g9mKv9AqZTYVq9ezZYtW1i6dCnDhw8nLS2NQYMGcccdd/hT1ddEghGIGQ5GjRrFlClTiIuLY8uWLUycOJGYmBj69OlDUVHRRVP1FxR4N14HM5V/oFTaz9axY0emT5/O2rVrGT16NP/3f//H7373OyZOnIjb7Q6FjSEnGIGY4aBVq1YkJCSgKAqdO3dmxIgRrFy5EoCoqKjLpuoPZir/QAk4PNVisXDrrbcyaNAgmjVrxn/+8x9/mHVNo3Qgpo/KBmJWB7J8PjVr06ZNycjI8O+rAG+4tu87BDOVf8D2BnLR1q1befbZZ+nevTvz58+nf//+rFmzJqyxbMGkdCBmYWEhe/fu5csvv+Tuu++ubtPK8M0331BYWIiu62zZsoVFixb5X9104403YrFYmD9/Pm63m6+//poTJ074jw8YMIA1a9awYcMGnE5nlVL5B0xl/CSzZs0SvXv3Fr///e/Fm2++KQ4ePBgah0w1kJeXJx577DGRnp4uunfvflX62YYPHy46dOgg0tPTRd++fcUnn3xS5vjevXvF4MGDRevWrUXfvn3Fpk2byhz/5ptvRK9evUSbNm2qxc9WqRWE5s2bU7duXTp06HDZDSIXy0hpYFCp2ehdd91VY/J3GFx9BCUs3O1243a7y7wOyMDgQio1QfB4PMycOZNx48Yxd+5cNE1j+vTptG/fnk6dOjFq1CgyMzNDZatBDadSLdvf//53Vq1axW233caGDRtITk7m9OnTPPLIIyiKwrvvvovdbueVV14Jpc0GNZXKzCZ69uwpfvzxRyGEECdPnhTNmjUT69at8x/fsmWL6NGjR7AmLwbXGJXqRjMzM/27z+vWrUtERAT169f3H2/YsKERqWtwSSolNl3Xy6RHkmW5jAuk9MtnDQwupNJh4fPmzSMyMhLwThgWLFjgX4QvLi4OrnUG1xSVmiCMGjWqQud9/PHHARtkEDyaNWvGhx9+6I9hq3aqecxY4xg5cqSw2+3CbreL5s2bi5tuuklMnz5duFwuIYQ3dNtut4u33367zHW6rotevXoJu90uNm7cGBZb7Xa7WL9+fVjqqghGUrIAuO+++1i3bh3ff/89M2bMYNWqVcydO9d/vE6dOixbtqzM+HXr1q016G03ocEQWwBERkaSlJRESkoK3bp147bbbisTCt+xY0d0XWfr1q3+siVLljBgwIAy98nKymLixIl0796ddu3aMWLEiHIh9Rs2bOD222+nTZs2jB07lvfff79Sr9w+ffo0999/P23btuXuu+8uE2a0bds2Ro0aRceOHenSpQtPPPEE2dnZlX0cFcYQWxU5deoUGzZs8OcWBu+s/M477/S/497lcvHtt98ycODAMtc6nU46duzI/Pnz+fLLL2nSpAnjx4/3xwXm5+fz6KOP0qNHD5YsWUKvXr344IMPKmXf3LlzGTlyJEuWLCE5ObnMW5QdDgfDhg1j8eLFzJs3j1OnTvH8888H+iiuTHX34zWNkSNH+rfTtW7dWtjtdjFmzBjhdruFEN4x2+TJk8WBAwdEx44dhcvlEsuXLxdDhw4VHo/nsmM2VVVFenq6PzRo0aJF4uabbxaapvnPeeKJJ8Tvf//7Ctlqt9vF+++/7/+8bds2YbfbRWFh4UXP/+mnn0TLli2FqqoVun9lMV4nFABDhgzh/vvvR9d1Tpw4wcsvv8xLL73Ec8895z+nSZMmNGrUiNWrV7NkyZJyrRp4XUezZ89m1apVZGZmomkaxcXF/tdmHjlyhObNm5fxZd5www2Ver9E6Ujc2rVrA5CdnU1UVBSnT5/mjTfeYNu2bWRnZyOEQFVVsrKySElJqfRzuRKG2AIgNjaWRo0aAZCWlkZBQQFPPvkkU6dOLXOebzvd3r17LxrjN2/ePL766iueeeYZ0tLSiIiIYMiQIf6JhLjw7cgBYC79PnrfG2lKQseffvppPB4PL774IsnJyZw4cYKHH37Yn8cl2BhjtiCgKAqappX7kfr168euXbvo0aPHRZNZb9++ndtvv50+ffpgt9uxWCzk5eX5j6elpbFnz54y+wp27doVNLu3b9/OmDFj6Nq1K02aNCEnJydo974YRssWAMXFxWRmZiKE4Pjx4/zjH/+gQ4cO5fZgJCQksH79eiIiIi56nwYNGrB27Vp++eUXAF555ZUy59555528+eabzJgxg2HDhrFlyxbWrVsXtB1RDRo0YOnSpTRt2pSjR4/y3nvvBeW+l8Jo2QJg4cKF9OjRg549ezJp0iSuv/563nrrrYueGxcXd8l3BEyYMIH69eszfPhwHnvsMYYOHVqmBYyNjWXOnDn88MMPDBw4kH//+9+MGjUKi8USlO/x4osvcvToUfr378/MmTN5/PHHg3LfS2EkcK5h/OUvfyEzM5P333+/uk2pNEY3epXzxRdf0LRpU2rVqsX69etZunRpSN+cE0oMsV3lnDp1ilmzZpGTk0P9+vX5y1/+Qv/+/QFo167dRa+pV68ey5cvD6eZFcLoRmswR48evWi5yWQiNTU1zNZcGUNsBmHDmI0ahA1DbAZhwxCbQdgwxGYQNgyxGYQNQ2wGYcMQm0HY+H+Nwn3hnwlWdgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "for target in target_vars:\n", + " for group, group_df in dff.groupby(\"method\"):\n", + " print(group)\n", + " sns.set(\n", + " rc={\"figure.figsize\": (3.28125003459, 4)},\n", + " context=\"paper\",\n", + " style=\"whitegrid\",\n", + " )\n", + "\n", + " f = sns.lmplot(\n", + " x=target,\n", + " y=f\"{target}_pred\",\n", + " hue=\"C_qfrac\",\n", + " height=3.28125003459 / 2,\n", + " data=group_df,\n", + " fit_reg=False,\n", + " facet_kws={\"legend_out\": False},\n", + " )\n", + " # f.ax.set_title(group)\n", + " plt.tight_layout()\n", + " break" + ] + }, + { + "cell_type": "code", + "execution_count": 125, + "id": "f09b9d03", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BMag_ha\n", + "linear\n", + "RF\n", + "KPConv\n", + "PointNet\n", + "\\power{}\n", + "MSENet14\n", + "MSENet50\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXMAAADhCAYAAAA6Y1VuAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8o6BhiAAAACXBIWXMAAA9hAAAPYQGoP6dpAABtBUlEQVR4nO2deXgUVfb3v1W9ZF+AgEISSEjSwYVshFVENkF+g6CMQVFxxGFUcARRZxDFYVBUdFwRRIVRUUZ5RRBwIyCKrCHsCRDIRpbOHrJ10nvVff+orqI76U46nb1zP8+TB7qq+tY91d3funXuuecwhBACCoVCofRo2K7uAIVCoVDaDhVzCoVCcQOomFMoFIobQMWcQqFQ3AAq5hQKheIGUDGnUCgUN4CKOYVCobgBVMydhOd5aLVa8Dzf1V2hUChtwF1/y1TMnUSv1yMjIwNarbaru9Lh8DyPtLQ0t/uyO4La677Ys1H8Lev1+i7oUcdBxbyV9IYFs4QQmEymXmErQO11Z3qDjSJUzCkUCsUNoGJOoVAobkCPEfOXX34Zt99+OxISEjB58mR8/PHH0r7MzEzMnTsXsbGxmDlzJk6dOmXz3r1792LKlCmIi4vDY489hrKyss7uPoVCoXQoPUbM//KXv2D//v04c+YM/ve//2HPnj345ZdfYDKZsGjRIkydOhUnT57E3/72NyxevBi1tbUAgJycHKxYsQKvvvoqUlJSMGTIEDz33HNdbA2FQqG0Lz1GzCMjI+Hp6Sm9ZlkW+fn5SE1NhV6vx8KFC6FUKjF79myEhIRg3759AIA9e/ZgwoQJGDduHDw9PbF06VKcPXsWBQUFXWUKhUKhtDvyru5Aa3jnnXfw1VdfQafTITg4GLNmzcK+ffugUqnAstfvS8OGDUNWVhYAwQUTExMj7QsMDMTAgQORmZmJwYMHt7oPPM+D47i2G9ONEe1zdztFqL3uC8dxkMlkDvf1lGvgyAZrepSYP/fcc3j22WeRnp6OAwcOwN/fHw0NDfDz87M5zt/fHxqNBgCg1Wrt7m9oaHCpD9nZ2U4dxzCMS+13FxiGwYULF7q6G50Gtdd9SUhIsLs9MzOzk3viOiNGjGjxmB4l5oDwJYyJicHhw4exfv163Hjjjaivr7c5RqPRwMfHBwDg7e3d7P7WEhkZCV9f3xaPs35S6GkQQmAwGODh4dHjb0rOQO11X5qLM1epVPD29u7E3nQsPVZxOI5Dfn4+oqKikJmZabPSKyMjA1FRUQCED+zy5cvSvtraWpSUlEClUrl0XpZlIZPJWvxjGKbH/K1fvx4PPvigzTaz2WzzetiwYTh+/HiX97Wj/hrb21P+UlNTMWzYMHAc1y72rlixotn3NTQ04J577sEtt9yCP/74o9lji4uLMW/ePMTFxWHDhg1ttnX06NEYO3YsXn/99Ravyffffy+9doQzv+Pu8ueUNrmkaJ2MRqPBrl27UF9fD57ncfr0aXzzzTcYN24cRo0aBaVSic8++wxGoxE//PAD1Go17rzzTgDArFmzcOjQIRw/fhx6vR7r1q1DXFycS/7y3syRI0eQmJjYaefbsWMHJk+ejOHDh2P+/Pm4evVqs8dv27YN8+bNQ2xsLCZMmNBkf3l5OZ555hlMnjwZ0dHR2L59e5NjysrK8Oyzz2Ls2LEYMWIE/vGPf6Curg4AsHHjRixfvrzFfs+fPx/R0dH47rvvbLZrtVrEx8cjOjoaarW6xXaMRiPi4+NRWFjY4rGdhclkwtNPPw2FQoGnnnoKy5YtQ1pamsPjv/32W1RWVuLbb7/FggULmm27pc8PAH788Ue88MIL+PLLL5GTk9MmW9yRHiHmDMPg+++/x6RJkzBixAi89NJLWLBgAR5++GEoFAps3LgRycnJSExMxMcff4wNGzYgMDAQABAREYHXXnsNK1euxOjRo3H16lW88847XWtQD6R///5QKpWdcq7jx4/jX//6F5588kns2LED/fr1wxNPPAGj0ejwPQaDAVOmTMG8efPs7jcajRgwYACWLVuG/v37N9nP8zz+/ve/o7a2Fp9//jm2bt2K0tJS/POf/wQATJo0CYcOHXIqn8mNN96I3bt322zbt28f/P39W3yviFKpxLhx4/DHH384/R5X0Wq1+Ne//oVJkybhxx9/xPTp07FmzZomx7300kvQaDT4/PPP8dRTT2Hp0qVYtGiRwxtOeXk5hg8fjujo6Bbdmi19fgAwYMAA/OlPfwIAVFRUNNmfn5+Pxx9/HEuWLMGaNWtwzz33YMeOHc2e160gFKdoaGggp06dInV1dU6/5+GHHyZr164lL730EomLiyOTJk0iBw8eJCUlJeQvf/kLiY2NJffffz9Rq9U279uyZQuZPHkyiYmJIXPmzCEpKSnSvuzsbLJw4UIyatQoMmLECLJw4UJSUFAg7U9JSSEqlYocO3aMzJgxg8TFxZFFixaRmpoah/1ct24deeCBB8imTZvImDFjSGJiInnjjTcIx3HSMSqVihw9epQQQkhFRQV5+umnybhx40hcXBx58MEHyaVLl6Rj9Xo9eemll8iYMWPI8OHDyfTp08n+/fudvm5PPfUUefbZZ6XXDQ0NJCYmxqk2duzYQW6//fZmj5k0aRL59ttvpdc8z5P09HSiUqlsPosrV64QlUpFcnJyCCGE3HHHHeTcuXPNtv3www+TV199lcTFxZGioiJp+6OPPkrefvttolKpSGFhobT9888/J2PHjiUJCQnkjTfeIM8++yxZvnw5IYSQb7/9lvz1r39t9nzOfN7ffvstmTVrFomNjSUTJ04k7733HqmqqiI8zxNCCHn33XfJpEmTSEpKClm8eDE5duwYWb9+vc153n77bXLfffc1+f5v2bKFTJs2jVy7dq1J35YvX06ee+65ZvvfGGc+P+vvojX3338/WbBgAdm2bRtZv349SU5OJtu3b29ynPhbbmhoaFXfujs9YmTek/n2228RFRWF77//HnfccQf++c9/4qWXXsJf/vIXadSwdu1a6fjvvvsOX375JVatWoUff/wR99xzDx5//HHp0Vyr1WL69On4+uuv8fXXX0OhUODZZ59tct6PPvoIa9euxZdffonMzExs3Lix2X5evnwZ586dw5dffolXXnkF3377Lb7//nu7x+r1eiQmJuKzzz7Dzp07ERERgUWLFsFgMAAAvvzyS1y8eBGbNm3CTz/9hBUrVtiMzKKjo7Fz506HfUlLS8OYMWOk197e3oiJicH58+ebtaEtmEwmALBZy+Dl5QUAOHv2LABgwoQJOHjwYItt+fj4YPLkydizZw8AwX1z7tw53HXXXTbHpaSk4O2338ayZcuwfft2mEwm/P7779L+O+64AydPnoROp2vxnM193oQQLF++HD/88AP+/e9/47vvvrP5bC9fvoypU6di9OjR8PPzw9ixY/HUU0/ZtP/cc89h+/btTSLDHnnkESQnJ6Nv375N+mQ0GqFQKFrse2uRy+V2n9KuXLmCBx98EGFhYRg4cCCmTZuGP//5z+1+/u4KFfMOJiEhAX/5y18QFhaGxYsXo6amBuPGjcOkSZMQERGB+fPnIzU1VTp+48aNeOmllzBhwgSEhoZi/vz5GDFihCQMw4cPx3333YeIiAioVCqsXr0aaWlpKC4utjnvP/7xD8TExGD48OFISkqyOYc9eJ7Ha6+9hqioKNx11124//778b///c/usSEhIXjkkUcQHR2N8PBwrFq1CrW1tZL/tLS0FDfddBNuvfVWhIaG4o477sDYsWOl94eHhzcRBWuqqqrQr18/m219+/bFtWvXmrWhLQwZMgQ33ngj3nnnHWi1WtTX1+O9994DAFRWVgIAJk6c6LTbY/bs2ZKrZffu3Zg0aVKTKKhvvvkGd911F5KSkjB06FC8+OKLNq6YAQMGICIiAsePH2/xfM193nPnzsW4ceOkz2L+/Pk4cOCAtD82NhbJyck229pKRUUFzp49i7CwsHZrU2Tw4MH47bffYDabbbbHxsZiy5YtyMjIaPdz9gSomHcw1lEzQUFBAITwRpF+/fqhpqYGHMehoaEBarUay5YtQ3x8vPR34sQJyS+p0WiwevVqTJs2DQkJCZg2bRoAoKSkpNnzVlVVNdvPwYMHIyAgQHp9yy23OJx0NJlMePfddzFjxgwkJiYiMTEROp1O6sPs2bORnJyMOXPm4N13320Sz7x3715pgrq7oFAo8P777+PMmTNISEjAmDFj0K9fPwQFBUkREePGjUNOTg7Ky8tbbO+2226DRqNBWloa9uzZg9mzZzc5Ji8vD7feeqv0WiaTYdiwYTbHTJw40amngeY+7zNnzuCxxx7D7bffjvj4eKxfv94mP9Hf/vY33H///Xjrrbewe/duzJkzBz///HOL53TEv/71L4wfPx4DBgzAo48+arPv1KlTNt/txnmUnOHVV1/FTz/9hJiYGJv3v/322xgyZAg+/fRT/Pvf/8Zf//pXXLp0yWU7eho9Ls68pyGXX7/EoihYP3qK2wgh0uP022+/LYVWiohuirVr1+L8+fN48cUXERISArPZjNmzZzcZpTQ+b0sTd82FcDVm06ZN+P7777Fy5UqEh4fDw8MDSUlJUh9iYmJw4MABHDx4EIcPH8a8efPwzDPP4K9//atT7dsbhVdVVXV4BFJcXBz27t2LqqoqKBQKyOVybN26FSEhIQAEF8yoUaPwxx9/ICkpqdm2ZDIZZs6ciTfffBPV1dUYP358kygW4kSu7YkTJ+Lpp59u8ThHn3d9fT2eeOIJzJgxA0uWLEFAQAB++OEHGzeXQqHA4sWLsXjxYixatAijRo3C888/jwEDBrgUwbRkyRJMnToVzz33HPbs2WNzrW699Vbs2rVLen3DDTe0uv13330XcXFxeP755xEeHi5tDwoKwpo1a3DixAmkpKSgtLQUCxYswG+//ebU2pCeDh2ZdyP69euH/v37o6SkBEOGDLH5E0f158+fx3333YeJEyciMjKyyYIoV8nPz5fC8ADg0qVLNj8Ua86fP4+77roL06dPh0qlglKplBKbiQQGBuKee+7BO++8gyVLlrQqqiAmJgYnTpyQXut0OqSlpSE2NraVVrlG37594efnh+TkZCgUCowbN07a5+xIGQDuuecenDp1CjNnzrQbKxweHo6LFy9KrzmOs1kTAQhuNbPZ3GS7s1y9ehV1dXV4/vnnERcXh/DwcJSWljo8PiAgAAsWLEBUVJTLcxRBQUGYMGECxo4di3Pnztns8/T0tPleW89ROEtaWhrmzZuHm266yeH7Q0NDsWLFCtTW1rYY1uouUDHvRjAMgyeeeAIffPABduzYgYKCAqSnp+PTTz+V/KahoaFITk5GdnY2Tp06hbfeeqtdzs2yLFauXIns7Gzs27cP27Ztw4MPPmj32NDQUBw+fBgXL17ExYsXsXz5cnh4eEj7v/jiC/zyyy/Iy8vDlStXcPToUZsbw1133YX9+/c77MtDDz2EX375Bdu3b0dWVhZefPFFDBgwwCb+uHEbFRUVyMjIQHFxMcxmMzIyMpCRkWEzUWa9raSkRDpe5Oeff8apU6eQn5+P7du3Y/Xq1ViyZIkU5goIYn7s2LFmwyRFhg0bhpSUFIdZOufNm4dffvkF3333HXJzc/HGG2+grq7O5imJYRinJ17tMWjQICgUCnz99dcoLCzEN998g19//dXmmPXr1+Po0aPQaDQwm8347bffkJubi5tvvtmlc4p4e3tLk+It4cznJ2Iymeyu3Hz55Zdx8eJFGAwGaLVafPnll/D29u4Qv313hLpZuhnz58+HUqnE5s2bsWrVKgQGBiIuLg5Tp04FALzwwgtYvnw55syZg5CQEKxYsQILFy5s83mHDRuGW2+9FQ899BA4jsN9992HOXPm2D128eLFyMvLw4MPPoh+/frh2WefRV5enrTfy8sLH330EQoKCuDp6YkxY8Zg5cqV0v6rV69KuXPsMXbsWKxevRofffQRKioqEBsbi08++cQmzr1xG9u2bcP69eul1/fccw8A4MCBA5KbRNwGABs2bMCGDRtw77334o033gAgzDu88cYbqKmpQUhICP75z382iXsODg5GSEgITp48idtuuw0vvPACioqK8NVXX9m1pU+fPg7tHDNmDJ5//nm88847MBqNSEpKwrhx45pEgEycOBGff/45nnzySajVakyZMgVffvklRo8e7bBtkX79+uGVV17B+++/j48//hjjx4/H448/jq1bt0rHhISE4IMPPkBOTg60Wi3OnDmD5557zmbS2hUYhnG6bJsznx9wvaanvXQZ/v7+eP7551FUVASe5xEREYEPP/yw2cl2t6JrIyN7Dq7EmfdUeJ4ndXV1Uhyyu9Nae//zn/+QNWvWEEKEuPJ169a1Wz+mTZtGNm3aZLNdo9GQ4cOHk6qqKnLixAmSmJjY7LoBZ87jyF4xxr09eOedd8jdd99NtFptu7V56tQpolKpSHZ2tsNjUlJSyI4dOwghxK6NNM6cQqEAAObMmQOVSgWtVovCwkI89thjLre1efNmZGVlITs7G6+++iqKi4ubxKP7+vrixRdfRF1dHY4ePYonnnjCJvKouzJ79mxUVlYiPj7epjKYq4wfPx4PPvigFNZLsYUhpBeVr24DWq0WGRkZUKlUbv/YRghBfX09fH19WxXl0lPpSnv/9re/IS0tDUajEVFRUfjnP//Z4TlwOtNenudRXl4OLy+vNt+A1Go1/Pz8WtUOIaSJjeJv+aabbnKrrInUZ06xi3WoW2+gq+zdtGlTl5y3s+xlWRY33nhju7Rl7TunNIW6WShN4HkeGRkZDmPTOZ5gf2o+Nu1Kw/7UfHB8z364a8led6M32dsbbBTpXcMvitM05337NTUfXydfAcfxOHK+BDxPMH1MWOd1rgPobd7G3mZvb4COzCmtJreoFhzHo38fL3Acj9yi2pbfRKFQOhQq5pRWMzQ4ADIZi4pqHWQyFkODu39kBYXi7lA3C6XVTB01BIAwQh8aHCC9plAoXQcVc0qrkbFMj/eRUyjuBnWzUCgUihtAR+YUSg+F4wl+Tc23cXfJWPdf5EWxDxVzCqWHYh0imnJBSGtL3V+9lx7hZjEajXjppZcwefJkxMfH409/+pNURg0AMjMzMXfuXMTGxmLmzJlNqpfs3bsXU6ZMQVxcHB577DGbKisUSlvheILklDxs3HEeySl5nbaIioaIUqzpEWJuNpsxYMAAbNmyBadPn8bq1auxevVqnD17FiaTCYsWLcLUqVNx8uRJ/O1vf8PixYulYgk5OTlYsWIFXn31VaSkpGDIkCEO80tTKK4gjpCPni/G18lX8Gtqfqecl4aIUqzpEW4Wb29vLF26VHqdmJiIhIQEnD17FlqtFnq9HgsXLgTLspg9eza2bNmCffv2ISkpCXv27MGECROkajFLly7FbbfdhoKCApfKkPE8D47j2s227ohon7vbKdJWe7MLa2DmePQP9EJFjQ7ZhTWYOrLjr92kESHgeR65RXUYGuyPSSNCnLKhN32+HMfZrfIk7usp18CRDdb0CDFvjFarxYULF/DII48gKysLKpXKJln9sGHDkJWVBUBwwcTExEj7AgMDMXDgQGRmZrok5tnZ2W03oIeQnp7e1V3oVFy1V8HXg/BmlFRqIGMBBV/bpFxaRzHAAxgwFACqkZ5W3ar39pbPd8SIEXa3Z2ZmdnJPXMeRDdb0ODEnhGDFihWIiYnB+PHjkZaW1iQlrb+/v1SFRqvV2t3f0NDg0vkjIyPdvjgsx3FIT0/H8OHDnRoR9HTaau/wGILQ0AJphDxl5OBuHVXSmz7f5kbeKpWKpsDtKgghWLVqFcrKyvDZZ5+BYRj4+Pg0KWqs0Wikavbe3t7N7m8tLMu6/Q9ARCaT9RpbAdftlcmAGeOGdkCPOpbe9vk2xt3s7xEToIAg5KtXr8alS5ewefNm6Y4aFRWFzMxMm1SXGRkZiIqKAiDcfa0rm9fW1qKkpAQqlapT+99VEQ8U94J+jyiO6DEj81deeQXnz5/HF198YePmGDVqFJRKJT777DM88sgjSE5Ohlqtxp133gkAmDVrFpKSknD8+HHEx8dj3bp1iIuLc8lf3hZ6S0wwXcjSsfSW7xGl9fQIMS8qKsLXX38NpVKJiRMnStufeOIJPPnkk9i4cSNWrlyJdevWITQ0FBs2bEBgYCAAICIiAq+99hpWrlyJyspKjBgxAu+8806n22AdE1xRrXPbmGAqNh1Lb/keUVpPjxDz4OBgXLlyxeH+6OhobN++3eH+GTNmYMaMGR3RNacZGhyAlAulbh8TTMWmY+kt3yNK6+kRYu4O9Ja0sY7Ehrpf2ofe8j2itB4q5p1Eb0kb60hsqPulfWj8PRInROlNkkLFnNKuOLpp9Xb3S0c9mdCbJEWEijmlU+jtvt6OEt3efpOkXIeKOaVT6O2+3o4S3d5+k6Rch4o5pVPoLXMGjugo0e3tN0nKdaiYUyidQEeJbm+/SVKuQ8Wc4jI03NB5qOhSOhoq5m5EZ4srjaSgULoPVMzdiM4W194USUGfQijdHSrmbkRniKu1qOkM5l5TtsydnkI4nuB0dj1SctMRGRpIb0xuQpvEnBCCiooKmM1mm+2DBg1qU6cortEZYWrWosayDOJU/eHlIe+wSIruMiLOVtdAazBBKZdBazAhW12D6Z3ei/a5HgdOFuBgeh0YVovUS0Jx8556Y6JcxyUxr66uxiuvvIL9+/fbreSRkZHR5o5RWk9nhKk1Hv17ecix6M+x7X4eke4yIjYYOeHPwAGM8LoraOl6OCP2uUV14HhgYF8vVNa4t3usN+GSmK9Zswbl5eXYunUrFixYgA8++ABVVVX49NNPsWzZsvbuI8VJOiNiorMXqXQXv7ynUgZPhQwKBQuTiYensmsq1LR0PZy5+Q0N9sfR80B5tRYGI4eUCyXQGcz4+9x4KOU9pl4NpREuifmxY8ewefNm3HLLLWAYBqGhoZgwYQL69u2LDz/8UCoMQXE/OnuRSndZ4RgREogTF8vAcTy8PBWICAnskn60dD2cuflNGTkYhYWFOJVrRL3WhKo6Aw6eUQMAnn2w5cLBlO6JS2JuNpvh7+8PAOjbty/Ky8sRHh6OsLCwHlXxmtJ6OjteuruscOwp/XDm5idjGYyI9MW5/BoAgIdCBqOJQ46ault6Mi6J+U033YSLFy8iNDQU8fHxWL9+Perr67F7926Eh4e3dx+7Nd1lgq6r6Gj7nbl5dMZn0FI/OJ5g34k8HDpbBBBgQkIIpo3u/H605qYTERwAdXk9jCZhHiAixH2jkXoDLon5smXLoNVqAQDPP/88li9fjueffx6DBw/GmjVr2rWD3Z3uMkHXVXQH+7tLH7748RJ0BjNAgKvFtWCZzu9Ha56cFt8XC4ZhkKOuRURIAP4+N75jO0fpUFwS8/j46x/6DTfcgC+++KK9+tPj6C4TdF1Fd7C/u/TBZObBMAwYBjCa+W7/XVDKWeojdyPaNHWt0+mgVqtRWFho89febN26FXPmzMGtt97aJFomMzMTc+fORWxsLGbOnIlTp07Z7N+7dy+mTJmCuLg4PPbYYygrK2vXvg0NDnCbhTNi1ZqPd6bjdHY9OJ60+B5n7Rfb3rjjPJJT8pxq21m6w2cwNDgACjkLQgh4nkAp79nfBUrPw6WR+ZUrV/Diiy/i0qVLAITFQwzDSP+2d5z5gAEDsHjxYhw7dgzV1dXSdpPJhEWLFuH+++/H1q1b8csvv2Dx4sXYv38/AgICkJOTgxUrVmDDhg1ISEjAm2++ieeeew5bt25tt751l4mx9kB0V5g5HoQ3IzS0ADPGDW32Pc7a35GukO7wGUwdNQQ8ITY+c1f70dvnYSiu4ZKYr1ixAgMGDMA333yDoKAgMEzHftGmTZsGQFiMZC3mqamp0Ov1WLhwIViWxezZs7Flyxbs27cPSUlJ2LNnDyZMmIBx48YBAJYuXYrbbrsNBQUFGDx4sEt94Xm+yUKpqSNDgZGhwgvCw846qh5BdmENzByP/oGeKKmsR466xu6isMY4Y//1tr1QUaNDdmENpo5seiDHExw4WYDcojoMDfbHlJGDbYTM0f62fAaijc7Y2hzTRg3GtFFW3ysXvwv7TuRj2/5MmDmClAsl4Hke00a33w2qveztCXAcB5nM/poAjuN6zDVwZIM1Lol5bm4u3nvvPQwZ0rWj0KysLKhUKrDsdW/RsGHDkJWVBUBwwcTExEj7AgMDMXDgQGRmZros5tnZ2W3rdDdGwdeD8GaUVNZDxgJKosG5c+eaHMfzBGdzG1BSbcLAPgrED/UB28LI8XrbGshYQMHX2m37dHY9DqbXwcwR/H4a+OnwFcSEeUvnEPdzPHD0PFBYWIgRkb7tYn96enq7tOPK9bEmNa0aeoMJAT4y1DaYkZqWiwEe1S2/sZW0l73dnREj7M8L9KQwakc2WOPyBGhubm6Xi3lDQwP8/Pxstvn7+0Oj0QAAtFqt3f0NDQ0unzMyMhK+vu0jHt2N4TEEoaEFyFHXQEk0mD97DJSKpl+RfSfycfRyJcwcQW6ZGaGhoS2OHMW2HY24RVJy08GwWvh6yFBZq0NxlRkavU46h7h/YF9hhG9iAxAXN7xNdnMch/T0dAwfPtypEVBLiNfHZOZxqdCA3EoWd8QHO7S5MeWGfOSWZaLBQODpocComKGIi2vfkXl72tvS01RX0tzIW6VSwdvbuxN707E4LebHjx+X/j9r1iy8/vrruHr1KqKioiCX2zYzduzY9uthM/j4+KC+vt5mm0ajgY+PDwDA29u72f2uwLJsu/wAWoszftS2+lplMmDGuKHgOA7nzp2DUiG3a2teiQYcRzDAEj2SV6Jp8ZqIbbdk19CQAKReKkNNvQEgQKCfB4xGTjpHZGggUi+VobJGB7mMxdCQQPx6srBd/MsymaxdPturxXXQ6c0gAPRGDlfyq1FSqQXLsk7NE0wbEw6WZTvcZ95e9v56Mg/b9meB43ikXipz2s6uZuOOdPzz0XFd3Y12w2kxX7BgQZNtb731VpNtHTEB6oioqChs3rwZPM9LrpaMjAzMmzcPgHDnvXz5snR8bW0tSkpKoFKpOqV/7YkzE4idFW/dnkvsG/f5gWnReHB6NA6eUSNXXQNNgxGEEOgMZnA8aTLZyROCr5Mzu1Wcv97IQW/iQCwBO54ecnCc86GKzsaKd5eJ0u4QGkpphZhbi2JnYzabwXEczGYzeJ6HwWAAy7IYNWoUlEolPvvsMzzyyCNITk6GWq2WcsPMmjULSUlJOH78OOLj47Fu3TrExcW57C/vSsQfTFAfLxRV1GP3oRxczL0GT6UMESFCTuq2/KishSFsoB+CFI5DB9szeqRxn/OKa7Hoz7GYOmoIPth2BsfSS8CwDM5nVeLX1HxMHxNmI3Qbd5zvdkLioZTBw5KIS28Qsi36eSvbPVSxOyyWArpP/pzW8ve5cV3dhXalQ4tT3H333fj0008xcODANrWzceNGrF+/Xnq9d+9e3HvvvVi7di02btyIlStXYt26dQgNDcWGDRsQGBgIAIiIiMBrr72GlStXorKyEiNGjMA777zTpr50FeIPpqiiHgYjh5JrDVCX18NTIcOJi2U2x7jyo7IVBga3DfNCQoL9Y9szP4ujPstYBl4ecngqZHaFWrz5FJRpYOZ4lFdb3C7dQEgiQwKRerEMZo6HjGEQ4OuBvgGe4AkBx5N2Gz13lxFxdwgNpXSwmKvV6iaFK1zh6aefxtNPP213X3R0NLZv3+7wvTNmzMCMGTPa3AdnaMtjb0vvFX8gew7loqJaC7mcRb3WBIWClR7hH79XiNxx5UdlLQzl1TqUVJtc6m9zdtjb15wQNHdzkmLizRwIgME3+GHiCNdju13FXk6WKSMHSzbpDGacy6xAYakG2/ZlgmU6/kbY2fTUYtXrvz2HGi3B64tv6+qutAu0bFw70pbH3pbea/2D+Tr5CrQGE8AAJpOQknVocECbklJZC4NcxmBgH4VL/W3ODkf7HPW5OaEXbz4D+nqjolqHwTf6dYmgtJSTZeOO8+B50iGjZzoiplhDxbwdactjr7PvFX+w2eoaGIycjc/cGRwJqrUwCD7zKpf625wdOeoa6PTC00Sd1oDdf+RINtl7gmnu5tTcqLQzJwZbysnSkaPnjhwRd5fJVYrzUDFvR9ryw3X2veIP2NX6k47E1loYhNBE20UqjX/cYYP87fa3sR1hgwKQnJKH3KJa5BbVQm/ioLOUXCuprMdnP1zEwdNqTEgIBsAgr9g58WhuVNpRE4P2BE7MyWI2mEGIUJHI+rNzZvTcVuHsCOHtLpOrHY27uFiADhbzjl7m391w9MN15sfWWY/MLd00BB9wPlLTqlFuyMeUUWH4/VQBDp5R42pRLWQy1iaEUHxCyFHXIDklD5MSr/uLG4cO6gwmsAwDAgKeADIZC53BjCsF1cgtrgUDQG5pH2hePJoblXbUxKA9gWspJ4szo+e2Cqf1/MHBM2ocPK2W5g9cFfXuMrlKcZ4OFXMiBtr2Ehz9cH9Nzcf/9l6G3mDGgZMFuJh7DUsfSLD5oXXWJFJLN41fU4W8IHqDCbllmcjIq8b5rEpotEaYzUJopEZrxI+HczFrwlAMDQ7Atn2CWItRNY5CB/NLTDbZEo1mXloYVKMxAAAGBvk4LR7O+P/b07VhT+BkLIMZY8MxY6zrRVkah53uOZQLwLH7ydH7PZRyaGp1uFJQjeJKYZWzq9+p7jK52tG8+NFRtxmdd6iYnz17tiOb73KcfbzNLaqF3mCG0cyD5wmOpZfglqFCzLSrj8iuvq/xTUNMTSu2k62ugZkjUMgYaLQmnM+qgNnMI9DXA5W1OlTV6sETgopqLb5OvoJB/X2aHcFZiwLDMlAwQsihzmCGn5cCBpMQh62Qs2AAG/FoyUZHI9pJiYNxMfeaVHRBfFpw6pqezHN4vsYCpzOYsXHH+WavvzOfU+OwU/HaXsy9Bi8PeYufr/j+xqtm2zKappOrPQ+XxHzy5Ml2XSgMw0CpVGLw4MGYNWsW/u///q/NHezOWKeM/eOMGgfPqDExoenj7dDgABw4WQCeJ2BZYaJM/KG58ojN8URYUJNWDJZhcDy9xKn3NWeDeP7YqCBwHI8GPQeGAeq1JsjlLGA0S3HfOoMZg/r7oLJGD8ITmDge+aUayGUMtHqTjcBZi4LOYMb5rEpwHA8/byUemBYN1nItwgb5o7HPvKVr48gV8PupAuk857Mq8fupAofXxtqt9FvGOZy6VAoTJ+Qj5wkwY+z19zW25VxmBXieNPu5OfP5Ng47HdTfB0WVDTiWVgxPpbzF74X4ftEVZjBybY6578xww66ebHWX0blLYj5v3jxs3rwZ48ePx/DhQpKj9PR0HDlyBI888ghKSkrwwgsvoL6+HnPnzm3XDncnrj/eylBRbcTF3GvILaoFT4jNY/fUUUNwMfeasJqRAbw9FNIPzRXf5K+p+TiWXgKjibdk4zO7vNqzoFRYdCPmWfFUyhA20B+X86vQ198LeqMZQ270x+Ab/Sw+cGDbviuorNFDJmPRv4838krqhHY5HsfSS0AIwDDAwdOFmDgi1KkY9OaurzOjfmGy1R/JKXnYcygXOr1JuuE4ujYcT/D+N6dx5HwxeJ4A0IInBAwDmMw8dv2RbVPH01rgnA05dObzbRx2WlmjByEAyzAICvREcUVDs64X8f3iDdCZ0TRvuYnllWi6PFqlt0y2djQuifnp06fx/PPPIykpyWb79u3bceDAAXz88ce45ZZbsGXLFrcWc1FMqmr10jadwYxDZ4tsxFzGMlj6QAJuGdr0h+bIN9mc8OUW1YJhAJZlwPMEPCEur/Y0cbyNeyMiJBDhg/xRUFoDg4mDQi7DxBEh1yNdeCKNpkW3jFzGYmCQD64W14LjBTEkBMjIq0JxpVArdvqYsFaP9qyvDcsyklsjbFAAAIKrRbWIjQqCh1KGyJBA8OR6DL7BxKG4okGKwXd0HQ6fK7by4wv/ilM95VVaKYVAc31rzqfcGt+zvZF/cUUD9CYOZdUNUuSPo8nN1lzfs7kNOHq5EhzX/JNFZ0AnW9sHl8Q8JSUFL7zwQpPtI0eOxGuvvQYAGD9+PNauXdu23nVzxB/f13svo0pjAAOLENiZ97UJ/bPOgzIoAA9MUyGvuM5G5JsbrQgCUQINJ6zSDBvoj0mJg52uEG+z2rNKi8ED/TH4BmHkPSlxMPafuIq+fnL4+vrijoTQ5qMzUvKQerEM6op6iJooiiHDMNBojTh4Rm1jlyurQ61dNH+cUYMAUMhYyGQsHpwejeljwqTJ1uD+viiqqEf/QG/MmjDU4Qi1OdFgGEAmYxwe46xfvjW+Z3vfkT2HclFeo4WflxKVTkxuOrMyN2ygH0qqjDBbZb7sSgHtLZOtHY1LYj5w4EB88803WLFihc32b775RsrDUl1dLeVIcVfEHx9PgC0/XoTRzEMpZzEhIaTZ9zUW6genR2PRn2NtjmlutCK5bdKKIZezKK/W4fdTBQDQZDUiAJuRtBgbLa32lMswMeH6yDs5JQ//b38mGnQmMNU1GNDHu9lc5dZ9kTGA2epGZuaEO9vlvCp8sO0MbgrvK0W+WIf2/Zqaj4On1cgtroWcZeyuDrWJiikV8tUPahT5ItpVWa2Dt4cCsyYMbXa0OTQ4AEoFC51BiHtXylmEDfRDfll9E3dYYw6cLEDqxVIYzTyu1epw4GSBjX9dxFXfc2PXi7OTm86tzGUQ0peFXMZ0CwHtDpOt7uA3d0nM//Wvf+Hpp5/G/v37cdNNN4FhGFy6dAkajQYffvghAKEKUGM3jDvCWdwcAb4e0BvNuDUyCDzhpUnASYmD8fupApsvqj2hdmZRjvUxFdU6eCjlTUZW4mpEAgKt3owvfrwIBkJMtzhJOyE+RHoaCBvkD55A6m+2ugY6AwczJ9hlHXljDykhllKOfoGeyCupA8/bHmPmCI6mFeNKQTU0WiMCfT1gMAp+flFkNFojTGYevl4Km9G8OHK3vgEp5SwIAYrK68E3kx63JVEQY8T/OK1GfUM9ZoyPxp2jw5p8XvY4dEYNrd4MlmWg1Zvx/cEsHDqrbvaJyBVaO7npzMrc8modFHIGD9ypsvGZdxU9NbdLd4MhLgaDazQa7NmzB/n5+SCEIDw8HHfffXeTyj7uglarRUZGBlQqlWSjGFVy+FwRzJzgK1bIWMjlrOQCiI0KktwDoksAgBQFw3E8woMD0D/QS4qOkMlYm0gPe9Edoq+bIwSEAOOGD8RN4f3wxY8XodU3TW7moZTBYFl56e0hw6N334oZY8OQnJIntSmTsYiJDMKRc0UwW3zjSqUMk0eE4vF7Y6xcQ7aRJ2aex5c/ZUBvNEtCzsDW28SywoQexxGAAbw85Hjs7luQW1SLo+eLoVTKUFGtk46XyxjcHhcsxeNb38gG3+iP307mI6eoDizLQCEXxM1eJJEziMU44uLinC7WsOKjI7iYew0MI8xbsIzFXgJ4e8qx4O5bmoSAdsYqz8afp+iCst0uZMX8633ju6TQSmdirwao+Fs+kG5Ajda+/PXEUbrLceZ+fn546KGH2rMvPY5fU/NxLK3Y4k4QfMVGMw9CiOQCyFE3HSmJ2Q1F10JBSR2yCmvAAgge4GuT11v8EX/6fRoKyjQwmczw9FCgQW+Cj5cCZoMZLMPgXGYFosP6YNTNN+B4eikMJttyWaKQA4DWwOHQGTVmjA1rMpKrqNYKETK8sErTbOKhM5ix70Qevkm+IsXLy1gGDAscOMVgyI1+YABJvcUJUGsIAeQKFn38lajRGBBueSIoKNPAZOZgMHGQpIkRRvPH0opxy9B+TSZPk1PykF+qEZ6KeAKTmUfG1WsoKheqSk0fEwajmcf6b89KPu2/z42HUi4UMGksjBPig3E6ux4puemIDA1s0ZcvYxlMiA+W8rIIl4uAgf38LFIoqSWayZUJR2dHr46eTlqbe4fS83BZzI1GI9LS0lBaWtokze0999zT1n51a6wnp3i+6Z3dxBFkq2uhlDOIU/VHes41lFfrwHE8Cso0+DU1X3K3FJZp0L+PlzCByBOoK+pBiBAVIwqSGE/OWESuXq8DIUBdvREEgK+XAmYzhyPnilFc0QCGtR0ZW7TZBp7wSE7JQ0GpBiaOR2llPYxmgmx1DXhC4KEADCahjfNZlaio0UlCzvHEZiVnZkENFHIW/QK9UFmtA8sIk58MC8gYBgTCJG1FjR5GS6GGAX28sW3fFZhMZhhMPHgixOBzPAEIpDbs+YZzi2rBMoyNXRwP1GuNyFHXAADWf3sWB88Ibo/CcsHH/uyDQlHcxn7lCzmVOJ1RB4bVIvWS7SpWRz7oaaPDwFr6pzOYceJiKXQO8rP8mpqPo2nFMJqExxaDkcPvpws7pvSfA9FvKfdOe9DV8eLtyYsfHbV53RNG6i6J+eXLl7Fo0SLU1NTAYDDAz88PtbW18PT0RGBgoNuL+b4Tefjix0s2bgWgqWgazQQ8iFAGzWoU/nXyFQDX/cCi0PM8wJl4yGUMzl4pF4TcKp5cLmchYxkYzbZiXa8zQSlnUVWnF3zSfh4AMcLLQw4vTzkIISi7prXpm5kj+Dr5CkxmDmYzb+kvb7Vf+NfTQwatwYTi8npJyO1hMvPQaI3w9pQjPDgAE+JDABApSqfx3EGOugYcx8PTQwGNThgM8FbDed4i6PZ8w2GD/AX3RqOumDjBfw4AOepagABKhQxGEye8tpCtrhFSCAMw6Uw4nyUUpx7U1wuVNbaTzc0lJhNdX9nqGoy6+QZU1ugAME3ys+QW1YJYdZYQIKuwRrqpWwsgT4j0BOQo9UNjuouI0njxrsUlMV+zZg3uuOMOvPzyy0hMTMR3330HuVyOF154Affff39797HbcehsEXQGs2UVrLBakOP5JuICABeyK/HMAyOw82A2dAYzfDwVMBiM2HkwGzKWQf9AT9Q2GFDXYJTew3EEBiOH7MIacDwPAuEHS0ycJOBNT0VQVqUFxxEhmsNTjjhVf5zPqkSD3mjTNzGunON4eCrlqNeZJLcIyzA2otpgEVqzxX3U2BcutckA/QO9cPftERBEXBCWx++NkYTF+oednJKHExcthZtheyP0tSzzH+owta/wlCL60q2pqBFi/iNCAlBYroHRxAGM8FrEYORgMFy/ljX1BijlDCpqmlYrEm+41k9MYrUga/Gy9k83ZmhwABovmCZEuKlYLyZLuVCKQf19HKZ+cMS+E/k20VSNV652FjRevGtxScwvXbqE1157DTKZDHK5HAaDAaGhofjnP/+JpUuXuv0yflhiycUfqLBQhrEZfYlwHMGLGw6juEKIDa7XmcAwgNbyGgwgbzSKIgAMJg419QaYrWL9GAZNIkVETByRVl4yDODvq8S5zArUNhiglNtOABEADTphmb7OEvImTpDyDubDxXkB675YH0qIUC6NZSBlSTyeXuIwv4go0r+fLkRWQbUw4QpBoFkG8PNWYmJCiN0RZl5xLeQyFmGDfHC1qNbmRlVSWY/klDwsui8OAGx85iKeSpnNow0hgI8ni1G3DJR85iLWoZfi3IS4kMhZ8RLbECfKxeun1Ztx7HyxNNksfjg8IXZTPziicWSNOB/S2dB48a7FJTH39vaGySQ8pvbv3x95eXmIjIwEwzC4du1au3awOzIhIQRXi2uF0RMhDl0PAFCnNaGuoMZm2/VRsDAaNdt5PwGg0dqWbuMcCLl1m+K/12p0MFpuBHqj7WQoA4BhgaGDAkBAkFVQDaNlAlIuZ2C23BjEYxv3ThwVWwu8nBVE0lrgisrrcSy9BJ4KGVIulAp+cYZBjroGeiMHpUIGk8V1w4CBh4cMo26+Ad6eCoQNEqJkVnx0pEm4n3U8uZeHHEYzJ/VZozVJbizRR97YDREeHNCk/44QQy89FDJ4KOWoqTdIYZOtyUG/9IEEEEJw5HwxCCGQsyyyCqulz54nwtPPhIQQ9O/jZTf1g0MY4Y9Y/b8r6A7x4r0Zl8Q8ISEBKSkpiIyMxNSpU/Hqq6/i5MmTOHz4MEaOHNnefWwzdXV1ePnll3Ho0CH4+vriySefbFMkzvjYYHz500WYzM2oqxPwzQimvQGyIxdHk+MYBiYHQiXdQMwEQYGe4AmQwVdfvxlYRolcM33zUAgjW4677qpgWRbhwQFgGUYSOJ4QMCwjjVwPnS1CcUUDdHoT9CYhXlq8hiwjuJYqa/RYs2iEpRxbht1ybDaRGYMCcCm3EofOFYEQIT9Mbb0BB09fj1Nv7Mt9YFo0xscOkoRVxjJo0PM4eEaN38+om/iphwYH4OAZNTS1OqEvlvh4IU5dGBnDEtHiqGCzjGXg7amAj6dCuh4Gy2Il8SlnQF9hgda00UNsUj9MShxsk9mysU/cOrJGIZdhQnywE9+S9sed48V7woSoS2K+atUq6PWCb3LJkiXw9PREWloabrvtNixatKhdO9gevPLKK+A4DocPH0ZBQQEWLFiAiIgIjBkzxqX2nnn3d9Tr2l6oWqSZgb0NTgk5ABBi92ZgfS5CCE5cKIXJ4psVaXx/YhiAIbbnNpl5aRQopjAQ3DOM3fwi4uKeqlo9zGYOCgXb5GmBJ0JDucW10ihaWAAlnFtnMGPnwWxkF1bDYBKSm11PVibEm4OQJu3Yc4fkFdfimXkjcGtEEHKLapFfWie4ejiLn9oqJBIQRpwHT6txpaAagX4eMFhWYIouoeLKBnAc32LB5sYj+ZjIfki9VGYRYRb33BHhcH6huYlF68gaeyPixk8mk0aEdKtEW5T2wSUx79ev3/UG5HIsXry43TrU3mi1Wuzduxe7du2Cr68vbr75Ztx7773YsWOHy2JebrW4pTOw6JRTYm7RMqcwSqNixzcUaYRu6YPQF2JzPMsAcjmLvOJaaXQmppa9ki+s+mQYBuXVWiGkkBXuAo3dU0oFCzl7XZQUchYmPS/ZVVzRgJKKBhAAHkoWh84WgYGQAdBsJtIchq+XQlpwBdifxASuC+Ivx3KRmV8tjapFYbQWwf59vFBc2QBjoxWYrZn0a+yGsLc62B5S8QoHGRRbGhE3fjLheR6Fhd0n0VZPpPFIXaQrR+xsaw4+efKkU3/diby8PABAZGSktG3YsGHIyspyqb2tW7dK/9fkH0bFyY3gTYK4m3VVqDi5EQ1F169BzeXdqDz7hfTaUJWDipMbYajKkbZVnv0CNZd3S68bik6i4uRGmHXCwg6zUYfykxuhyT8sHVOXvQ8VJzdKr411Rag4uRG6snRpW9WFbai6sE16rStLR8XJjTDWFUliXJa6EZrc/QAEwW4osG9TvfokGAAKOQtN1g+oPPsF5DJGskl9dAP+OHQEL6w/jJ+O5iDp/ofxzpuvoaJGBzNHUFtwAiUpH8FQfw1DB/ljzE2BqGhkU+Xlvcg7vB5hA/0waUQIJt8iR9XpT6CvSJfOdU20iQhRKZUFZ1GU8hGgLUGgrwe8PeXIP/IhKi7vRdhAP3Ach7y0fSg8uh5mgxYMgBNnLuHOu2bh7yvexS/HcjEhbhDM+T+i6twXUMhZeHrIYKrJwcxZ9+LTr37AkfPFOJdZjrLTn8Fc8AseuDMKk0aEYNu2bdjz+SqY9VUoq9JCr6vHto0vYvmq/8BoMoPjOLz99ttISkoCx3EA4THQuw6/bVsDruoSZAzB1JGhyDzyBfb+vw8AwoPjOPz8889ISkpCWloaOI5D2EA/XD30IdKO7IDexKG8WosPPtyImbPuRVVVlWBjXh6SkpKwbds2cBwHjuOwatUqPPbYY8gurIGZ4yHTFiDr9/dx6PBRlFSbYOZ4FJ/8LwrPfCdETnEctm3bhqSkJOTl5YHjOFRVVSEpKQmbNm2S2rW2ieM4pKWlISkpCT///LO0bdmyZVi2bJn0WrTp3Lnz+OVYLjZsP4cZM+/Bf/7ztnTMpk2bkJSU5JRN4uujR48iKSkJR48elbY99thjWLVqlfTaEYSQdv8Tz9nef87QqpH5/PnzpaIUjrIAMAyDjIyM1jTboWi1Wvj4+Nhs8/f3R0NDg0vt1dTUABjU9o51InLZ9bhxa5RywT/OMgTB/RSo03Jo6aqYzDzMlplYcQKRZQECgqo6AxquVuFKQRUqqrWQeXnCW8mg1mw16csC3gojxkcp8YtlKCGO+hUyQOnBoK/sGr7YWYiM7GLIZYBS1jSaxmBZgAOLnxyEx23DvMAwwNcnGAzqK0eQogrnzlWjrLQUPM+BZQg85EB1rRlavQlX8qtQvOsCjpzORmiQEpxehluHeGJgHwU89OXQ6Y2Qm8zo50FQU2+GwWCC3qBHYWEhziuqoFarIWc4jBjqibwaoLBeqJp0OqMMn+84goQIX1zIKkZFdQM2bz8CgCAj8ypq6rTIzb2K/353BCXVJuQXX0OAtxznzp0DAOTn50Ov1yMzMxMmkwl9ZTzkMgBE+Lz6+LCo5XgY9UZcuHABPj4+KC8vh16vh1qtltqpqqpCQ0MDFHwtCG9GVZ2QjlhOtBjYR4Erah2MJg5KBQ8FX4tz585BrVZDr9cjIyMD165dQ0NDA/R6PUpKSqR2KyoqoNfrpddXr16FXq9Hfn6+tK22VnhCaWzTjwfTcKkyEBwP1NYbcDG7WDqmpKQEer3eKZvE1zk5OdDr9cjJyYGnpycAoKGhAVVVVdIxI0aMsPtd1up0qK9vP3cpADzz9v52be/Rqf0BOLbBmlblZpkyZQo4jsPs2bMxa9YshIWF2T2uO+V7uHTpEubOnYsLFy5I23bv3o3PP/8cu3btcrodMZ9DZGQkHn7lYPt3tI0o5AxMZvsfpVLOWJabCz7xxp84wwiRKKNvuREeCmGRUEFRJXx8fFGt0aOy1gAPhQwNeqEgs1zGQqlgbaJtrCLrAAiTYaIv3vp03h4yPDrzZkwbPQR7U/Lx1c8ZUnz0QzOGQc4y+ONsEa4W10FmybvSN8ATOYW1Nm17KFhoDZx0IxgU5I0Pn59kswz/wMkC5BbVYWiwkDrgy58zpFWaot2EAB4KFtPi/fHovWNtvrv7Tgj1UE1mXgrblMtYeHnKERsZBE+lHEOD/TFl5GD865PjuJRXJYSoEoKbw/rijoRgbNufCTNHwHHCPINcJmQrjIkMQlq2sFhJLmMwd6rK4hqqk9oUbRH7odWbYTBx8FDI4O0pxwN3qprNaCnS+FpMTAjGhfR0lBv6WHzmtudr6f3NHdsSH+9Mx7H0EvQP9EJFjQ7jhg/Ek3OGu9SWM3AcB6VSabNN/C3/mqZ3mJulu7DmybEAnNPUVo3MDxw4gFOnTmH37t148MEHER4ejnvvvRczZsyAv7+/a73tYMQbTk5ODiIiIgAIK1ijoqJcao9lW+WZ6hQYBpbIkKYhiLAsiyf89aRc9jBzBN6eCjx+bwz2pVyFtr4Wfv7eYFgWtQ0mGIxCVEnfQE8YjZwlnFBYYWmwWswk4uUhs5kkViqEyJUAPw/wYPDprgsYPNAfI2++QRKJjKtVOJ5eAo4jIBAWIRktRSas25exDAL8PKAzaqWFW30DvCCTySQft3VZt9RLZXhgmgpDBwXgSkE1PJUyaLQmaQ6A4wlSrtQj9JQaU0Zdz5oo5JqPxqGzRbh8tQpmnoAQHpzWhOPpJfBUypF6qQwsywr9INdvEAzDIK9EA86SM7xp2t46aV95tQ57DufiWo0OLMPgxEU5WJaVfNhiO8H9fVBc0YD+fa7naXe27uiMcUOv7+c4sCyDu8aGOSUSv57Mw7b9WeA4XrLXVf96ZGggUi+VodKyQCsyNLDLBn+MJUVGe9KVPvNWT4AmJiYiMTERL7/8Mn799Vfs3r0ba9euxe2334633367yV2wq/H29sb06dPxwQcf4PXXX4darcbOnTvx/vvvu9wmC6BtQYntDIGUk7vRZoAIQt1cLDwhwiSikFBLGAXWa40wcVp4KGWQMQz69/VGrcYgpWAd0NcTWYU1TRJ6AcIqUtENIpcJ8dxiXpLyKh0270qHXM4K+WYg3IhOXiqDzmC2mVi9VqdDoK9nkxBQo5lH+TWtlJ9FoZRjQnywzUSf3sQ1SlxWh4kjQlBUUQ+dwSwt0hHzwdQ2cNi2PxOXrlbhxMVSqb+3xQ2yTPgKE6xiymMPhcxm0tN67YGnUoYJCSFgLas6pbS9sK7oFCDkvLGsxC23rN61VwZQiquv0cPLs+U87faW1VunDRh8oy/UhfYTi9mjPVd20lj0jsPlRFtKpRJ33nknWJZFbW0tfv/9dxgMhm4n5oAQSrly5Urcfvvt8PHxwZIlSzB27FiX29v2xky8tPEIshotBuoqmntQFFwoLBr0jidRWJaBQsHiXGaFEH3SYARvNYqXyVjERgYhPDhQiKkGQem1eoft+fkoAEJQa+ab+LrFmwohPAgIZKxQci6/VNMkoibQxwMPTo9GenYF/jhbbNuOZVStGtIHkyx1Rj/9Pk2K+sgv08BoJrhaXAelXFgZma2uwYA+XrhaUgel5SbloZRBqzejjw+L6gYzjqeXXPfHAzhyrhhylmkUvSPcRKwXC00dNaRJymIRcZTvKFdNQZkGWYXV0k21cRlAMe+6WEGKJ3AYzy6er7H4Wgv8wTPCRKuHsmliMXu058pOd4pF726x5i7XAN29ezf27t2LsLAwzJo1Cxs3buy2ucz9/f2xbt26dmvPSynDu0vvgNHMY922MzhxoQR6U7caqwO4HgPenJD38VWCI8AAy4pNjdZoI8B6AwcZwyAiJBCAEFOt05ugMzpuU6c3S5Ok9mAZWPK2M1DKWWnkyoBIq1aVchYPWHKdTEocDJZlcT6zAlUag9QOb1k5altOrxTFFQ3gzETKaaM3EaReKoVCxkJvEuwJ7i+M2Af190VxRT2uaYx2F1rxvJAszVMpg97IQcYykMsZhAcHYMiN/jYx2vZEqjnhsk7pW1zRAK3BJOWmb1yqj2UYwS6Ox7Z9V6QFVI3hLE9YehMHdUW9tILUWuDzSzUgPEH/wKaJxexBR9M9g1aJ+Ycffog9e/aA4zjcfffd2LZtG4YOHdryG90UpZzF8w8nAoBUf/OPM2qoy+uhsSTOcnZBUEdAIMRcN+hNDn3lA/p6o6JGL63YlMtZyGVCCgC5jAHDCMIljnzNZsf5WxgGCArwQrVG36zdrIyBjGEwdvhA3Dw0CHnFwsiVJwRHzhYBjLCqccrIIdLKx1uG9sPf58bj6bd/E9L8ioNSq8GpOILdlnzF5mZDLGGMgwb5SKmGxVHmhPhggBBs/zUDdVphktLY6MZMCGC2JMXheALeRDCgj3eTUn+tRfRtZ6trEBsVBE+lDBEh190e4v4cdQ3OZ1U2qdJkj19T83Eus0JwBfIEsVFBkotFHF0r5Cw4jkiJxcIG+Te7wtSdRtOu0N1G4I5olZhv2LABAwcOxIgRI1BSUoKPP/7Y7nFvvfVWu3SuJyFjGcwYG44ZY8NtJqDq9SakWNLY2oNhAF8vOTRaxyFSLAPc0Ncb1Ro99MbWPQHoDGYoZIw04rU5N4CyKi3iowfAy0MuFU3W6k2SL9vbU4H+gV749Ps06AxmmHli44awtt/Dkm5WqZCBgBOqCtk5TsYwGBczyGbJvHjNBt/oJ7kkVn1yFJmFNSBEuHGaOQJVaCDKq7QghMBDIcf4uEE2QgQwdv34hBCUV2nh7aFAbFSQTfIvEB6FhYVIPlNnkwZYRKmUISjAC6XXGuDlIYfJ4hdvKy1lXRT36/Qm6I3CJHNlrZCPJmxQAH45nielEpgQH4xpo4XVrjxPpLkCLw+5lK4XgJXPXA2zLACRoYHgCWjqWjegVWJ+zz33SHHmFMc0qbIekY+swhpcLa5FRZUWdVqTle8Y8PVSCnMP9Ua77XkoZCiv1oLjBQFm2eaTblnDMEC/QMF3au0+YRkgbKAfrtUa4OUht6lqlFVQjZKyCgy8oT+MZl6KCmFZBoF+HjbhfSIcTxDo74G4qP4YPNAfX+/NQF2DCY3pGyBEw4giI77XuhKP7IxayBxpNAs2M0ISqj2Hc6A3cJBbaoCOvvVGsAxrI0SDgnwgk7Hw9Vag3ip0Ui5nMXigv93SchwHxA/1wZEMHYxWbhyRPn6eiAwNRHmNDkaOh7enQnI7tQVpZWcfLxRV1DdZ2SnuVyhY6IwcPD1k4DiCoZab3ZYfLwklApnrBTsc+bebFKfwrEFc3HDIZDKbYtk0dW3PpVVivnbtWrvbjUYjjEYjfH1926VT7oT4I5pulTngL6v3oqruumgYTBxCBviitr5pKS+Wha3LAM67bny9FGBZBn39PVGjMYAQ4Q4QERKAimodKmv0MPME+aV1eOd/p1BRowMDBuPjBmF0OIeEhBh8uusCeJ5IP3SxLXt1Rvv6e2LRn2ORnJLXZL+3hxwMI7g7xIpLySl5kgvgmKUSD8syMEFYhMRY0nyJqX31euHJwM9bKD0nFLWusRHEoop66I3CzYZlhRb6BnjCYOQw+AY/hyNOlmUQGxWEP84W2dyolAoWkYMD7bou2ooovEUV9TAYOVRUa6WMj9PHhEn7dXqT5Fby81Zi4ogQ5BYJkTMsK1RyMllK1YklCVvj36apa5vSU1wr1rRKzE0mEz766CNkZGRg+PDhePLJJ/H666/j//2//weO45CYmIh3330X/fv376j+ugUxkUH440wRhNRUwuubh/bD1eI6KW9IUIAn+vfxRum1BlTXGWwiVggBZAygEGO8G4m7p5IFxxMYjGYoFIKLIK/YUvxYxuDGfj5QymWo0uhRqzEgp6gWl3KrpHNkFlZjaowfKk35yC8R+pRfUgeFQobxcYMwIT4Yn/9w0SYckrE86gMQSrdZdUomY/DIzJsgZ1m7FZekMnCWhUYMy1jythOpbS8POWKignDyUhkqq43CaLRYyJnCsgzyS+tgNhMYDRx4CDdBOSss0NFojU0KS9hj8X2xYBgG2eoaeHvIETbIH1GhfZCjrrHruhCx9m3rjRw8lDJEhrQc8icK7Z5Duaio1mJQfx9U1uilkbG4P6uwGnnFddAazIgMCcSkxMEACqCUs9LIXCGXIWyQv0sVh+gEp3vQKjF/6623sH//fkybNg179+7FmTNnUFpairfeegsymQwff/wx3n77bbz55psd1V+34On7E8AwjE3hBBnL4EBqAa5Ywh0ravSobTBKI6/GyGQsPBQyoTSchwxlVTooFMKCpqGDApBfIvh/GQDXaoWJroFBPigqr8eJi6XwVMigN5rBsAyUchn0VsJsNPE4kFYL2YVLMFrCC1kGUAJgGRYzxobhh8O5KCyrl1LkeinlmDJSEAG9kYO1uzwyOAB3jQmXXAdi3VPxkX5ocACOp5cAMIMnBH0DPFFdq4dczkCn5+DpKceom2/Eovvi8Mrm47iSX41AXw/oDSaUV2vh5SmXKhaJ3idfLyVACHx9lNJinHOZFfhg2xmpEIc44Xi9oPNF3DK0X5MybWJVJEcjVxvftkkQ89SLLYf8Wbs+vk6+gsoavUPXyMlL5TCbOZy4WIqKT45hQkII5v/pZpsJY4Bxyffd3ARndylJ11n0xBG5SKvEfN++fVi7di3Gjh2LkpISTJo0Cf/9739x223CBQgKCsIzzzzTEf10K5RyViqcYI3OwNmsqjSbeYQHB6DIEoFhMvNCLnEAQ4P9hdWoBAgK9ILeWClNpNXUG2A081K6VoCBTCaEAHI8Dx4AASv4wcHAaCdxi1Ai09a9wxGCP86oceiMWipzR6TjORw4mY8ZY8PhoZTBQymDUi6D0cxJxSCEfjd9pG88MuQJwbZ9mdDpTeAhCMr5rEocOlOIiQkhKK5ogNEk3DDyioWbluiKER8IjGYO3h4K9PXzRIPWJBXNPpZWDIZlBDFXyHDiYlmzBZ2Blkeu1r5tvZGDUi4Dx/F2fc/WESx6gxmVNToQCE9nXh7Xo1nste+hlENTq8OVgmoUVzbgwenReOOp8dJxLfm+pXMX1kDB12N4DEFLiy8bL0Ay8zyu5FXbDESU8taviu5tN4nOoFViXlFRIS2JHzhwIDw8PBASEiLtHzx4cK+oNNRR2NSthDBpV1mtkyIwzmddF+wb+vpIr4sq6hGn6i9FpJy4UAozx0u1QMfHBeNKfhVy1LXw9JChvFoHTYMJBEB0sB/CgwNw5nK5TWpfhQywDgohRJiEzCqovp7P3AqOF4R+xthwRIYEIvViGTiOh7eHApFWk4X20sDuO5EnLYgJGxSAKSOHgGUY7DmUi/IaLYL7+6LSIk7WPuH8kjrkWPKKm3BdyGUsg4jgAEwaEQqeANv2XUFFtc6yfJ+BQi6DwSDkVec4HjlFteB4YKCdgs7OYO3bBnP9RmLP9yyKo9Zgkp6GGADennVYcPctTUbI1nHjGp0RIECgnweMlpzq9vph7wnCZpIZgEJGEBpaYLPM3x6NFyDtOZSLksoGgACF5UKKAnsDk5borsWfX/zoaI8dnbdKzHmet8mjwLKsTa4SMckQxTXEOpU56loMDfbHsLB+KChtvrq9mOM6R12LWROGIkddA7mMQVCAF2rqDUL1HxZCBXozhwZL2KHSUgczPDgAT90Xh4++O4eDp4UoEr1R8MMrLWLHWMIJ+wV4CTnJLa4fvtFMbLUlEqS5FYvW+c5/Tc3Hqk+OIbOgWgoJzCysBnC9IPHXyVdQaSVO1u9f/uEhmyIXYpy7wWjGkBv9pePElZlivhadUfAzm0w8vDwViAgOQE2d1qags/XIUQzZdCQ84g2qsc/cukJQ2CB/AAx+PJwLnd4EhZyF3vLkw7IMjGb7I3nruHGWYcAqGCmlQuObRXNPEI0nmXleSOrVEo1vEHq9kKNHaQlDzVG7FvlCiz+3P61eAbpp0yZ4eXkBECZEv/jiCynJlk7XuUUb3A1H7hcRe9XtiysahBzXNUIkRP9ATxhMHExmHr5eCkxMCLF5TK+zhOsZjMLIVF0uFEAOF0eXBjNkLAN/Hxb3Tb0JchkrLUE3cwSf/XBBiplvXNSij6+HJF46g1moMMQTuysWxZGZRmu0yb1iNPFSQeKWxCmvRBgZMowgdDIZA6OJg1wus+t3lkIvC2uQV1wLrd6MyNBAPHHvcPxvTwNMbICUq8Qmz4tlbiHEsmpUFJ7GroIn5sQ28bVfX0KvBgPBVWUwcZBZVebmLKthhfh6W6zjxsurtBg80B+Db/Cz6+5pzvdtPcks3OAYDA1uOTle48/gQs41/HFWLTw9MsLTpCt0twianjoat6ZVYj5y5EhcvHhReh0fH4/MzEybYxITE9unZ5RmsY6EEF0RRRX1yCsxCtEgBIhT9bdZ/VejEUr9yWXCD5rniRRV8sC0aMSp+uNYeolQ4cdMILcshBL55fhVyFkWHMuDYRiEDfRDQakGJo5AIWMRFOjlMNFV45GXeIMJ9PNARePKTU64TnOLasGw11PtyuXCQiSbxUCNEMWOJ1dx+FwRTGYe1+r0iB4SiBGRvlLctXX/xMLUPEET4dl3Ig9f/HhJKvvGE2JzvRovoQeAwTf6oaiiHv0DveCllCG3uE5KrWsvy4616MnlMkxMCGnRHWHPH914kjk62ANTRg5u8To3vkFMShwMhoGNz9wVaARN+9MqMf/qq686qh+UVtIkEsLKJyzmHWm8+m/nwWyUVDRcXzzEWJbzW+piennI4amQoV+AJ9TlGvxw5CpYlpUmp/KK6yCXMfDxElw4DTozZDIWBEIagMpanSRejZfNhw0KsFmpGTZIECmDkYNSwcJsFlZEeVoyIAJCHu8tP16U8p3z5Lr7RRQ5LUwgvLAAytNSF9R6Ms2esB06WwSdwQyGYWA2mHH4bDGSxnrbXF9rEfX0kCM2KgiVtXobt1Hjdg6dLbIRc+s2FHIWDCDNgcyeEIHcolqUVemsapM2dXu4InqOsiaK7YQN9EOQosqlCceWnh6dpbenCOgIXM6aSOke2Cug7Gj1X466BtW1emFFoV4QocbHplwoRXFlA0xmQYgbL2I5eFqNuhphJF1mSZA15EY/VFTrwDDXo2a8PRSIiewnCeCl3Eqcz6oEzws1Jx+YpsKD06PtZhQUbTp0Rg2t3gyWFbIeiu4XR3aXXtPiRKOQQHvCJhRKtYp+saNpgt8f0nJ5AE0SXTVphzRtQ+yj6DPPK74uytb5UuxNWFrfhMSJ35YiQDie4OBptZDHpVHxaZsVoOeqm/lWUXoiVMx7OPZ8wo5GcREhgThhiTLx9VYKo80aPcAAPCFSnPjuQzkou9YgRJFYRXdMHTUEOw9mQ1thBmPJfGgy8zZJq6yrxIuRJGaOx+X8KhByfTVmXnGdTaIq677/mpov9J0BwFg00vJ/e3Z/9N056A1my03KJCxasmBvoq1x7vHbYgfhdLa6SX5vlhGyRHKcEMXT2G9uL4e5o8/GHi3NCTS5CaHlHCq/puYjt7jWJpqpo/3RPT3M0B385QAVc7eiteIhiq0w2swEy1h8yjyPr36+aBPdIbbf198TJZUNUoWfG/p6I07V3+6PWIx79lDKpLh0RwJjT7wmxAcjt6jW4pOWSe6XxugMZuiNHHRGIU5fXEULOI5rt849buZ4HExvGmfekt/cXg7z1ghbSxOW9qI9xG3lVVocPKNucp7colrIWdtopo72R3fXMMPeBhXzXkRj8XC0yGTKyMEoLCy0ie4QaSyw90yMsPETWyMKaY3GAIYBfDwVMJg4uwKTo64RQvasRtdPzIm1Gek3fo8onOezKqXUCATC6lkRe6Pfxtdhw/ZzduPMxf6XV2kBBujn54m+/p6YYJWsq7FoWUewtEXYxAlL4SYixJrfFN5XujGZeYKrRbUoLNXYnEfss9HECXlcEkI6fJTck8MM3WVUDlAx79U0l2GvcXSHyLTRYc0KrDXiPjEfC8vAocDojRz0Jk5K9ZpbVItPv0+T/MX2BEkcEYqjfjAWb4wDd4zj6+CPo+fR5ElE6v8ZNa4W1aJeJ6SiZRk4VeWnvFon2O6C+2HqqCG4mHtNWOTDMjifVYmbwvtJ8wwFpRoUNEqLYN3nzowS6W5hhr0VKua9kOaKIrREa6IQxGOt6086EhjrFABagwl5JRqUXtM2O7oVhbOPn4cw0QohIZcjd4wjHD2JiP3PtYyAnRl5Wgsbx/HILRZy0bR2lC5jGSm66Hq0S600zyA+ATSX6razoGGG3QMq5r2QlooitDfOCIx1CgAZw4Bh0aJ4WrsUvD3lGDooABNHhNisvLQ3Im7s1540IsThk4j1eZwZeVoLW0GZBgUldS67H5o7b3cSUHufb0+YFP373Liu7kK7QsW8F9KSK6AraBxqKFaub048HfnDW/JbN56w43keAzyut9tE7BMHNzmPI6yFzdHouTXXxDo8kiekSVqE7gqdFO18ur2Yp6SkYMOGDbh06RI8PT1x9OhRm/11dXV4+eWXcejQIfj6+uLJJ5/EQw89JO3PzMzEypUrceXKFYSGhuLf//53r1+l2pIrYOrIUJfadTWvN9C6EEt7MdjW7bc0Idd0fx0GWOWbai8hcmb03NwItnF4pHXEUXenJ0+K9lS6vZh7e3vjz3/+M2bNmoX333+/yf5XXnkFHMfh8OHDKCgowIIFCxAREYExY8bAZDJh0aJFuP/++7F161b88ssvWLx4Mfbv34+AgN47SdOiK8BFMXc1r3djWhp1tiS2LblFwgYF4I8zauSXaqCUs5YFPdcX0bRFiFrrXmjJlp4qinRStPPp9mIeExODmJgYnDhxosk+rVaLvXv3YteuXfD19cXNN9+Me++9Fzt27MCYMWOQmpoKvV6PhQsXgmVZzJ49G1u2bMG+ffuQlJTkUn94ngfHNc3/3dOYOjIUGBmKfSfysa2iHuXVOshlQr4V0b7W2pldWAMzJ+Qp0Vnyeps5HtmFNZg6sv2umXie/oFeqKjRNWl/0ogQ8DyP3KI6DA32x6QRITa28DwnLdYUsj8KSa/EY8IG+iHlAmP3mrTEvhP5+GbfFegMHA6cLMCFnEo8PTfOoaC3ZEtb+uIIVz/f1tDSZ9BZcBxndx5E3NdTfsuObLCm24t5c+Tl5QEAIiMjpW3Dhg3DF198AQDIysqCSqWySdM7bNgwZGVluXzO7Oxsl9/bHQlSENw2zAsl1SYM7KNAkKIK6enCKDU9Pb1VbSn4ehDeDIOJgAGgN5jgoWCg4Gtx7ty5duuzeJ6SSg1kLOy2P8ADFtdJNdLTbJeun0qvBgiPIH8Zahs4nLmYh0Ej+0j22rsmzi5/T02rRoPOBDNHwBPg6Pli+MkbMCLSfn3cxrbIuVr897sj0rljw7xd7ktLtPbzbS3NfQadyYgR9nPJNE4S2J1xZIM1XSrmHMc5zH8u5Plo/m6k1Wrh4+Njs83f3x8NDQ0AgIaGBvj5+TXZr9FoXO5zZGSk2xWuTkiwfc1xHNLT0zF8uG10B8cTHDhZII22powcbDPiHB4jFDzIUdfCYOLgoZAhIiSgyXFtRTyPo360RLkhH7llmWgwEHh6KJB4aziAGht7G18ToGX7xbbT8y6CJ5aJShkLExuAuLjhTtnCE4Jvf82CmSPILTMjNDQUf73PTmfagKPP1x1pbuStUqng7e3tcH9Po0vF/NFHH0VqaqrdfUFBQU0mOxvj7e0tCbeIRqORBN7Hxwf19fUO97sCy7Ju/wMQkclkNrb+ejIP2/ZngeN4pF4qA8uyNv5dmQwtVq5pn3617TzTxoSDZVmb0MT0tJom9jamJfvFtjPyqoXFPgyESkuhgQ7bbWyLsCqXYIAYW16i6bDvW0v2ujvuZn+XinlbU+qGhYUBAHJycqRydpcvX0ZUVBQAICoqCps3bwbP85KrJSMjA/PmzWvTeXsrPXUyrjGNJ1id9Zs6Y7+MZbD0gQTcMrT5RVKOoBOHFFfp9j5znudhMplgMlkq5BgMYBgGSqUS3t7emD59Oj744AO8/vrrUKvV2LlzpxT1MmrUKCiVSnz22Wd45JFHkJycDLVajTvvvLMLLeq5uJPQWEedCPm9Wy536Kz9bYkB7w6LgZyNyOnKhUE9YVFSZ9PtxfzkyZN45JFHpNcxMTEIDg7Gb7/9BgBYtWoVVq5cidtvvx0+Pj5YsmQJxo4dCwBQKBTYuHEjVq5ciXXr1iE0NBQbNmxAYGBgV5jS4+kOQtNe2IYEMrhtmJddP7k1nWF/d1gM5GycfVcuDKKLkprS7cV89OjRuHLlisP9/v7+WLduncP90dHR2L59e0d0rdfRHYSmvWi8Crak2iSM9k46TgPgTvY3h7PutK50u7mLy689YVs+hEJxP4YGB0hVkeQyBgP7KHDgZAG+Tr6Co+eL8XXyFfyamt/V3ewSrK9Nc+4kZ4/ryj72Jrr9yJxC6Qjs1cRMzaujoz04707qSrebO7n82gsq5pReib2amEOD/ZF6qazXj/acdSd1pdupt7i8WgMVcwrFwpSRg23iz1s72qMRFpSuhIo5hWKhraM9GmFB6UqomFMo7URPirDgeYJ9J/KRV6KhTxFuAhVzCqWd6EmLqs7mNuDo5UpwHKFPEW4CFXMKpZ3oSREWJdVCZscBPeApguIcVMwplHaiJ0VYDOyjQG6ZuUc8RVCcg4o5hdILiR/qg9DQUBufOaVnQ8WcQukGtGdYozNtsSyDaaOHuFUK2N4OFXMKpRvQnmGNNESyd0Jzs1Ao3QDrsEaO49s0IdmebVF6DlTMKW4BxxMkp+Rh447zSE7JA8e3nJ+8O9GeiaNoEqreCXWzUNyCnu5aaM+wxp4UIklpP6iYU9yCnrT60h7tGdbYk0IkKe0HdbNQ3ALqWqD0dujInOIWUNcCpbdDxZziFlDXAqW30+3dLJs3b8bdd9+NhIQE3HHHHXjvvffAcZy0v66uDkuXLkV8fDxuv/12/O9//7N5f2ZmJubOnYvY2FjMnDkTp06d6mwTKBQKpcPp9mLO8zxef/11nDhxAl9//TV+//13/Pe//5X2v/LKK+A4DocPH8Ynn3yCdevWISUlBQBgMpmwaNEiTJ06FSdPnsTf/vY3LF68GLW1PWtyjEKhUFqi27tZHn/8cen/wcHBuPvuu3H69GkAgFarxd69e7Fr1y74+vri5ptvxr333osdO3ZgzJgxSE1NhV6vx8KFC8GyLGbPno0tW7Zg3759SEpKcqk/PM/bPBm4I6J97m6nCLXXfeE4zmHKAo7jesw1cCbtQrcX88acPHkS0dHRAIC8vDwAQGRkpLR/2LBh+OKLLwAAWVlZUKlUYFnWZn9WVpbL58/Oznb5vT2N9PT0ru5Cp0LtdU9GjBhhd3tmZmYn98R1HNlgTZeKOcdxIMT+Sj2GYZrcjb766itkZmbizTffBCCMzH18fGyO8ff3R0NDAwCgoaEBfn5+TfZrNBqX+xwZGQlfX1+X398T4DgO6enpGD58eK9IxETtdV+aG3mrVCp4e3t3Ym86li4V80cffRSpqal29wUFBeHo0aPS6927d+OTTz7Bli1b0KdPHwCAt7e3JNwiGo1GEngfHx/U19c73O8KLMu6/Q9ARCaT9RpbAWpvb8Pd7O9SMf/qq6+cOu6HH37AW2+9hc8//xwRERHS9rCwMABATk6OtP3y5cuIiooCAERFRWHz5s3geV5ytWRkZGDevHntaAWFQqF0Pd0+muXHH3/Ea6+9hk2bNkGlUtns8/b2xvTp0/HBBx+gvr4ely9fxs6dOzFnzhwAwKhRo6BUKvHZZ5/BaDTihx9+gFqtxp133tkVplAoFEqH0e3F/N1334VGo8FDDz2E+Ph4xMfHY+HChdL+VatWAQBuv/12LFy4EEuWLMHYsWMBAAqFAhs3bkRycjISExPx8ccfY8OGDQgMDOwKUygUCqXD6PbRLL/99luz+/39/bFu3TqH+6Ojo7F9+/b27haFQqF0K7q9mHcXeJ4HAOj1ereaNLGHGAGg1Wrd3laA2uvOiHHmnp6eNiHK7ggVcycxGAwAgIKCgi7uSefRk+Jw2wNqr/ty0003uVUYoj0Y4ijQm2KD2WxGbW0tPDw83P4OT6G4G9Yjc57nodfr3W60TsWcQqFQ3AD3uS1RKBRKL4aKOYVCobgBVMwpFArFDaBiTqFQKG4AFXMKhUJxA6iYUygUihtAxZxCoVDcACrmFAqF4gZQMadQKBQ3gIo5hUKhuAFUzJ2grq4OS5cuRXx8PG6//Xb873//6+ouuczWrVsxZ84c3HrrrVi2bJnNvszMTMydOxexsbGYOXMmTp06ZbN/7969mDJlCuLi4vDYY4+hrKysM7veaoxGI1566SVMnjwZ8fHx+NOf/oQ9e/ZI+93NXgB4+eWXcfvttyMhIQGTJ0/Gxx9/LO1zR3sBoLq6GqNHj8bcuXOlbe5qa7MQSos899xz5KmnniIajYZcvHiRjBo1ihw/fryru+USycnJZP/+/WT16tXkmWeekbYbjUYyefJk8sknnxCDwUB27dpFRo4cSWpqagghhGRnZ5O4uDhy9OhRotPpyL///W/y0EMPdZUZTtHQ0EDef/99UlBQQDiOIydPniQJCQnkzJkzbmkvIYRkZWURnU5HCCGkuLiYzJgxg/z8889uay8hhCxfvpw8/PDDJCkpiRDint9lZ6Aj8xbQarXYu3cvnnnmGfj6+uLmm2/Gvffeix07dnR111xi2rRpmDp1qlQUWyQ1NRV6vR4LFy6EUqnE7NmzERISgn379gEA9uzZgwkTJmDcuHHw9PTE0qVLcfbs2W6dEtjb2xtLly5FaGgoWJZFYmIiEhIScPbsWbe0FwAiIyPh6ekpvWZZFvn5+W5r74kTJ1BQUIB77rlH2uautrYEFfMWyMvLAyD8SESGDRuGrKysLupRx5CVlQWVSmWTEtTazszMTAwbNkzaFxgYiIEDB/aonNharRYXLlxAVFSUW9v7zjvvIC4uDhMnToRWq8WsWbPc0l6j0YhXX30Vq1atAsMw0nZ3tNUZqJi3gFarhY+Pj802f39/NDQ0dFGPOoaGhgb4+fnZbLO2U6vVNru/u0MIwYoVKxATE4Px48e7tb3PPfcczp49i+3bt+Puu++W+u1u9n7yyScYP348oqOjbba7o63OQMW8Bby9vZt8yBqNponA93R8fHxQX19vs83aTm9v72b3d2cIIVi1ahXKysrw3nvvgWEYt7YXABiGQUxMDJRKJdavX+929ubl5WH37t14+umnm+xzN1udhYp5C4SFhQEAcnJypG2XL19GVFRUF/WoY4iKikJmZqZU6xQAMjIyJDtVKhUuX74s7autrUVJSQlUKlWn97U1EEKwevVqXLp0CZs3b5ZKh7mrvY3hOA75+fluZ++ZM2dQVlaGyZMnY/To0Xj11Vdx8eJFjB49GiEhIW5lq7NQMW8Bb29vTJ8+HR988AHq6+tx+fJl7Ny5E3PmzOnqrrmE2WyGwWCA2WwGz/MwGAwwmUwYNWoUlEolPvvsMxiNRvzwww9Qq9W48847AQCzZs3CoUOHcPz4cej1eqxbtw5xcXEYPHhwF1vUPK+88grOnz+P//73v/D19ZW2u6O9Go0Gu3btQn19PXiex+nTp/HNN99g3LhxbmfvjBkzsH//fuzevRu7d+/G0qVLoVKpsHv3btxxxx1uZavTdHE0TY+gtraWPP300yQuLo7cdtttZOvWrV3dJZdZt24dUalUNn/Lly8nhBBy+fJlct9995Hhw4eT//u//yOpqak27/3555/J5MmTSUxMDFmwYAEpLS3tChOcRq1WE5VKRW699VYSFxcn/W3cuJEQ4n72ajQa8sgjj5DExEQSFxdHpk+fTj755BPC8zwhxP3stWbHjh1SaCIh7m2rI2gNUAqFQnEDqJuFQqFQ3AAq5hQKheIGUDGnUCgUN4CKOYVCobgBVMwpFArFDaBiTqFQKG4AFXMKhUJxA6iYUygUihtAxZzS64iOjsaxY8e6uhttZsKECdi5c2dXd4PSTaBiTmkX5s+fj+joaERHR+Omm27ChAkTsGbNGhiNRgDACy+8gOjoaHzwwQc27yOEYMqUKYiOjsaJEye6ousUiltAxZzSbvzlL3/BkSNHcPDgQaxduxb79+/Hhg0bpP033ngj9uzZA+sMEqdPn4bZbO6K7nYLTCYTaEYNSntAxZzSbnh5eaF///644YYbMG7cOEybNg0ZGRnS/sTERCmbn8iuXbswa9Ysm3YqKyuxZMkS3HbbbYiPj8dDDz1k0w4AHD9+HHfddRdiYmLwxBNP4NNPP8XkyZOd7mtpaSkeffRRxMbGYs6cOTYpUc+cOYP58+cjMTERY8aMwbPPPouqqiqn2p0/fz7efPNNLF++HHFxcZg0aRJ+/vlnaf+JEycQHR2NP/74A//3f/+H2NhY1NXVQafTYfXq1RgzZgwSExPxxBNPQK1WS+8zGo14+eWXER8fjzvuuAO7du1y2lZK74CKOaVDKCkpwfHjxzF8+HBpG8MwuPvuu7F7924AgMFgQHJyMmbPnm3zXr1ej8TERHz22WfYuXMnIiIisGjRIhgMBgBAXV0d/v73v2P8+PHYtWsXJk+ejM2bN7eqfxs2bMDDDz+MXbt2YcCAAXjxxRelfVqtFvPmzcOOHTuwadMmlJSUYPXq1U63vW3bNgwePBg7d+7E3Llz8Y9//AP5+fk2x3z00UdYs2YNfvjhB3h5eWHVqlXIz8/Hpk2b8O2336Jv375YtGgROI4DAHz66af4/fff8eGHH+KTTz7Bjh07UFNT0yqbKW5Ol+ZspLgNDz/8MLnllltIXFwcGT58OFGpVGTBggXEaDQSQoQK6s899xzJzs4miYmJxGAwkJ9++onMnTuXmEwmolKpSEpKit22zWYziYuLk9KYbt26lUycOJFwHCcd8+yzz5JJkyY51VeVSkU+/fRT6fWZM2eISqUi9fX1do8/e/Ysufnmm4nZbHbqOlinYiWEkAceeICsXbuWEEJISkoKUalU5MSJE9L+wsJCcsstt0jV4wkRKszHxsaSkydPEkIIGTt2LPn666+l/dnZ2USlUpEdO3Y4YTGlNyDv6psJxX1ISkrCo48+Cp7noVar8cYbb+D111/HqlWrpGMiIiIwZMgQHDhwALt27WoyKgcEP/KHH36I/fv3o6KiAhzHQafToaSkBIBQMmzYsGE2BXtvvfVWnD171um+WleVCQoKAgBUVVXBx8cHpaWleOedd3DmzBlUVVWBEAKz2YzKykrccMMNLbYdExPT5PXVq1dttt18883S/7Ozs2E2mzFx4kSbY/R6PdRqNaKjo3Ht2jWbdiMiInp8mTNK+0LFnNJu+Pv7Y8iQIQCA8PBwaDQaPP/881i+fLnNcbNnz8aWLVtw+fJlvPXWW03a2bRpE77//nusXLkS4eHh8PDwQFJSkjRRSgixqcbuCgqFQvq/2JZYZuyFF16AyWTCmjVrMGDAAKjVajz++OMwmUxtOqc1np6e0v+1Wi08PT3t+sH79esn9autNlPcG+ozp3QYMpkMHMc1EcE//elPuHDhAsaPH4/AwMAm7zt//jzuuusuTJ8+HSqVCkqlErW1tdL+8PBwZGRk2NR4vHDhQrv1+/z581iwYAHGjh2LiIgIVFdXt+r96enpTV6Hh4c7PD46Oho6nQ56vR5Dhgyx+fP19YW/vz/69euHtLQ06T25ubk9vpo8pX2hI3NKu6HT6VBRUQFCCAoLC7Fx40aMGDECfn5+Nsf17dsXR48ehYeHh912QkNDcfjwYVy8eBEA8Oabb9oce/fdd+Pdd9/F2rVrMW/ePJw6dQpHjhxpN7dDaGgodu/ejaioKOTn5+OTTz5p1fszMzOxceNG3HXXXdi3bx/OnTuH119/3eHxERERmDZtGp599lm88MILCAsLQ2lpKfbu3Yu///3v6NOnDx544AGsX78egwcPRt++ffH66687vH6U3gkVc0q7sWXLFmzZsgUMwyAoKAhjxozBP/7xD7vHBgQEOGxn8eLFyMvLw4MPPoh+/frh2WefRV5enrTf398f69evx7///W9s27YNY8eOxfz58/Hjjz+2ix1r1qzBypUrMXPmTKhUKjzzzDNYsmSJ0++///77kZ2djXvvvRcBAQH4z3/+g7CwsGbf8/bbb+O9997Diy++iOrqatxwww247bbb4OXlBQB48sknUVpaisWLF8PPzw/Lli2zuSYUCq0BSnELXnrpJVRUVODTTz/t0n7Mnz8fCQkJWLZsWZf2g9L7oCNzSo/ku+++Q1RUFPr06YOjR49i9+7dWLt2bVd3i0LpMqiYU3okJSUlWLduHaqrqxESEoKXXnoJM2fOBADEx8fbfc+gQYPw008/tem8CxcutFnBak1b26ZQ2gJ1s1DcjsarLUXkcjmCg4Pb1HZZWRn0er3dfcHBwZDL6fiI0jVQMadQKBQ3gMaZUygUihtAxZxCoVDcACrmFAqF4gZQMadQKBQ3gIo5hUKhuAFUzCkUCsUNoGJOoVAobsD/B/aKQ2qSUcaUAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "figsize = (3.28125003459, 1.8)\n", + "sns.set(\n", + " rc={\"figure.figsize\": figsize}, context=\"paper\", style=\"whitegrid\", font_scale=0.75\n", + ")\n", + "\n", + "max_dict = {\n", + " \"BMag_ha\": 400,\n", + " \"V_ha\": 1000,\n", + "}\n", + "\n", + "y_range = {\n", + " \"BMag_ha\": (-225, 325),\n", + " \"V_ha\": (-310, 440),\n", + "}\n", + "\n", + "y_unit = {\n", + " \"BMag_ha\": \"\\,Mg\\,ha\\$^{-1}\\$\",\n", + " \"V_ha\": \"\\,m$^3\\$\\,ha\\$^-1\\$\",\n", + "}\n", + "\n", + "\n", + "split = \"test\"\n", + "\n", + "for target in [\"BMag_ha\"]: #target_vars:\n", + " print(target)\n", + " # if target != \"BMag_ha\": continue\n", + " for method in result_scores[split].method.unique():\n", + " # f, ax = plt.subplots()\n", + " run_i = best[method]\n", + " dff = results_corrected[method].query(f\"run == {run_i} & split == @split\")\n", + " if len(dff) == 0: continue\n", + " print(method)\n", + " sns.set(\n", + " rc={\"figure.figsize\": figsize},\n", + " context=\"paper\",\n", + " style=\"whitegrid\",\n", + " font_scale=1,\n", + " )\n", + " y_min, y_max = y_range[target]\n", + " #y_max = (dff[target] - dff[f\"{target}_pred\"]).max()\n", + " #y_min = (dff[target] - dff[f\"{target}_pred\"]).min()\n", + " x_min = 0\n", + " x_max = max_dict[target]\n", + "\n", + " # n_bins = 500//10\n", + " # bins = np.linspace(0, 500, n_bins+1)#np.histogram_bin_edges(dff[target], n_bins, range=(0,500))\n", + " # err = abs(dff[target] - group_df[f\"{target}_pred\"]).values\n", + " # bin_err = []\n", + " # bin_x = []\n", + " # for i in range(n_bins):\n", + " # mask = (bins[i] <= dff[target]) & (dff[target] < bins[i+1])\n", + " # err_ = err[mask].std()\n", + " # err_ = 0 if np.isnan(err_) else err_\n", + " # bin_err.append(err_)\n", + " # bin_err.append(err_)\n", + " # bin_x.append(bins[i])\n", + " # bin_x.append(bins[i+1])\n", + " # bin_err = np.array(bin_err)\n", + "\n", + " # ax.fill_between(bin_x, -bin_err, bin_err, alpha=0.5, color=\"g\")\n", + "\n", + " # ax.scatter(group_df[target], group_df[target] - group_df[f\"{target}_pred\"], s=5, label=group)\n", + "\n", + " f = sns.jointplot(\n", + " y=target,\n", + " x=f\"{target}_pred\",\n", + " kind=\"resid\",\n", + " data=dff,\n", + " # label=group,\n", + " robust=False,\n", + " scatter_kws={\"s\": 5},\n", + " marginal_kws={\"edgecolor\": \".0\", \"linewidth\": 0.00},\n", + " )\n", + " # f.ax_joint.legend()\n", + "\n", + " # f.ax_marg_x.set_title(group)\n", + " fig = plt.gcf()\n", + " fig.delaxes(f.ax_marg_x)\n", + " fig.set_size_inches(figsize)\n", + " # plt.title(method)\n", + " plt.subplots_adjust(\n", + " left=0.125, right=1, top=1.2, bottom=0.12, hspace=0.1, wspace=0.1\n", + " )\n", + " me = (\n", + " result_scores[split]\n", + " .query(f\"method == '{method}' & target == '{target}' & corrected == True & run == {run_i} & not treeval\")[\n", + " \"mean bias\"\n", + " ]\n", + " .median()\n", + " )\n", + " f.ax_joint.text(\n", + " 0.525,\n", + " 0.9,\n", + " f\"mean bias: {me:0.3f}{y_unit[target]}\",\n", + " horizontalalignment=\"center\",\n", + " verticalalignment=\"center\",\n", + " transform=f.ax_joint.transAxes,\n", + " bbox={\"facecolor\": \"white\", \"alpha\": 0.75, \"pad\": .1},\n", + " )\n", + " f.ax_joint.set_xlim((x_min, x_max))\n", + " f.ax_joint.set_ylim((y_min, y_max))\n", + " f.ax_marg_y.set_ylim((y_min, y_max))\n", + " f.ax_marg_y.margins(y=0.01)\n", + " f.ax_marg_x.margins(x=0)\n", + " f.ax_joint.margins(y=0.01, x=0)\n", + " group_name = method.replace(\" \", \"\").replace(\"\\\\\", \"\").replace(\"{}\", \"\")\n", + " #plt.savefig(f\"figures/{target}_{group_name}_resid.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e23c3131-25e2-4f2c-85a3-e1f67ee06ad5", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/eval_scripts/eval_deep_learning_v2_size.ipynb b/eval_scripts/eval_deep_learning_v2_size.ipynb new file mode 100644 index 0000000..308b61d --- /dev/null +++ b/eval_scripts/eval_deep_learning_v2_size.ipynb @@ -0,0 +1,1250 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "38cde886", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/stefan/.conda/envs/pts/lib/python3.8/site-packages/geopandas/_compat.py:123: UserWarning: The Shapely GEOS version (3.8.0-CAPI-1.13.1 ) is incompatible with the GEOS version PyGEOS was compiled with (3.10.4-CAPI-1.16.2). Conversions between both will be slow.\n", + " warnings.warn(\n", + "/home/stefan/.conda/envs/pts/lib/python3.8/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.4\n", + " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import geopandas as gpd\n", + "import seaborn as sns\n", + "\n", + "sns.set_context(\"paper\")\n", + "sns.set_style(\"whitegrid\")\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.rcParams[\"svg.fonttype\"] = \"none\"\n", + "from sklearn.metrics import mean_absolute_percentage_error, mean_squared_error, r2_score\n", + "\n", + "sns.set_color_codes()\n", + "from glob import glob\n", + "from itertools import product\n", + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "358d4bc2", + "metadata": {}, + "outputs": [], + "source": [ + "target_vars = [\"BMag_ha\", \"V_ha\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "30bad65e", + "metadata": {}, + "outputs": [], + "source": [ + "bias_correct_splits = [\"val\", \"train\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4d80056e", + "metadata": {}, + "outputs": [], + "source": [ + "# choose one of test, train, val\n", + "splits = [\"train\", \"val\", \"test\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "93cd866d", + "metadata": {}, + "outputs": [], + "source": [ + "models = {\n", + " \"MSENet14\": (\n", + " f\"results_new/SENet14_xy_??.gpkg\",\n", + " ),\n", + " \n", + " \"MSENet14 < 1y 100%\": (\n", + " f\"results_size/SENet14_1y_xy_treeadd_??.gpkg\",\n", + " ),\n", + "\n", + " \"MSENet14 < 1y 75%\": (\n", + " f\"results_size/SENet14_75_1y_xy_treeadd_??.gpkg\",\n", + " ), \n", + "\n", + " \"MSENet14 < 1y 50%\": (\n", + " f\"results_size/SENet14_50_1y_xy_treeadd_??.gpkg\",\n", + " ),\n", + "\n", + " \"MSENet14 < 1y 25%\": (\n", + " f\"results_size/SENet14_25_1y_xy_treeadd_??.gpkg\",\n", + " ),\n", + "\n", + " \"MSENet14 < 1y 12.5%\": (\n", + " f\"results_size/SENet14_12_1y_xy_treeadd_??.gpkg\",\n", + " ),\n", + "\n", + " \"MSENet14 < 1y 6.25%\": (\n", + " f\"results_size/SENet14_6_1y_xy_treeadd_??.gpkg\",\n", + " ),\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "49266c45-aa7e-41a0-a704-f6bcc07bff48", + "metadata": {}, + "outputs": [], + "source": [ + "with open('results_size.pickle', 'rb') as handle:\n", + " results = pickle.load(handle)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "aebad451-aaa4-4ad8-812e-b2713cdffe8d", + "metadata": {}, + "outputs": [], + "source": [ + "# set number of instance\n", + "results[\"MSENet14\"].eval(\"n_samples = 4270\", inplace=True)\n", + "results[\"MSENet14 < 1y 100%\"].eval(\"n_samples = 2636\", inplace=True)\n", + "results[\"MSENet14 < 1y 75%\"].eval(\"n_samples = 1977\", inplace=True)\n", + "results[\"MSENet14 < 1y 50%\"].eval(\"n_samples = 1318\", inplace=True)\n", + "results[\"MSENet14 < 1y 25%\"].eval(\"n_samples = 659\", inplace=True)\n", + "results[\"MSENet14 < 1y 12.5%\"].eval(\"n_samples = 330\", inplace=True)\n", + "results[\"MSENet14 < 1y 6.25%\"].eval(\"n_samples = 165\", inplace=True)" + ] + }, + { + "cell_type": "markdown", + "id": "6f20b57b-0f00-4b69-8585-71a5ee12e3e0", + "metadata": {}, + "source": [ + "# Bias correction" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "d4c831b4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14 0\n", + "0.9151888974556669\n", + "MSENet14 2\n", + "0.9151888974556669\n", + "MSENet14 3\n", + "0.9151888974556669\n", + "MSENet14 4\n", + "0.9151888974556669\n", + "MSENet14 5\n", + "0.9151888974556669\n", + "MSENet14 < 1y 100% 1\n", + "0.9151888974556669\n", + "MSENet14 < 1y 100% 2\n", + "0.9151888974556669\n", + "MSENet14 < 1y 100% 3\n", + "0.9151888974556669\n", + "MSENet14 < 1y 100% 4\n", + "0.9151888974556669\n", + "MSENet14 < 1y 100% 5\n", + "0.9151888974556669\n", + "MSENet14 < 1y 75% 0\n", + "0.9151888974556669\n", + "MSENet14 < 1y 75% 1\n", + "0.9151888974556669\n", + "MSENet14 < 1y 75% 2\n", + "0.9151888974556669\n", + "MSENet14 < 1y 75% 3\n", + "0.9151888974556669\n", + "MSENet14 < 1y 75% 4\n", + "0.9151888974556669\n", + "MSENet14 < 1y 50% 0\n", + "0.9151888974556669\n", + "MSENet14 < 1y 50% 1\n", + "0.9151888974556669\n", + "MSENet14 < 1y 50% 2\n", + "0.9151888974556669\n", + "MSENet14 < 1y 50% 3\n", + "0.9151888974556669\n", + "MSENet14 < 1y 50% 5\n", + "0.9151888974556669\n", + "MSENet14 < 1y 25% 1\n", + "0.9151888974556669\n", + "MSENet14 < 1y 25% 2\n", + "0.9151888974556669\n", + "MSENet14 < 1y 25% 3\n", + "0.9151888974556669\n", + "MSENet14 < 1y 25% 4\n", + "0.9151888974556669\n", + "MSENet14 < 1y 25% 5\n", + "0.9151888974556669\n", + "MSENet14 < 1y 12.5% 0\n", + "0.9151888974556669\n", + "MSENet14 < 1y 12.5% 1\n", + "0.9151888974556669\n", + "MSENet14 < 1y 12.5% 2\n", + "0.9151888974556669\n", + "MSENet14 < 1y 12.5% 3\n", + "0.9151888974556669\n", + "MSENet14 < 1y 12.5% 4\n", + "0.9151888974556669\n", + "MSENet14 < 1y 6.25% 0\n", + "0.9151888974556669\n", + "MSENet14 < 1y 6.25% 1\n", + "0.9151888974556669\n", + "MSENet14 < 1y 6.25% 2\n", + "0.9151888974556669\n", + "MSENet14 < 1y 6.25% 3\n", + "0.9151888974556669\n", + "MSENet14 < 1y 6.25% 5\n", + "0.9151888974556669\n" + ] + } + ], + "source": [ + "# get bias correction\n", + "# we do not include the 0 predictions into the adjustment since they come from a different data distribution\n", + "\n", + "deltas = {}\n", + "results_corrected = {}\n", + "exclude_1y = False\n", + "exclude_pred_0 = False\n", + "clip_0 = False\n", + "for model in models:\n", + " if \"treeval\" in model: # using the original correction\n", + " continue\n", + " corrected = []\n", + " corrected_treeval = []\n", + " for run in pd.unique(results[model][\"run\"]):\n", + " print(model, run)\n", + " pred_vars = [f\"{v}_pred\" for v in target_vars]\n", + " preds_cal = pd.concat(\n", + " [\n", + " results[model].query(f\"(run == {run}) & (split == @split)\")\n", + " for split in bias_correct_splits\n", + " ],\n", + " axis=0,\n", + " )[target_vars + pred_vars + [\"mask\", \"temp_diff_years\"] ].copy(deep=True) \n", + " \n", + " #reds_cal = preds_cal.sample(len(preds_cal))\n", + " \n", + " mask = np.ones_like(preds_cal[\"mask\"])\n", + " if exclude_1y:\n", + " mask &= (preds_cal[\"temp_diff_years\"] <= 1)\n", + " if exclude_pred_0:\n", + " mask &= ~preds_cal[\"mask\"]\n", + " \n", + " correct_ = ~mask == (preds_cal[target_vars] == 0).any(axis=1)\n", + " print(correct_.sum() / len(correct_))\n", + " #print(f\"num vals != 0: {mask.sum()}\")\n", + " y_cal_ = preds_cal[target_vars][mask].values\n", + " preds_cal_ = preds_cal[pred_vars][mask].values\n", + "\n", + " '''\n", + " ds = []\n", + " num_vals = 100\n", + " for i in range(0, len(y_cal_), num_vals):\n", + " mm = np.ones(len(y_cal_), dtype=bool)\n", + " mm[i:i+num_vals] = False\n", + " ds.append((\n", + " y_cal_[mm].astype(np.float64).sum(0)\n", + " - preds_cal_[mm].astype(np.float64).sum(0)\n", + " ) / (mm.sum()))\n", + " delta = np.median(ds, 0)\n", + " '''\n", + " delta = (y_cal_.astype(np.float64).sum(0)\n", + " - preds_cal_.astype(np.float64).sum(0)) / (len(y_cal_))\n", + " deltas[model, run] = delta\n", + " \n", + " # check if calibration is close to 0 on calibration set\n", + " assert np.isclose(0, y_cal_.sum(0) - ((preds_cal_ + delta).sum(0))).all() \n", + " \n", + " # apply delta to all values\n", + " df = results[model].query(f\"run == {run}\")[target_vars + pred_vars + [\"run\", \"mask\", \"split\", \"n_samples\"]]\n", + " dff = df[pred_vars]\n", + " if exclude_pred_0:\n", + " mask = ~df[[\"mask\"]].values\n", + " else:\n", + " mask = np.ones_like(df[[\"mask\"]])\n", + " df[pred_vars] = (dff + delta) * mask + (~mask) * dff\n", + " if clip_0:\n", + " df[pred_vars] = df[pred_vars].mask(dff < 0.00, 0.0)\n", + " corrected.append(df)\n", + "\n", + " results_corrected[model] = pd.concat(corrected, axis=0)" + ] + }, + { + "cell_type": "markdown", + "id": "64349efc-e498-4cd7-9b47-bb677f5a0aab", + "metadata": {}, + "source": [ + "# Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "eb2c8782", + "metadata": {}, + "outputs": [], + "source": [ + "def cohen_d(y1_pred, y2_pred):\n", + " mse1 = (y1_pred**2).mean()\n", + " mse2 = (y2_pred**2).mean()\n", + " \n", + " diff = mse1 - mse2\n", + " s_pooled = np.sqrt((mse1 + mse2) / 2)\n", + " cohens_d = diff / s_pooled\n", + " return cohens_d\n", + "\n", + "def evaluate(name, results):\n", + " print(name)\n", + " columns = [\n", + " \"method\",\n", + " \"target\",\n", + " \"R2\",\n", + " \"MSE\",\n", + " \"RMSE\",\n", + " \"MAPE\",\n", + " \"mean error\",\n", + " \"mean bias\",\n", + " \"rel. error\",\n", + " \"n_samples\",\n", + " \"run\",\n", + " ]\n", + " results_df = []\n", + "\n", + " for target in target_vars:\n", + " pred = target + \"_pred\"\n", + " for run, result in results.groupby(\"run\"):\n", + " mask = mm = result[target] != 0\n", + " #mm = result[pred] != 0\n", + " \n", + " results_df.append(\n", + " pd.DataFrame(\n", + " [\n", + " [\n", + " name,\n", + " target,\n", + " r2_score(result[target], result[pred]),\n", + " mean_squared_error(\n", + " result[target], result[pred]\n", + " ),\n", + " mean_squared_error(\n", + " result[target], result[pred], squared=False\n", + " ),\n", + " mean_absolute_percentage_error(\n", + " result[target][mask], result[pred][mask]\n", + " )\n", + " * 100,\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[pred][mm])\n", + " ),\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[target][mm])\n", + " ,\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / (result[target][mm]).sum()\n", + " )\n", + " * 100,\n", + " result[\"n_samples\"].median(),\n", + " run,\n", + " ]\n", + " ],\n", + " columns=columns,\n", + " )\n", + " )\n", + " results_df = pd.concat(results_df, axis=0)\n", + " return results, results_df\n", + "\n", + "'''\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[pred][mm])\n", + " ),\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / len(result[pred][mm])\n", + " ,\n", + " abs(\n", + " (result[target][mm] - result[pred][mm]).sum()\n", + " / (result[pred][mm]).sum()\n", + " )\n", + " * 100,\n", + "''';\n", + "'''\n", + " abs(\n", + " (result[target] - result[pred]).sum()\n", + " / len(result[pred])\n", + " ),\n", + " (result[target] - result[pred]).sum()\n", + " / len(result[pred])\n", + " ,\n", + " abs(\n", + " (result[target] - result[pred]).sum()\n", + " / (result[pred]).sum()\n", + " )\n", + " * 100,\n", + "''';" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "4635281c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14\n", + "MSENet14\n", + "MSENet14 < 1y 100%\n", + "MSENet14 < 1y 100%\n", + "MSENet14 < 1y 75%\n", + "MSENet14 < 1y 75%\n", + "MSENet14 < 1y 50%\n", + "MSENet14 < 1y 50%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14 < 1y 25%\n", + "MSENet14 < 1y 25%\n", + "MSENet14 < 1y 12.5%\n", + "MSENet14 < 1y 12.5%\n", + "MSENet14 < 1y 6.25%\n", + "MSENet14 < 1y 6.25%\n", + "MSENet14\n", + "MSENet14\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14 < 1y 100%\n", + "MSENet14 < 1y 100%\n", + "MSENet14 < 1y 75%\n", + "MSENet14 < 1y 75%\n", + "MSENet14 < 1y 50%\n", + "MSENet14 < 1y 50%\n", + "MSENet14 < 1y 25%\n", + "MSENet14 < 1y 25%\n", + "MSENet14 < 1y 12.5%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14 < 1y 12.5%\n", + "MSENet14 < 1y 6.25%\n", + "MSENet14 < 1y 6.25%\n", + "MSENet14\n", + "MSENet14\n", + "MSENet14 < 1y 100%\n", + "MSENet14 < 1y 100%\n", + "MSENet14 < 1y 75%\n", + "MSENet14 < 1y 75%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MSENet14 < 1y 50%\n", + "MSENet14 < 1y 50%\n", + "MSENet14 < 1y 25%\n", + "MSENet14 < 1y 25%\n", + "MSENet14 < 1y 12.5%\n", + "MSENet14 < 1y 12.5%\n", + "MSENet14 < 1y 6.25%\n", + "MSENet14 < 1y 6.25%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n", + "/tmp/ipykernel_45615/1417926861.py:9: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = False\n", + "/tmp/ipykernel_45615/1417926861.py:16: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " file.loc[:, \"corrected\"] = True\n" + ] + } + ], + "source": [ + "result_dict = {}\n", + "result_dict_corrected = {}\n", + "result_scores = {}\n", + "for split in splits:\n", + " result_score = []\n", + " for name in models.keys():\n", + " # use corrected version except for linear regressor (optimal already)\n", + " file, scores = evaluate(name, results[name].query(\"split == @split\"))\n", + " file.loc[:, \"corrected\"] = False\n", + " scores.loc[:, \"corrected\"] = False\n", + "\n", + " result_dict[name] = file\n", + " result_score.append(scores)\n", + "\n", + " file, scores = evaluate(name, results_corrected[name].query(\"split == @split\"))\n", + " file.loc[:, \"corrected\"] = True\n", + " scores.loc[:, \"corrected\"] = True\n", + "\n", + " result_dict_corrected[name] = file\n", + " result_score.append(scores)\n", + "\n", + " result_score = pd.concat(result_score, axis=0)\n", + " result_scores[split] = result_score" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2a606535-70d0-4251-9ba7-87a44e604999", + "metadata": {}, + "outputs": [], + "source": [ + "result_scores[\"test\"] = result_scores[\"test\"].query(\"corrected == True\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f7494784-6232-490e-9e16-897f139b890b", + "metadata": {}, + "outputs": [], + "source": [ + "def abs_min(x): return x.iloc[np.argmin(abs(x))]\n", + "def abs_max(x): return x.iloc[np.argmax(abs(x))]\n", + "def abs_median(x): return np.median(abs(x))\n", + "def avg_sign(x): return np.mean(np.sign(x))\n", + "def abs_mean(x): return np.mean(abs(x))\n", + "def arg_abs_min(x): return np.argmin(abs(x))\n", + "def arg_abs_max(x): return np.argmax(abs(x))\n", + "def arg_max(x): return np.argmax(abs(x))\n", + "\n", + "agg = {\n", + " \"R2\": [\"median\", \"max\"],\n", + " #'MSE' : ['median', 'min'],\n", + " 'RMSE' : ['median', 'min'],\n", + " 'MAPE' : ['median', 'min'],\n", + " #\"mean error\": [\"median\", \"max\", \"min\"],\n", + " \"mean bias\": [abs_median, abs_min],\n", + " #'rel. error' : ['median', \"min\"],\n", + "}\n", + "\n", + "rr = (\n", + " result_scores[\"test\"].query(\"target == 'BMag_ha'\")\n", + " .groupby([\"target\", \"n_samples\"])\n", + " .agg(agg)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "abd3ae88-8c20-473a-b415-4e05c58b643c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
R2RMSEMAPEmean bias
medianmaxmedianminmedianminabs_medianabs_min
targetn_samples
BMag_ha165.00.7748400.77734648.02592447.757922859.315691337.0231620.7451730.130643
330.00.7900700.79488746.37324145.838127956.707443355.4326650.6111450.193949
659.00.7881640.81050046.58335344.059004828.811509438.9796170.3410790.228971
1318.00.8079080.81293444.35938743.775150583.615801189.1330630.2421150.200385
1977.00.8128190.82261443.78865742.627488812.118072734.5218550.448185-0.025055
2636.00.8235940.83017542.50956341.709073545.921768398.3767410.2192230.028558
4270.00.8247250.82938842.37314941.805641299.496832192.7774400.665678-0.290542
\n", + "
" + ], + "text/plain": [ + " R2 RMSE MAPE \\\n", + " median max median min median \n", + "target n_samples \n", + "BMag_ha 165.0 0.774840 0.777346 48.025924 47.757922 859.315691 \n", + " 330.0 0.790070 0.794887 46.373241 45.838127 956.707443 \n", + " 659.0 0.788164 0.810500 46.583353 44.059004 828.811509 \n", + " 1318.0 0.807908 0.812934 44.359387 43.775150 583.615801 \n", + " 1977.0 0.812819 0.822614 43.788657 42.627488 812.118072 \n", + " 2636.0 0.823594 0.830175 42.509563 41.709073 545.921768 \n", + " 4270.0 0.824725 0.829388 42.373149 41.805641 299.496832 \n", + "\n", + " mean bias \n", + " min abs_median abs_min \n", + "target n_samples \n", + "BMag_ha 165.0 337.023162 0.745173 0.130643 \n", + " 330.0 355.432665 0.611145 0.193949 \n", + " 659.0 438.979617 0.341079 0.228971 \n", + " 1318.0 189.133063 0.242115 0.200385 \n", + " 1977.0 734.521855 0.448185 -0.025055 \n", + " 2636.0 398.376741 0.219223 0.028558 \n", + " 4270.0 192.777440 0.665678 -0.290542 " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(rr)" + ] + }, + { + "cell_type": "markdown", + "id": "7b282eb4-ea1d-40d6-abd7-f77021248f5f", + "metadata": {}, + "source": [ + "# Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "113fd06e-010e-4e99-a939-5b15eed6bfdd", + "metadata": {}, + "outputs": [], + "source": [ + "result_scores[\"test\"].columns = result_scores[\"test\"].columns.map(lambda x: x.replace(' ', '_'))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "ac10dda3-5bd0-40e4-8760-839a957d760a", + "metadata": {}, + "outputs": [], + "source": [ + "targets = [\"BMag_ha\", \"V_ha\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "3ac99322-4d59-454a-8d54-d94f5cbc9020", + "metadata": {}, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'figures/BMag_ha_R2_size.svg'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[19], line 17\u001b[0m\n\u001b[1;32m 15\u001b[0m ax\u001b[38;5;241m.\u001b[39mset_xlim(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m4300\u001b[39m)\n\u001b[1;32m 16\u001b[0m plt\u001b[38;5;241m.\u001b[39msubplots_adjust(left\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.15\u001b[39m, right\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, top\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m1\u001b[39m, bottom\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.15\u001b[39m, hspace\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.15\u001b[39m, wspace\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.1\u001b[39m)\n\u001b[0;32m---> 17\u001b[0m \u001b[43mplt\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msavefig\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43mf\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mfigures/\u001b[39;49m\u001b[38;5;132;43;01m{\u001b[39;49;00m\u001b[43mtarget\u001b[49m\u001b[38;5;132;43;01m}\u001b[39;49;00m\u001b[38;5;124;43m_R2_size.svg\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/pyplot.py:954\u001b[0m, in \u001b[0;36msavefig\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 951\u001b[0m \u001b[38;5;129m@_copy_docstring_and_deprecators\u001b[39m(Figure\u001b[38;5;241m.\u001b[39msavefig)\n\u001b[1;32m 952\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21msavefig\u001b[39m(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 953\u001b[0m fig \u001b[38;5;241m=\u001b[39m gcf()\n\u001b[0;32m--> 954\u001b[0m res \u001b[38;5;241m=\u001b[39m \u001b[43mfig\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msavefig\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 955\u001b[0m fig\u001b[38;5;241m.\u001b[39mcanvas\u001b[38;5;241m.\u001b[39mdraw_idle() \u001b[38;5;66;03m# Need this if 'transparent=True', to reset colors.\u001b[39;00m\n\u001b[1;32m 956\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m res\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/figure.py:3274\u001b[0m, in \u001b[0;36mFigure.savefig\u001b[0;34m(self, fname, transparent, **kwargs)\u001b[0m\n\u001b[1;32m 3270\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m ax \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maxes:\n\u001b[1;32m 3271\u001b[0m stack\u001b[38;5;241m.\u001b[39menter_context(\n\u001b[1;32m 3272\u001b[0m ax\u001b[38;5;241m.\u001b[39mpatch\u001b[38;5;241m.\u001b[39m_cm_set(facecolor\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m'\u001b[39m, edgecolor\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mnone\u001b[39m\u001b[38;5;124m'\u001b[39m))\n\u001b[0;32m-> 3274\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcanvas\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprint_figure\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/backend_bases.py:2338\u001b[0m, in \u001b[0;36mFigureCanvasBase.print_figure\u001b[0;34m(self, filename, dpi, facecolor, edgecolor, orientation, format, bbox_inches, pad_inches, bbox_extra_artists, backend, **kwargs)\u001b[0m\n\u001b[1;32m 2334\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m 2335\u001b[0m \u001b[38;5;66;03m# _get_renderer may change the figure dpi (as vector formats\u001b[39;00m\n\u001b[1;32m 2336\u001b[0m \u001b[38;5;66;03m# force the figure dpi to 72), so we need to set it again here.\u001b[39;00m\n\u001b[1;32m 2337\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m cbook\u001b[38;5;241m.\u001b[39m_setattr_cm(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mfigure, dpi\u001b[38;5;241m=\u001b[39mdpi):\n\u001b[0;32m-> 2338\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mprint_method\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2339\u001b[0m \u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2340\u001b[0m \u001b[43m \u001b[49m\u001b[43mfacecolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfacecolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2341\u001b[0m \u001b[43m \u001b[49m\u001b[43medgecolor\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43medgecolor\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2342\u001b[0m \u001b[43m \u001b[49m\u001b[43morientation\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43morientation\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2343\u001b[0m \u001b[43m \u001b[49m\u001b[43mbbox_inches_restore\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_bbox_inches_restore\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 2344\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2345\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m 2346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m bbox_inches \u001b[38;5;129;01mand\u001b[39;00m restore_bbox:\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/backend_bases.py:2204\u001b[0m, in \u001b[0;36mFigureCanvasBase._switch_canvas_and_return_print_method..\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 2200\u001b[0m optional_kws \u001b[38;5;241m=\u001b[39m { \u001b[38;5;66;03m# Passed by print_figure for other renderers.\u001b[39;00m\n\u001b[1;32m 2201\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdpi\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mfacecolor\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124medgecolor\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124morientation\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 2202\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mbbox_inches_restore\u001b[39m\u001b[38;5;124m\"\u001b[39m}\n\u001b[1;32m 2203\u001b[0m skip \u001b[38;5;241m=\u001b[39m optional_kws \u001b[38;5;241m-\u001b[39m {\u001b[38;5;241m*\u001b[39minspect\u001b[38;5;241m.\u001b[39msignature(meth)\u001b[38;5;241m.\u001b[39mparameters}\n\u001b[0;32m-> 2204\u001b[0m print_method \u001b[38;5;241m=\u001b[39m functools\u001b[38;5;241m.\u001b[39mwraps(meth)(\u001b[38;5;28;01mlambda\u001b[39;00m \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: \u001b[43mmeth\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 2205\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43m{\u001b[49m\u001b[43mk\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mv\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mitems\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mif\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mk\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01mnot\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mskip\u001b[49m\u001b[43m}\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m 2206\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m: \u001b[38;5;66;03m# Let third-parties do as they see fit.\u001b[39;00m\n\u001b[1;32m 2207\u001b[0m print_method \u001b[38;5;241m=\u001b[39m meth\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/_api/deprecation.py:410\u001b[0m, in \u001b[0;36mdelete_parameter..wrapper\u001b[0;34m(*inner_args, **inner_kwargs)\u001b[0m\n\u001b[1;32m 400\u001b[0m deprecation_addendum \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 401\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mIf any parameter follows \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mname\u001b[38;5;132;01m!r}\u001b[39;00m\u001b[38;5;124m, they should be passed as \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 402\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mkeyword, not positionally.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 403\u001b[0m warn_deprecated(\n\u001b[1;32m 404\u001b[0m since,\n\u001b[1;32m 405\u001b[0m name\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mrepr\u001b[39m(name),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m deprecation_addendum,\n\u001b[1;32m 409\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 410\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minner_args\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minner_kwargs\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/backends/backend_svg.py:1389\u001b[0m, in \u001b[0;36mFigureCanvasSVG.print_svg\u001b[0;34m(self, filename, bbox_inches_restore, metadata, *args)\u001b[0m\n\u001b[1;32m 1355\u001b[0m \u001b[38;5;129m@_api\u001b[39m\u001b[38;5;241m.\u001b[39mdelete_parameter(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m3.5\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124margs\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 1356\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mprint_svg\u001b[39m(\u001b[38;5;28mself\u001b[39m, filename, \u001b[38;5;241m*\u001b[39margs, bbox_inches_restore\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 1357\u001b[0m metadata\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 1358\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 1359\u001b[0m \u001b[38;5;124;03m Parameters\u001b[39;00m\n\u001b[1;32m 1360\u001b[0m \u001b[38;5;124;03m ----------\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1387\u001b[0m \u001b[38;5;124;03m __ DC_\u001b[39;00m\n\u001b[1;32m 1388\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1389\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[43mcbook\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mopen_file_cm\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mw\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mutf-8\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mas\u001b[39;00m fh:\n\u001b[1;32m 1390\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m cbook\u001b[38;5;241m.\u001b[39mfile_requires_unicode(fh):\n\u001b[1;32m 1391\u001b[0m fh \u001b[38;5;241m=\u001b[39m codecs\u001b[38;5;241m.\u001b[39mgetwriter(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m'\u001b[39m)(fh)\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/cbook/__init__.py:506\u001b[0m, in \u001b[0;36mopen_file_cm\u001b[0;34m(path_or_file, mode, encoding)\u001b[0m\n\u001b[1;32m 504\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mopen_file_cm\u001b[39m(path_or_file, mode\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124m\"\u001b[39m, encoding\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 505\u001b[0m \u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Pass through file objects and context-manage path-likes.\"\"\"\u001b[39;00m\n\u001b[0;32m--> 506\u001b[0m fh, opened \u001b[38;5;241m=\u001b[39m \u001b[43mto_filehandle\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath_or_file\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 507\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m fh \u001b[38;5;28;01mif\u001b[39;00m opened \u001b[38;5;28;01melse\u001b[39;00m contextlib\u001b[38;5;241m.\u001b[39mnullcontext(fh)\n", + "File \u001b[0;32m~/.conda/envs/pts/lib/python3.8/site-packages/matplotlib/cbook/__init__.py:492\u001b[0m, in \u001b[0;36mto_filehandle\u001b[0;34m(fname, flag, return_opened, encoding)\u001b[0m\n\u001b[1;32m 490\u001b[0m fh \u001b[38;5;241m=\u001b[39m bz2\u001b[38;5;241m.\u001b[39mBZ2File(fname, flag)\n\u001b[1;32m 491\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 492\u001b[0m fh \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mopen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mfname\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mflag\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mencoding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mencoding\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 493\u001b[0m opened \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 494\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(fname, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mseek\u001b[39m\u001b[38;5;124m'\u001b[39m):\n", + "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'figures/BMag_ha_R2_size.svg'" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "b_r2 = {\n", + " \"BMag_ha\": 0.760720,\n", + " \"V_ha\": 0.763\n", + "}\n", + "\n", + "for target in targets:\n", + " f, ax = plt.subplots(figsize=(7.48031/3, 1.5))\n", + " sns.lineplot(x=\"n_samples\", y=\"R2\", data=result_scores[\"test\"].query(\"target == @target\"), errorbar=\"se\", markers=[\"x\"], dashes=False, ax=ax)\n", + " ax.plot(np.arange(4300), [b_r2[target]] * 4300, linestyle=\"--\",c=\"orange\")\n", + " ax.scatter([4270], [b_r2[target]], c=\"orange\", label=\"\\power{} baseline\")\n", + " ax.set_xticks(range(0, 4300, 1000), labels=[f\"{n/1000}K\" if n > 0 else 0 for n in range(0, 4300, 1000)])\n", + " #ax.legend()\n", + " ax.set_ylabel(\"\\$R^2\\$\")\n", + " ax.set_xlabel(\"training samples \\$n\\$\")\n", + " ax.set_xlim(0, 4300)\n", + " plt.subplots_adjust(left=0.15, right=1, top=1, bottom=0.15, hspace=0.15, wspace=0.1)\n", + " plt.savefig(f\"figures/{target}_R2_size.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e5e5641-4927-4862-823b-798dd44c35c8", + "metadata": {}, + "outputs": [], + "source": [ + "b_rmse = {\n", + " \"BMag_ha\": 49.508961,\n", + " \"V_ha\": 92.82\n", + "}\n", + "\n", + "for target in targets:\n", + " f, ax = plt.subplots(figsize=(7.48031/3, 1.5))\n", + " sns.lineplot(x=\"n_samples\", y=\"RMSE\", data=result_scores[\"test\"].query(\"target == @target\"), markers=True, errorbar=\"se\", ax=ax)\n", + " ax.plot(np.arange(4300), [b_rmse[target]] * 4300, linestyle=\"--\", c=\"orange\")\n", + " ax.scatter([4270], [b_rmse[target]], c=\"orange\", label=\"\\power{} baseline\")\n", + " ax.set_xlim(0, 4300)\n", + " ax.set_xticks(range(0, 4300, 1000), labels=[f\"{n/1000}K\" if n > 0 else 0 for n in range(0, 4300, 1000)])\n", + " plt.subplots_adjust(left=0.15, right=1, top=1, bottom=0.15, hspace=0.15, wspace=0.1)\n", + " ax.set_xlabel(\"training samples \\$n\\$\")\n", + " #ax.legend()\n", + " plt.savefig(f\"figures/{target}_RMSE_size.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "df355b4a-ab63-457d-9362-722744fa6af2", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "b_mbias = {\n", + " \"BMag_ha\": 2.03,\n", + " \"V_ha\": 4.50\n", + "}\n", + "ticks = {\n", + " \"BMag_ha\": np.arange(0, 2.1, .5),\n", + " \"V_ha\": np.arange(0, 4.6, 1)\n", + "}\n", + "\n", + "for target in targets:\n", + " f, ax = plt.subplots(figsize=(7.48031/3, 1.5))\n", + " sns.lineplot(x=\"n_samples\", y=\"mean_bias\", data=result_scores[\"test\"].query(\"target == @target\").eval(\"mean_bias = abs(mean_bias)\"), markers=True, errorbar=\"se\", ax=ax, label=\"MSENet14\")\n", + " ax.plot(np.arange(4300), [b_mbias[target]] * 4300, linestyle=\"--\", c=\"orange\")\n", + " ax.scatter([4270], [b_mbias[target]], c=\"orange\", label=\"\\power{} baseline\")\n", + " ax.set_xticks(range(0, 4300, 1000), labels=[f\"{n/1000}K\" if n > 0 else 0 for n in range(0, 4300, 1000)])\n", + " ax.set_yticks(ticks[target])\n", + " ax.set_ylabel(\"abs. mean bias\")\n", + " ax.set_xlabel(\"training samples \\$n\\$\")\n", + " ax.set_xlim(0, 4300)\n", + " ax.legend()\n", + " plt.subplots_adjust(left=0.15, right=1, top=1, bottom=0.15, hspace=0.15, wspace=0.1)\n", + " #plt.savefig(f\"figures/{target}_mbias_size.svg\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "e2d4332d-38ab-4b13-a60c-e31315f43e1e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "sns.lineplot(x=\"n_samples\", y=\"MAPE\", data=result_scores[\"test\"].query(\"target == 'BMag_ha'\"), markers=True, errorbar=\"se\")\n", + "plt.plot(np.arange(4300), [365.34] * 4300, linestyle=\"--\", label=\"\\power{} baseline\")\n", + "plt.xlim(0, 4300)\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aea2c2c8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/eval_scripts/result_scores_agg.pickle b/eval_scripts/result_scores_agg.pickle new file mode 100644 index 0000000..75a3879 Binary files /dev/null and b/eval_scripts/result_scores_agg.pickle differ diff --git a/eval_scripts/results_new.pickle b/eval_scripts/results_new.pickle new file mode 100644 index 0000000..32f5940 Binary files /dev/null and b/eval_scripts/results_new.pickle differ diff --git a/eval_scripts/results_size.pickle b/eval_scripts/results_size.pickle new file mode 100644 index 0000000..a92b84b Binary files /dev/null and b/eval_scripts/results_size.pickle differ diff --git a/pointcloud_stats_method/learn_with_stats.ipynb b/pointcloud_stats_method/learn_with_stats.ipynb new file mode 100644 index 0000000..64b350c --- /dev/null +++ b/pointcloud_stats_method/learn_with_stats.ipynb @@ -0,0 +1,898 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "f7927e81", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/stefan/.conda/envs/pts/lib/python3.8/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.24.4\n", + " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from pathlib import Path\n", + "import glob\n", + "from joblib import load\n", + "from sklearn.metrics import r2_score" + ] + }, + { + "cell_type": "markdown", + "id": "954b6c05", + "metadata": {}, + "source": [ + "# look at stats" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "836b6c7c", + "metadata": {}, + "outputs": [], + "source": [ + "# we only include the anonomysed set as agreed with the data owners\n", + "\n", + "df_train = pd.read_csv(\"../nfi-data/train_split.csv\")\n", + "df_val = pd.read_csv(\"../nfi-data/val_split.csv\")\n", + "df_test = pd.read_csv(\"../nfi-data/test_split.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "229e28c0", + "metadata": {}, + "outputs": [], + "source": [ + "df_train.eval(\"temp_diff_years = temp_diff_days / 365\", inplace=True)\n", + "df_val.eval(\"temp_diff_years = temp_diff_days / 365\", inplace=True)\n", + "df_test.eval(\"temp_diff_years = temp_diff_days / 365\", inplace=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "01da773a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "919" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(df_val)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f7e8a4e2", + "metadata": {}, + "outputs": [], + "source": [ + "variable_list = [\n", + " \"h_mean_1_\",\n", + " \"h_mean_2_\",\n", + " \"h_std_1_\",\n", + " \"h_std_2_\",\n", + " \"h_coov_1_\",\n", + " \"h_coov_2_\",\n", + " \"h_kur_1_\",\n", + " \"h_kur_2_\",\n", + " \"h_skew_1_\",\n", + " \"h_skew_2_\",\n", + " \"IR_\",\n", + " *[f\"h_q{i}_1_\" for i in [5, 10, 25, 50, 75, 90, 95, 99]],\n", + " *[f\"h_q{i}_2_\" for i in [5, 10, 25, 50, 75, 90, 95, 99]],\n", + " \"temp_diff_years\"\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "62dc65ee", + "metadata": {}, + "outputs": [], + "source": [ + "target_list = [\"BMag_ha\", \"V_ha\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a202a5b8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
h_mean_1_h_mean_2_h_std_1_h_std_2_h_coov_1_h_coov_2_h_kur_1_h_kur_2_h_skew_1_h_skew_2_...h_q99_1_h_q5_2_h_q10_2_h_q25_2_h_q50_2_h_q75_2_h_q90_2_h_q95_2_h_q99_2_temp_diff_years
count919.000000919.000000919.000000919.000000919.000000919.000000919.000000919.000000919.000000919.000000...919.000000919.000000919.000000919.000000919.00000919.000000919.000000919.000000919.000000919.000000
mean7.75688610.7375375.0214823.6190871.0672210.38217822.0670382.0802550.551701-0.235734...16.7719354.3937715.8400248.44227411.0842013.30409415.00724515.91396617.312287-0.070031
std5.6509145.8267002.6188111.7756451.1043520.218708280.69339714.6060874.2206331.242230...7.5695413.8625394.6776385.6693296.379346.7919607.0770317.2507017.4579280.555424
min-0.0248251.1128570.0532500.087785-4.6569720.069645-1.926685-2.000000-13.816676-3.078436...0.0700001.0100001.0200001.0300001.075001.1500001.1970001.2525001.330500-1.000000
25%3.1796626.2398913.0993132.3837290.5351940.262913-1.173350-0.621735-0.683866-0.871200...11.3717501.5900002.1850003.7437506.012508.0450009.77600010.44000011.883950-0.613699
50%6.6729529.6470544.6351123.4010800.7987200.354456-0.470218-0.0289350.003507-0.356868...16.8602002.8650004.1720007.0600009.9500012.36000014.56000015.69650017.2900000.068493
75%11.12915014.8997066.8804954.6253161.1976070.4445631.4943421.1753100.8863960.223805...22.5609005.9340008.33500012.13250015.7600018.27375020.24600021.30700022.7899500.410959
max26.87360428.32845115.99349214.31947113.5157062.9336475577.448026253.40650368.14599312.024318...44.66000021.68000024.49000027.53000029.8050037.08000040.99400042.59000046.5100000.936986
\n", + "

8 rows × 28 columns

\n", + "
" + ], + "text/plain": [ + " h_mean_1_ h_mean_2_ h_std_1_ h_std_2_ h_coov_1_ h_coov_2_ \\\n", + "count 919.000000 919.000000 919.000000 919.000000 919.000000 919.000000 \n", + "mean 7.756886 10.737537 5.021482 3.619087 1.067221 0.382178 \n", + "std 5.650914 5.826700 2.618811 1.775645 1.104352 0.218708 \n", + "min -0.024825 1.112857 0.053250 0.087785 -4.656972 0.069645 \n", + "25% 3.179662 6.239891 3.099313 2.383729 0.535194 0.262913 \n", + "50% 6.672952 9.647054 4.635112 3.401080 0.798720 0.354456 \n", + "75% 11.129150 14.899706 6.880495 4.625316 1.197607 0.444563 \n", + "max 26.873604 28.328451 15.993492 14.319471 13.515706 2.933647 \n", + "\n", + " h_kur_1_ h_kur_2_ h_skew_1_ h_skew_2_ ... h_q99_1_ \\\n", + "count 919.000000 919.000000 919.000000 919.000000 ... 919.000000 \n", + "mean 22.067038 2.080255 0.551701 -0.235734 ... 16.771935 \n", + "std 280.693397 14.606087 4.220633 1.242230 ... 7.569541 \n", + "min -1.926685 -2.000000 -13.816676 -3.078436 ... 0.070000 \n", + "25% -1.173350 -0.621735 -0.683866 -0.871200 ... 11.371750 \n", + "50% -0.470218 -0.028935 0.003507 -0.356868 ... 16.860200 \n", + "75% 1.494342 1.175310 0.886396 0.223805 ... 22.560900 \n", + "max 5577.448026 253.406503 68.145993 12.024318 ... 44.660000 \n", + "\n", + " h_q5_2_ h_q10_2_ h_q25_2_ h_q50_2_ h_q75_2_ h_q90_2_ \\\n", + "count 919.000000 919.000000 919.000000 919.00000 919.000000 919.000000 \n", + "mean 4.393771 5.840024 8.442274 11.08420 13.304094 15.007245 \n", + "std 3.862539 4.677638 5.669329 6.37934 6.791960 7.077031 \n", + "min 1.010000 1.020000 1.030000 1.07500 1.150000 1.197000 \n", + "25% 1.590000 2.185000 3.743750 6.01250 8.045000 9.776000 \n", + "50% 2.865000 4.172000 7.060000 9.95000 12.360000 14.560000 \n", + "75% 5.934000 8.335000 12.132500 15.76000 18.273750 20.246000 \n", + "max 21.680000 24.490000 27.530000 29.80500 37.080000 40.994000 \n", + "\n", + " h_q95_2_ h_q99_2_ temp_diff_years \n", + "count 919.000000 919.000000 919.000000 \n", + "mean 15.913966 17.312287 -0.070031 \n", + "std 7.250701 7.457928 0.555424 \n", + "min 1.252500 1.330500 -1.000000 \n", + "25% 10.440000 11.883950 -0.613699 \n", + "50% 15.696500 17.290000 0.068493 \n", + "75% 21.307000 22.789950 0.410959 \n", + "max 42.590000 46.510000 0.936986 \n", + "\n", + "[8 rows x 28 columns]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_val[variable_list].describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "cd534652", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
BMag_haV_ha
count919.000000919.000000
mean112.378628210.213042
std105.828078195.837800
min0.0000000.000000
25%29.29285552.422856
50%86.551500154.638388
75%165.194985325.370432
max668.0819761176.918249
\n", + "
" + ], + "text/plain": [ + " BMag_ha V_ha\n", + "count 919.000000 919.000000\n", + "mean 112.378628 210.213042\n", + "std 105.828078 195.837800\n", + "min 0.000000 0.000000\n", + "25% 29.292855 52.422856\n", + "50% 86.551500 154.638388\n", + "75% 165.194985 325.370432\n", + "max 668.081976 1176.918249" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_val[target_list].describe()" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "958358bd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([,\n", + " ], dtype=object)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "df_val[target_list].plot.hist(bins=100, subplots=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "08199065", + "metadata": {}, + "outputs": [], + "source": [ + "X_train = df_train[variable_list + target_list ]\n", + "X_val = df_val[variable_list + target_list]\n", + "X_test = df_test[variable_list + target_list ]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "02ed1cb4", + "metadata": {}, + "outputs": [], + "source": [ + "def rmse_loss(x, y):\n", + " return ((x-y)**2).mean()**.5" + ] + }, + { + "cell_type": "markdown", + "id": "e5eb81bf", + "metadata": {}, + "source": [ + "## linear model with sklearn" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "485019d1", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.linear_model import LinearRegression\n", + "from sklearn.impute import SimpleImputer" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e00a025a", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RMSE:\n", + "\tBMag_ha: 51.999\n", + "\tV_ha: 96.744\n", + "R2 score:\n", + "\tBMag_ha: 0.742\n", + "\tV_ha: 0.747\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_10413/256522388.py:2: FutureWarning: The frame.append method is deprecated and will be removed from pandas in a future version. Use pandas.concat instead.\n", + " X_trainval = X_train.append(X_val)\n" + ] + }, + { + "data": { + "text/plain": [ + "[None, None]" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# pure linear model\n", + "X_trainval = X_train.append(X_val)\n", + "\n", + "imputer = SimpleImputer().fit(X_trainval[variable_list])\n", + "X_train_ = imputer.transform(X_trainval[variable_list])\n", + "model = LinearRegression().fit(\n", + " X_train_, \n", + " X_trainval[target_list]\n", + ")\n", + "y_pred = model.predict(imputer.transform(X_test[variable_list]))\n", + "y_pred = np.clip(y_pred, a_min=0, a_max=None)\n", + "rmse = []\n", + "r2 = []\n", + "for i, name in enumerate(target_list):\n", + " rmse.append(rmse_loss(X_test[target_list[i]], y_pred[:, i]))\n", + " r2.append(r2_score(X_test[target_list[i]], y_pred[:, i]))\n", + " \n", + "print(f\"RMSE:\")\n", + "[print(f\"\\t{target}: {score:.3f}\") for target, score in zip(target_list, rmse)]\n", + " \n", + "print(f\"R2 score:\")\n", + "[print(f\"\\t{target}: {score:.3f}\") for target, score in zip(target_list, r2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "73a90ee1", + "metadata": {}, + "outputs": [], + "source": [ + "y_pred_train = model.predict(imputer.transform(X_train[variable_list]))\n", + "# np.savetxt(\"linreg_train.csv\", y_pred_train)\n", + "\n", + "y_pred_val = model.predict(imputer.transform(X_val[variable_list]))\n", + "# np.savetxt(\"linreg_val.csv\", y_pred_val)\n", + "\n", + "y_pred_test = model.predict(imputer.transform(X_test[variable_list]))\n", + "# np.savetxt(\"linreg_test.csv\", y_pred_test)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "cb2e5a83", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RMSE:\n", + "\tBMag_ha: 54.368\n", + "\tV_ha: 95.053\n", + "R2 score:\n", + "\tBMag_ha: 0.736\n", + "\tV_ha: 0.764\n" + ] + }, + { + "data": { + "text/plain": [ + "[None, None]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred_val = np.clip(y_pred_val, a_min=0, a_max=None)\n", + "rmse = []\n", + "r2 = []\n", + "for i, name in enumerate(target_list):\n", + " rmse.append(rmse_loss(X_val[target_list[i]], y_pred_val[:, i]))\n", + " r2.append(r2_score(X_val[target_list[i]], y_pred_val[:, i]))\n", + "\n", + "print(f\"RMSE:\")\n", + "[print(f\"\\t{target}: {score:.3f}\") for target, score in zip(target_list, rmse)]\n", + " \n", + "print(f\"R2 score:\")\n", + "[print(f\"\\t{target}: {score:.3f}\") for target, score in zip(target_list, r2)]" + ] + }, + { + "cell_type": "markdown", + "id": "c66446ac", + "metadata": {}, + "source": [ + "## RF with sklearn" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6456ddfd", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.ensemble import RandomForestRegressor" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "56441402", + "metadata": {}, + "outputs": [], + "source": [ + "imputer = SimpleImputer(strategy=\"constant\", fill_value=-100).fit(X_trainval[variable_list])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "65256bf6", + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import ParameterGrid\n", + "param_grid = {\n", + " # Number of features to consider at every split\n", + " 'max_features': np.arange(0.1, 1.1, 0.1),\n", + " # depth of tree (None means full depth)\n", + " 'max_depth': list(np.arange(5, 21)) + [None],\n", + " # Minimum number of samples required at each leaf node\n", + " 'min_samples_leaf': [1] + list(np.arange(2, 17, 2)),\n", + " # Method of selecting samples for training each tree\n", + " \"max_samples\": np.arange(0.1, 1.1, 0.1),\n", + " 'bootstrap': [True],\n", + "}\n", + "pgrid = ParameterGrid(param_grid)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b3cbedf8", + "metadata": {}, + "outputs": [], + "source": [ + "from tqdm.auto import tqdm\n", + "best_score = -np.inf\n", + "best_p = {}\n", + "best_rf = None\n", + "pbar = tqdm(pgrid)\n", + "for p in pbar:\n", + " pbar.set_postfix_str(str(p))\n", + " pbar.refresh()\n", + " rf = RandomForestRegressor(1000, n_jobs=-1, oob_score=True, **p).fit(\n", + " imputer.transform(X_trainval[variable_list]), \n", + " X_trainval[target_list]\n", + " )\n", + " if rf.oob_score_ > best_score:\n", + " best_score = rf.oob_score_\n", + " best_p = p\n", + " best_rf = rf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ed215d65", + "metadata": {}, + "outputs": [], + "source": [ + "# print(f\"best params: {best_p}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "34f9d264", + "metadata": {}, + "outputs": [], + "source": [ + "best_p = {'bootstrap': True, 'max_depth': 11, 'max_features': 0.9, 'max_samples': 0.2, 'min_samples_leaf': 6}" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "7e25184b", + "metadata": {}, + "outputs": [], + "source": [ + "best_rf = RandomForestRegressor(5000, n_jobs=-1, oob_score=True, **best_p).fit(\n", + " imputer.transform(X_trainval[variable_list]), \n", + " X_trainval[target_list]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "6c3092c8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RMSE:\n", + "\tBMag_ha: 50.176\n", + "\tV_ha: 87.487\n", + "R2 score:\n", + "\tBMag_ha: 0.775\n", + "\tV_ha: 0.800\n" + ] + }, + { + "data": { + "text/plain": [ + "[None, None]" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y_pred = best_rf.predict(imputer.transform(X_val[variable_list]))\n", + "y_pred = np.clip(y_pred, a_min=0, a_max=None)\n", + "rmse = []\n", + "r2 = []\n", + "for i, name in enumerate(target_list):\n", + " rmse.append(rmse_loss(X_val[target_list[i]], y_pred[:, i]))\n", + " r2.append(r2_score(X_val[target_list[i]], y_pred[:, i]))\n", + "print(f\"RMSE:\")\n", + "[print(f\"\\t{target}: {score:.3f}\") for target, score in zip(target_list, rmse)]\n", + " \n", + "print(f\"R2 score:\")\n", + "[print(f\"\\t{target}: {score:.3f}\") for target, score in zip(target_list, r2)]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "12454bdb", + "metadata": {}, + "outputs": [], + "source": [ + "# np.savetxt(\"rf_val.csv\", y_pred)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/torch-points3d/calibrate_bn.py b/torch-points3d/calibrate_bn.py new file mode 100644 index 0000000..61633ea --- /dev/null +++ b/torch-points3d/calibrate_bn.py @@ -0,0 +1,25 @@ +import hydra +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf + +from torch_points3d.trainer import Trainer + + +@hydra.main(config_path="conf", config_name="calibrate_bn") +def main(cfg): + if cfg.pretty_print: + print(OmegaConf.to_yaml(cfg)) + + epochs = cfg["epochs"] + + OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly + trainer = Trainer(cfg) + trainer.iterate_epochs(epochs) + + # # https://github.com/facebookresearch/hydra/issues/440 + GlobalHydra.get_state().clear() + return 0 + + +if __name__ == "__main__": + main() diff --git a/torch-points3d/compile_wrappers.sh b/torch-points3d/compile_wrappers.sh new file mode 100755 index 0000000..8d34270 --- /dev/null +++ b/torch-points3d/compile_wrappers.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Compile cpp subsampling +cd torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling +python3 setup.py build_ext --inplace +cd .. + +# Compile cpp neighbors +cd cpp_neighbors +python3 setup.py build_ext --inplace +cd ../../../.. \ No newline at end of file diff --git a/torch-points3d/conf/calibrate_bn.yaml b/torch-points3d/conf/calibrate_bn.yaml new file mode 100644 index 0000000..c972575 --- /dev/null +++ b/torch-points3d/conf/calibrate_bn.yaml @@ -0,0 +1,32 @@ +defaults: + - visualization: default + - task: ??? + - data: ??? + - debugging: default + +num_workers: 0 +batch_size: 2 +cuda: 0 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: ??? # "{your_path}/outputs/2020-01-28/11-04-13" for example +model_name: ??? +precompute_multi_scale: False # Compute multiscale features on cpu for faster training / inference +epochs: 1 + +pretty_print: True + +wandb: + project: ??? + log: True + public: True + +tracker_options: # Extra options for the tracker + full_res: False + make_submission: True + +hydra: + run: + dir: ${checkpoint_dir}/calibrate + + diff --git a/torch-points3d/conf/config.yaml b/torch-points3d/conf/config.yaml new file mode 100644 index 0000000..11d04b9 --- /dev/null +++ b/torch-points3d/conf/config.yaml @@ -0,0 +1,25 @@ +defaults: # for loading the default.yaml config + - task: ??? + + - visualization: default + - lr_scheduler: exponential + - training: default + - debugging: default + - data: ??? +models: ??? + +job_name: benchmark # prefix name for saving the experiment file. +model_name: ??? # Name of the specific model to load +update_lr_scheduler_on: "on_epoch" # ["on_epoch", "on_num_batch", "on_num_sample"] +selection_stage: "" +pretty_print: False +eval_frequency: 1 + +tracker_options: # Extra options for the tracker + full_res: False + make_submission: False + track_boxes: False + +hydra: + run: + dir: ./outputs/${now:%Y-%m-%d}/${now:%H-%M-%S-%f}/ \ No newline at end of file diff --git a/torch-points3d/conf/data/instance/NFI/default.yaml b/torch-points3d/conf/data/instance/NFI/default.yaml new file mode 100644 index 0000000..b282607 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/default.yaml @@ -0,0 +1,44 @@ +# @package data +class: las_dataset.LasDataset +name: LASRegression +dataset_name: biomass +task: instance +dataroot: data +transform_type: ??? +areas: { + NFI: { + type: object, + pt_files: [2014/*/*.las, 2018/*/*.las, 2019/*/*.las], + label_files: nfi.gpkg, + check_pt_crs: False, + pt_identifier: las_file + }, +} +xy_radius: 15 +x_scale: 30 +y_scale: 30 +z_scale: 40 +x_center: 0.5 +y_center: 0.5 +first_subsampling: 0.0125 +split_col: "split" +log_train_metrics: False +save_local_stats: False +in_memory: True +min_pts_outer: 100 +min_pts_inner: 0 +skip_list: [ "y_mol", "y_mol_mask", "y_cls", "y_cls_mask", "y_reg", "y_reg_mask"] +features: [ ] +stats: [ ] +pre_transform: + - transform: DBSCANZOutlierRemoval + params: + eps: 1.5 # in m + min_samples: 10 + skip_list: ${data.skip_list} + - transform: StartZFromZero + - transform: ZFilter + params: + z_min: -1e-5 + z_max: 50 + skip_keys: ${data.skip_list} diff --git a/torch-points3d/conf/data/instance/NFI/noground/default.yaml b/torch-points3d/conf/data/instance/NFI/noground/default.yaml new file mode 100644 index 0000000..e70d135 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/noground/default.yaml @@ -0,0 +1,50 @@ +# @package data +class: las_dataset.LasDataset +name: LASRegression +dataset_name: biomass +task: instance +dataroot: data +transform_type: ??? +areas: { + NFI: { + type: object, + pt_files: [2014/*/*.las, 2018/*/*.las, 2019/*/*.las], + label_files: nfi.gpkg, + check_pt_crs: False, + pt_identifier: las_file + }, +} +xy_radius: 15 +x_scale: 30 +y_scale: 30 +z_scale: 40 +x_center: 0.5 +y_center: 0.5 +first_subsampling: 0.0125 +split_col: "split" +log_train_metrics: False +save_local_stats: False +in_memory: True +min_pts_outer: 100 +min_pts_inner: 0 +skip_list: [ "y_mol", "y_mol_mask", "y_cls", "y_cls_mask", "y_reg", "y_reg_mask"] +features: [ "classification" ] +stats: [ ] +pre_transform: + - transform: DBSCANZOutlierRemoval + params: + eps: 1.5 # in m + min_samples: 10 + skip_list: ${data.skip_list} + - transform: StartZFromZero + - transform: ZFilter + params: + z_min: -1e-5 + z_max: 50 + skip_keys: ${data.skip_list} + - transform: ClassificationFilter + params: + feature_index: 0 + keep: False + class_indices: [2] + remove_feat: True diff --git a/torch-points3d/conf/data/instance/NFI/noground/reg.yaml b/torch-points3d/conf/data/instance/NFI/noground/reg.yaml new file mode 100755 index 0000000..de71e20 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/noground/reg.yaml @@ -0,0 +1,31 @@ +# @package data +defaults: + - instance/default + - instance/NFI/noground/default + - instance/NFI/transforms/xy + - instance/NFI/transforms/xy-grid + - instance/NFI/transforms/xy-treeadd-eval + - instance/NFI/transforms/xy-eval + - instance/NFI/transforms/sparse + - instance/NFI/transforms/sparse-xy + - instance/NFI/transforms/sparse-ori + - instance/NFI/transforms/sparse-skeleton + - instance/NFI/transforms/sparse-treeadd + - instance/NFI/transforms/sparse-xy-treeadd + - instance/NFI/transforms/sparse-eval + - instance/NFI/transforms/sparse-xy-eval + - instance/NFI/transforms/sparse-treeadd-inner + - instance/NFI/transforms/sparse-treeadd-eval + - instance/NFI/transforms/sparse-xy-treeadd-eval + - instance/NFI/transforms/fixed + - instance/NFI/transforms/fixed-xy + - instance/NFI/transforms/fixed-xy-treeadd-eval + - instance/NFI/transforms/fixed-xy-eval + - instance/NFI/transforms/fixed-skeleton + +processed_folder: "processed_nfi_reg_noground" +targets: { + BMag_ha: { task: regression, weight: 0.5 }, + V_ha: { task: regression, weight: 0.5 }, +} # metrics: m m cm + diff --git a/torch-points3d/conf/data/instance/NFI/reg.yaml b/torch-points3d/conf/data/instance/NFI/reg.yaml new file mode 100755 index 0000000..7f54e94 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/reg.yaml @@ -0,0 +1,25 @@ +# @package data +defaults: + - instance/default + - instance/NFI/default + - instance/NFI/transforms/xy + - instance/NFI/transforms/xy-treeadd-eval + - instance/NFI/transforms/xy-eval + - instance/NFI/transforms/sparse + - instance/NFI/transforms/sparse-xy + - instance/NFI/transforms/sparse-ori + - instance/NFI/transforms/sparse-eval + - instance/NFI/transforms/sparse-xy-eval + - instance/NFI/transforms/sparse-treeadd-eval + - instance/NFI/transforms/sparse-xy-treeadd-eval + - instance/NFI/transforms/fixed + - instance/NFI/transforms/fixed-xy + - instance/NFI/transforms/fixed-xy-treeadd-eval + - instance/NFI/transforms/fixed-xy-eval + +processed_folder: "processed_nfi_reg" +targets: { + BMag_ha: { task: regression, weight: 0.5 }, + V_ha: { task: regression, weight: 0.5 }, +} # metrics: m m cm + diff --git a/torch-points3d/conf/data/instance/NFI/transforms/eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/eval.yaml new file mode 100755 index 0000000..39bfb59 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/eval.yaml @@ -0,0 +1,7 @@ +# @package data + +eval: + test_transform: {} + train_transform: {} + val_transform: {} + diff --git a/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy-eval.yaml new file mode 100644 index 0000000..4ebc1db --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy-eval.yaml @@ -0,0 +1,45 @@ +# @package data + +fixed_xy_eval: + num_points: 12000 + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: FixedPointsOwn + params: + num: ${data.fixed.num_points} + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + train_transform: ${data.fixed_xy_eval.test_transform} + val_transform: ${data.fixed_xy_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy-treeadd-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy-treeadd-eval.yaml new file mode 100644 index 0000000..faf7ea0 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy-treeadd-eval.yaml @@ -0,0 +1,64 @@ +# @package data + +fixed_xy_treeadd_eval: + num_points: 12000 + test_transform: + - transform: RadiusObjectAdder + params: + areas: { + treeDB: { type: object }, + } + root_folder: ${data.dataroot} + dataset_name: treeDB + processed_folder: processed_treeDB_ALS + split: train + rot_x: 0.0 + rot_y: 0.0 + rot_z: 180 + min_radius: 15.1 + max_radius: 20 + n_max_objects: { scene: 10 } + adjust_point_density: False + in_memory: True + zero_center_z: True + p: 1.0 + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: FixedPointsOwn + params: + num: ${data.fixed.num_points} + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + train_transform: ${data.fixed_xy_treeadd_eval.test_transform} + val_transform: ${data.fixed_xy_treeadd_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy.yaml b/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy.yaml new file mode 100644 index 0000000..64ef39c --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/fixed-xy.yaml @@ -0,0 +1,131 @@ +# @package data + +fixed_xy: + num_points: 12000 + train_transform: + - transform: RandomGroundRemoval + params: + min_v: 0.05 # at least 5 cm + max_v: 0.5 # at most 50 cm + p: 0.1 + min_points: 500 + skip_list: ${data.skip_list} + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.0025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: RandomShiftPos + params: + p: 0.5 + max_x: 0.01 + max_y: 0.01 + max_z: 0.0 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + - transform: CopyJitterRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + sigma: 0.005 + clip: 0.015 + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: FixedPointsOwn + params: + num: ${data.fixed.num_points} + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes +# - transform: RandomScaling +# params: +# scales: [ 0.9, 1.1 ] + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: FixedPointsOwn + params: + num: ${data.fixed.num_points} + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + val_transform: ${data.fixed_xy.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/fixed.yaml b/torch-points3d/conf/data/instance/NFI/transforms/fixed.yaml new file mode 100755 index 0000000..3255446 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/fixed.yaml @@ -0,0 +1,123 @@ +# @package data + +fixed: + num_points: 12000 + train_transform: + - transform: RandomGroundRemoval + params: + min_v: 0.05 # at least 5 cm + max_v: 0.5 # at most 50 cm + p: 0.1 + min_points: 500 + skip_list: ${data.skip_list} + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.0025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: RandomShiftPos + params: + p: 0.5 + max_x: 0.01 + max_y: 0.01 + max_z: 0.0 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + - transform: CopyJitterRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + sigma: 0.005 + clip: 0.015 + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: FixedPointsOwn + params: + num: ${data.fixed.num_points} + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes +# - transform: RandomScaling +# params: +# scales: [ 0.9, 1.1 ] + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: FixedPointsOwn + params: + num: ${data.fixed.num_points} + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + val_transform: ${data.fixed.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse-eval.yaml new file mode 100755 index 0000000..a12c30d --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse-eval.yaml @@ -0,0 +1,53 @@ +# @package data + +sparse_eval: + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + train_transform: ${data.sparse_eval.test_transform} + val_transform: ${data.sparse_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse-ori.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse-ori.yaml new file mode 100755 index 0000000..95f75b6 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse-ori.yaml @@ -0,0 +1,99 @@ +# @package data + +sparse_ori: + train_transform: + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + - transform: RandomCoordsFlip + params: + ignored_axis: "z" + p: 0.5 + - transform: ShiftVoxels + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + val_transform: ${data.sparse_ori.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse-treeadd-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse-treeadd-eval.yaml new file mode 100755 index 0000000..4646fd1 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse-treeadd-eval.yaml @@ -0,0 +1,69 @@ +# @package data + +sparse_treeadd_eval: + test_transform: + - transform: RadiusObjectAdder + params: + areas: { + treeDB: { type: object }, + } + root_folder: ${data.dataroot} + dataset_name: treeDB + processed_folder: processed_treeDB_ALS + split: train + #processed_folder: merge_processed_instance_extra + rot_x: 0.0 + rot_y: 0.0 + rot_z: 180 + min_radius: 15.1 + max_radius: 20 + n_max_objects: { scene: 10 } + adjust_point_density: False + in_memory: True + zero_center_z: True + p: 1.0 + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + train_transform: ${data.xy_treeadd_eval.test_transform} + val_transform: ${data.xy_treeadd_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy-eval.yaml new file mode 100755 index 0000000..ce57440 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy-eval.yaml @@ -0,0 +1,63 @@ +# @package data + +sparse_xy_eval: + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddGround # only triggers for empty plots + params: + max_points: 1 + n_points: 1000 + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: AddGround # only triggers for empty plots + params: + max_points: 1 + n_points: 1000 + xy_min: 0.25 + xy_max: 0.75 + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + train_transform: ${data.sparse_xy_eval.test_transform} + val_transform: ${data.sparse_xy_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy-treeadd-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy-treeadd-eval.yaml new file mode 100755 index 0000000..252756b --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy-treeadd-eval.yaml @@ -0,0 +1,72 @@ +# @package data + +sparse_xy_treeadd_eval: + test_transform: + - transform: RadiusObjectAdder + params: + areas: { + treeDB: { type: object }, + } + root_folder: ${data.dataroot} + dataset_name: treeDB + processed_folder: processed_treeDB_ALS + split: train + rot_x: 0.0 + rot_y: 0.0 + rot_z: 180 + min_radius: 15.1 + max_radius: 20 + n_max_objects: { scene: 10 } + adjust_point_density: False + in_memory: True + zero_center_z: True + p: 1.0 + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + train_transform: ${data.sparse_xy_treeadd_eval.test_transform} + val_transform: ${data.sparse_xy_treeadd_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy.yaml new file mode 100644 index 0000000..5467c50 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse-xy.yaml @@ -0,0 +1,153 @@ +# @package data + +sparse_xy: + train_transform: + - transform: RandomGroundRemoval + params: + min_v: 0.05 # at least 5 cm + max_v: 0.5 # at most 50 cm + p: 0.1 + min_points: 500 + skip_list: ${data.skip_list} + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.0025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: RandomShiftPos + params: + p: 0.5 + max_x: 0.01 + max_y: 0.01 + max_z: 0.0 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + - transform: CopyJitterRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + sigma: 0.005 + clip: 0.015 + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes +# - transform: RandomScaling +# params: +# scales: [ 0.9, 1.1 ] + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + - transform: RandomCoordsFlip + params: + ignored_axis: "z" + p: 0.5 + - transform: ShiftVoxels + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + val_transform: ${data.sparse_xy.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/sparse.yaml b/torch-points3d/conf/data/instance/NFI/transforms/sparse.yaml new file mode 100755 index 0000000..40186ba --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/sparse.yaml @@ -0,0 +1,145 @@ +# @package data + +sparse: + train_transform: + - transform: RandomGroundRemoval + params: + min_v: 0.05 # at least 5 cm + max_v: 0.5 # at most 50 cm + p: 0.1 + min_points: 500 + skip_list: ${data.skip_list} + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.0025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: RandomShiftPos + params: + p: 0.5 + max_x: 0.01 + max_y: 0.01 + max_z: 0.0 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + - transform: CopyJitterRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + sigma: 0.005 + clip: 0.015 + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes +# - transform: RandomScaling +# params: +# scales: [ 0.9, 1.1 ] + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + - transform: RandomCoordsFlip + params: + ignored_axis: "z" + p: 0.5 + - transform: ShiftVoxels + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 16000 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True ] + feat_names: [ ones, pos_z ] + delete_feats: [ True, True ] + input_nc_feats: [ 1, 1 ] + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + val_transform: ${data.sparse.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/treeadd-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/treeadd-eval.yaml new file mode 100755 index 0000000..46b7b27 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/treeadd-eval.yaml @@ -0,0 +1,31 @@ +# @package data + +treeadd_eval: + test_transform: + - transform: RadiusObjectAdder + params: + areas: { + treeDB: { type: object }, + } + root_folder: ${data.dataroot} + dataset_name: treeDB + processed_folder: processed_treeDB_ALS + split: train + #processed_folder: merge_processed_instance_extra + rot_x: 0.0 + rot_y: 0.0 + rot_z: 180 + min_radius: 15.1 + max_radius: 20 + n_max_objects: { scene: 20 } + adjust_point_density: False + in_memory: True + zero_center_z: True + p: 1.0 + indicator_key: tree_add + - transform: CylinderExtend + params: + radius: 15.0 + skip_list: ${data.skip_list} + train_transform: ${data.treeadd_eval.test_transform} + val_transform: ${data.treeadd_eval.test_transform} \ No newline at end of file diff --git a/torch-points3d/conf/data/instance/NFI/transforms/xy-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/xy-eval.yaml new file mode 100644 index 0000000..feeef83 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/xy-eval.yaml @@ -0,0 +1,48 @@ +# @package data + +xy_eval: + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 6144 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + train_transform: ${data.xy_eval.test_transform} + val_transform: ${data.xy_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/xy-grid.yaml b/torch-points3d/conf/data/instance/NFI/transforms/xy-grid.yaml new file mode 100644 index 0000000..c803599 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/xy-grid.yaml @@ -0,0 +1,148 @@ +# @package data + +xy_grid: + train_transform: + - transform: RandomGroundRemoval + params: + min_v: 0.05 # at least 5 cm + max_v: 0.5 # at most 50 cm + p: 0.1 + min_points: 500 + skip_list: ${data.skip_list} + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.0025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: RandomShiftPos + params: + p: 0.5 + max_x: 0.01 + max_y: 0.01 + max_z: 0.0 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + - transform: CopyJitterRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + sigma: 0.005 + clip: 0.015 + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 6144 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes +# - transform: RandomScaling +# params: +# scales: [ 0.9, 1.1 ] + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: GridSampling3D + params: + size: ${data.first_subsampling} + quantize_coords: True + mode: "last" + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 6144 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + val_transform: ${data.xy_grid.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/xy-treeadd-eval.yaml b/torch-points3d/conf/data/instance/NFI/transforms/xy-treeadd-eval.yaml new file mode 100644 index 0000000..40171ea --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/xy-treeadd-eval.yaml @@ -0,0 +1,67 @@ +# @package data + +xy_treeadd_eval: + test_transform: + - transform: RadiusObjectAdder + params: + areas: { + treeDB: { type: object }, + } + root_folder: ${data.dataroot} + dataset_name: treeDB + processed_folder: processed_treeDB_ALS + split: train + rot_x: 0.0 + rot_y: 0.0 + rot_z: 180 + min_radius: 15.1 + max_radius: 20 + n_max_objects: { scene: 10 } + adjust_point_density: False + in_memory: True + zero_center_z: True + p: 1.0 + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 6144 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + train_transform: ${data.xy_treeadd_eval.test_transform} + val_transform: ${data.xy_treeadd_eval.test_transform} diff --git a/torch-points3d/conf/data/instance/NFI/transforms/xy.yaml b/torch-points3d/conf/data/instance/NFI/transforms/xy.yaml new file mode 100644 index 0000000..7746940 --- /dev/null +++ b/torch-points3d/conf/data/instance/NFI/transforms/xy.yaml @@ -0,0 +1,138 @@ +# @package data + +xy: + train_transform: + - transform: RandomGroundRemoval + params: + min_v: 0.05 # at least 5 cm + max_v: 0.5 # at most 50 cm + p: 0.1 + min_points: 500 + skip_list: ${data.skip_list} + - transform: RandomDropout + params: + dropout_ratio: 0.2 + dropout_application_ratio: 0.5 + min_points: 500 + skip_list: ${data.skip_list} + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: RandomNoise + params: + sigma: 0.0025 + - transform: Random3AxisRotation + params: + apply_rotation: True + rot_x: 0 + rot_y: 0 + rot_z: 180 + - transform: RandomShiftPos + params: + p: 0.5 + max_x: 0.01 + max_y: 0.01 + max_z: 0.0 + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: AddRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + - transform: CopyJitterRandomPoints + params: + n_max_points: 12000 + add_ratio_min: 0.01 + add_ratio_max: 0.2 + p: 0.25 + sigma: 0.005 + clip: 0.015 + - transform: RandomPolygon2dExtend + params: + polygons: [ + [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ], + ] + rotate: 180 + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 6144 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes +# - transform: RandomScaling +# params: +# scales: [ 0.9, 1.1 ] + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + test_transform: + - transform: ScalePos + params: + scale_x: ${data.x_scale} + scale_y: ${data.y_scale} + scale_z: ${data.z_scale} + op: "div" + - transform: MoveCenterPosPerSample + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: StartZFromZero + - transform: Polygon2dExtend + params: + polygon: [ + [ 0., 0.5 ], [ 0.25, 0.9330127 ], [ 0.75, 0.9330127 ], + [ 1., 0.5 ], [ 0.75, 0.0669873 ], [ 0.25, 0.0669873 ] + ] + skip_list: ${data.skip_list} + - transform: MaxPoints + params: + num: 6144 + skip_list: ${data.skip_list} + - transform: MinPoints + params: + num: 500 + skip_list: ${data.skip_list} + - transform: XYZFeature + params: + add_x: False + add_y: False + add_z: True + - transform: AddOnes + - transform: AddXYDistanceToCenter + params: + center_x: ${data.x_center} + center_y: ${data.y_center} + - transform: AddFeatsByKeys + params: + list_add_to_x: [ True, True, True ] + feat_names: [ ones, pos_z, xy_distance ] + delete_feats: [ True, True, True ] + input_nc_feats: [ 1, 1, 1 ] + val_transform: ${data.xy.test_transform} diff --git a/torch-points3d/conf/data/instance/default.yaml b/torch-points3d/conf/data/instance/default.yaml new file mode 100644 index 0000000..47c2da5 --- /dev/null +++ b/torch-points3d/conf/data/instance/default.yaml @@ -0,0 +1,2 @@ +# @package data +task: instance \ No newline at end of file diff --git a/torch-points3d/conf/data/instance/treeDB/ALS.yaml b/torch-points3d/conf/data/instance/treeDB/ALS.yaml new file mode 100755 index 0000000..b698464 --- /dev/null +++ b/torch-points3d/conf/data/instance/treeDB/ALS.yaml @@ -0,0 +1,30 @@ +# @package data +defaults: + - instance/default + - instance/treeDB/default + - instance/trees-sparse + - instance/trees-fixed + + +areas: { + treeDB: { + type: object, + pt_files: [ ALS/*.laz ], + label_files: treeDB_epsg_25832.gpkg, + # 'Carpinus betulus', 'Picea abies', 'Larix decidua', + # 'Quercus petraea', 'Fagus sylvatica', 'Quercus rubra', + # 'Pinus sylvestris', 'Pseudotsuga menziesii', 'Quercus robur', + # 'Abies alba', 'Prunus avium', 'Fraxinus excelsior', + # 'Acer pseudoplatanus', 'Tilia spec.', 'Tsuga heterophylla', + # 'Juglans regia', 'Acer campestre', 'Betula pendula', + # 'Prunus serotina', 'Robinia pseudoacacia', 'Sorbus torminalis', + # 'Salix caprea' + alias_targets: [ height_m, mean_crown_diameter_m, DBH_cm, species ], + targets_must_be_present: [ False, False, False, False ], + pt_identifier: file_path, + test_ratio: 0.1, + val_ratio: 0.0 + }, +} +features: [ "return_number", "classification" ] +processed_folder: processed_treeDB_ALS \ No newline at end of file diff --git a/torch-points3d/conf/data/instance/treeDB/default.yaml b/torch-points3d/conf/data/instance/treeDB/default.yaml new file mode 100755 index 0000000..b9dc739 --- /dev/null +++ b/torch-points3d/conf/data/instance/treeDB/default.yaml @@ -0,0 +1,47 @@ +# @package data +class: las_dataset.LasDataset +name: LASRegression +dataset_name: treeDB +task: instance +dataroot: data +transform_type: ??? +xy_radius: 30 +x_scale: 6 +y_scale: 6 +z_scale: 6 +x_center: 0.5 +y_center: 0.5 +first_subsampling: 0.1 +split_col: "split" +log_train_metrics: False +save_local_stats: False +min_pts_outer: 10 +min_pts_inner: 0 +# if samples are stored, you need to reprocess file when you change anything below (until center_z) +targets: { + height_m: { task: mol, num_mixtures: 10, class_tol: 1, weight: 0.25 }, # tolerance of 1m + mean_crown_diameter_m: { task: mol, num_mixtures: 10, class_tol: 0.1, weight: 0.25 }, # tolerance of 10cm + DBH_cm: { task: mol, num_mixtures: 10, class_tol: 1, weight: 0.25 }, # tolerance of 1cm + tree_species: { task: classification, class_names: [ RGR, BOG, DGR, Rest ], + class_mapping: { + "Picea abies": 0, "Fagus sylvatica": 1, "Pseudotsuga menziesii": 2, + "Abies alba": 3, "Quercus petraea": 3, "QUERCUS": 3, "Quercus rubra": 3, + "Quercus robur": 3, "Larix decidua": 3, "Tsuga heterophylla": 3, "Sorbus torminalis": 3, + } }, +} # metrics: m m cm +"features": [ "return_number", "classification" ] +stats: [ ] +skip_list: [ "y_mol", "y_mol_mask", "y_cls", "y_cls_mask", "y_reg", "y_reg_mask" ] +pre_transform: + - transform: DBSCANZOutlierRemoval + params: + eps: 1.5 # in m + min_samples: 10 + skip_list: ${data.skip_list} + - transform: StartZFromZero + - transform: CenterXYbyZ + params: + center_x: 0 + center_y: 0 + z_thresh_min: 0.0 # 0 cm over lowest point + z_thresh_max: 2.5 # 2.5 m over lowest point \ No newline at end of file diff --git a/torch-points3d/conf/debugging/default.yaml b/torch-points3d/conf/debugging/default.yaml new file mode 100644 index 0000000..0bae882 --- /dev/null +++ b/torch-points3d/conf/debugging/default.yaml @@ -0,0 +1,5 @@ +# @package debugging +find_neighbour_dist: False +num_batches: 50 +early_break: False +profiling: False \ No newline at end of file diff --git a/torch-points3d/conf/debugging/early_break.yaml b/torch-points3d/conf/debugging/early_break.yaml new file mode 100644 index 0000000..acde870 --- /dev/null +++ b/torch-points3d/conf/debugging/early_break.yaml @@ -0,0 +1,2 @@ +# @package _group_ +early_break: True \ No newline at end of file diff --git a/torch-points3d/conf/debugging/find_neighbour_dist.yaml b/torch-points3d/conf/debugging/find_neighbour_dist.yaml new file mode 100644 index 0000000..b85bcb0 --- /dev/null +++ b/torch-points3d/conf/debugging/find_neighbour_dist.yaml @@ -0,0 +1,3 @@ +# @package _group_ +find_neighbour_dist: True +num_batches: 20 \ No newline at end of file diff --git a/torch-points3d/conf/eval.yaml b/torch-points3d/conf/eval.yaml new file mode 100644 index 0000000..161fb9b --- /dev/null +++ b/torch-points3d/conf/eval.yaml @@ -0,0 +1,33 @@ +defaults: + - visualization: eval + - task: ??? + - data: ??? + - debugging: default + +num_workers: 0 +batch_size: 2 +cuda: 0 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: ??? # "{your_path}/outputs/2020-01-28/11-04-13" for example +model_name: ??? +precompute_multi_scale: False # Compute multiscale features on cpu for faster training / inference +enable_dropout: False +voting_runs: 1 +eval_stages: [ "val", "test" ] +pretty_print: True + +wandb: + project: ??? + log: False + public: True + +tracker_options: # Extra options for the tracker + full_res: False + make_submission: True + +hydra: + run: + dir: ${checkpoint_dir}/eval/${now:%Y-%m-%d_%H-%M-%S-%f} + + diff --git a/torch-points3d/conf/hydra/job_logging/custom.yaml b/torch-points3d/conf/hydra/job_logging/custom.yaml new file mode 100644 index 0000000..f374a2c --- /dev/null +++ b/torch-points3d/conf/hydra/job_logging/custom.yaml @@ -0,0 +1,19 @@ +# @package _group_ +formatters: + simple: + format: "%(message)s" +root: + handlers: [debug_console_handler, file_handler] +version: 1 +handlers: + debug_console_handler: + level: DEBUG + formatter: simple + class: logging.StreamHandler + stream: ext://sys.stdout + file_handler: + level: DEBUG + formatter: simple + class: logging.FileHandler + filename: train.log +disable_existing_loggers: False diff --git a/torch-points3d/conf/hydra/output/custom.yaml b/torch-points3d/conf/hydra/output/custom.yaml new file mode 100644 index 0000000..f7d9099 --- /dev/null +++ b/torch-points3d/conf/hydra/output/custom.yaml @@ -0,0 +1,4 @@ +# @package _global_ +hydra: + run: + dir: ./outputs/${job_name}/${job_name}-${model_name}-${now:%Y%m%d_%H%M%S} diff --git a/torch-points3d/conf/lr_scheduler/cosine.yaml b/torch-points3d/conf/lr_scheduler/cosine.yaml new file mode 100644 index 0000000..b40d164 --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/cosine.yaml @@ -0,0 +1,4 @@ +# @package _group_ +class: CosineAnnealingLR +params: + T_max: 10 \ No newline at end of file diff --git a/torch-points3d/conf/lr_scheduler/cosineawr.yaml b/torch-points3d/conf/lr_scheduler/cosineawr.yaml new file mode 100644 index 0000000..c64dd80 --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/cosineawr.yaml @@ -0,0 +1,5 @@ +# @package _group_ +class: CosineAnnealingWarmRestarts +params: + T_0: 10 + T_mult: 2 \ No newline at end of file diff --git a/torch-points3d/conf/lr_scheduler/cyclic.yaml b/torch-points3d/conf/lr_scheduler/cyclic.yaml new file mode 100644 index 0000000..67b09cc --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/cyclic.yaml @@ -0,0 +1,5 @@ +# @package _group_ +class: CyclicLR +params: + base_lr: ${training.optim.base_lr} + max_lr: 0.1 diff --git a/torch-points3d/conf/lr_scheduler/exponential.yaml b/torch-points3d/conf/lr_scheduler/exponential.yaml new file mode 100644 index 0000000..43d511a --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/exponential.yaml @@ -0,0 +1,4 @@ +# @package _group_ +class: ExponentialLR +params: + gamma: 0.9885 # = 0.1**(1/200.) divide by 10 every 200 epochs \ No newline at end of file diff --git a/torch-points3d/conf/lr_scheduler/multi_step.yaml b/torch-points3d/conf/lr_scheduler/multi_step.yaml new file mode 100644 index 0000000..11f2c5b --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/multi_step.yaml @@ -0,0 +1,5 @@ +# @package _group_ +class: MultiStepLR +params: + milestones: [80,120,160] + gamma: 0.2 diff --git a/torch-points3d/conf/lr_scheduler/multi_step_reg.yaml b/torch-points3d/conf/lr_scheduler/multi_step_reg.yaml new file mode 100644 index 0000000..cdcae5d --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/multi_step_reg.yaml @@ -0,0 +1,5 @@ +# @package _group_ +class: MultiStepLR +params: + milestones: [600, 1200, 1800, 3000] + gamma: 0.5 diff --git a/torch-points3d/conf/lr_scheduler/plateau.yaml b/torch-points3d/conf/lr_scheduler/plateau.yaml new file mode 100644 index 0000000..2260719 --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/plateau.yaml @@ -0,0 +1,4 @@ +# @package _group_ +class: ReduceLROnPlateau + params: + mode: "min" \ No newline at end of file diff --git a/torch-points3d/conf/lr_scheduler/poly_lr.yaml b/torch-points3d/conf/lr_scheduler/poly_lr.yaml new file mode 100644 index 0000000..188861b --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/poly_lr.yaml @@ -0,0 +1,9 @@ +# @package _group_ +class: PolyLR +params: + on_epoch: + max_iter: 150 + power: 0.9 + on_num_batch: + max_iter: 60000 + power: 2 diff --git a/torch-points3d/conf/lr_scheduler/step.yaml b/torch-points3d/conf/lr_scheduler/step.yaml new file mode 100644 index 0000000..921f63e --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/step.yaml @@ -0,0 +1,6 @@ +# @package _group_ +class: StepLR +params: + step_size: 10 + gamma: 0.9 + last_epoch: -1 \ No newline at end of file diff --git a/torch-points3d/conf/lr_scheduler/warmupcosine.yaml b/torch-points3d/conf/lr_scheduler/warmupcosine.yaml new file mode 100644 index 0000000..cdb83ac --- /dev/null +++ b/torch-points3d/conf/lr_scheduler/warmupcosine.yaml @@ -0,0 +1,5 @@ +# @package _group_ +class: LinearWarmupCosineAnnealingLR +params: + warmup_epochs: 10 + max_epochs: ${training.epochs} \ No newline at end of file diff --git a/torch-points3d/conf/models/default.yaml b/torch-points3d/conf/models/default.yaml new file mode 100644 index 0000000..af203a0 --- /dev/null +++ b/torch-points3d/conf/models/default.yaml @@ -0,0 +1 @@ +# @package models diff --git a/torch-points3d/conf/models/instance/default.yaml b/torch-points3d/conf/models/instance/default.yaml new file mode 100644 index 0000000..97c23be --- /dev/null +++ b/torch-points3d/conf/models/instance/default.yaml @@ -0,0 +1,3 @@ +# @package models +defaults: + - /models/default \ No newline at end of file diff --git a/torch-points3d/conf/models/instance/kpconv.yaml b/torch-points3d/conf/models/instance/kpconv.yaml new file mode 100755 index 0000000..3259122 --- /dev/null +++ b/torch-points3d/conf/models/instance/kpconv.yaml @@ -0,0 +1,89 @@ +# @package models + +KPConv: + class: kpconv.KPConv + conv_type: "PARTIAL_DENSE" + config: + ################## + # Input parameters + ################## + + # Dimension of input points + in_points_dim: 3 + + # Dimension of input features + in_features_dim: FEAT + + # Radius of the input sphere (ignored for models, only used for point clouds) + in_radius: 1.0 + + ################## + # Model parameters + ################## + + # Architecture definition. List of blocks + architecture: [ 'simple', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'resnetb_strided', + 'resnetb', + 'resnetb', + 'global_sum' ] + + # Dimension of the first feature maps + first_features_dim: 64 + + # Batch normalization parameters + use_batch_norm: True + batch_norm_momentum: 0.02 + + ################### + # KPConv parameters + ################### + + # Activation function + activation: relu + + # Number of kernel points + num_kernel_points: 15 + + # Size of the first subsampling grid + first_subsampling_dl: ${data.first_subsampling} + + # Radius of convolution in "number grid cell". (2.5 is the standard value) + conv_radius: 2.5 + + # Radius of deformable convolution in "number grid cell". Larger so that deformed kernel can spread out + deform_radius: 5.0 + + # Kernel point influence radius + KP_extent: 1.0 + + # Influence function when d < KP_extent. ('constant', 'linear', 'gaussian') When d > KP_extent, always zero + KP_influence: 'linear' + + # Aggregation function of KPConv in ('closest', 'sum') + # Decide if you sum all kernel point influences, or if you only take the influence of the closest KP + aggregation_mode: 'sum' + + # Fixed points in the kernel : 'none', 'center' or 'verticals' + fixed_kernel_points: 'center' + + # Use modulateion in deformable convolutions + modulated: False + + # Deformable offset loss + # 'point2point' fitting geometry by penalizing distance from deform point to input points + # 'point2plane' fitting geometry by penalizing distance from deform point to input point triplet (not implemented) + deform_fitting_mode: 'point2point' + deform_fitting_power: 1.0 # Multiplier for the fitting/repulsive loss + deform_lr_factor: 0.1 # Multiplier for learning rate applied to the deformations + repulse_extent: 1.2 # Distance of repulsion for deformed kernel points diff --git a/torch-points3d/conf/models/instance/minkowski_baseline.yaml b/torch-points3d/conf/models/instance/minkowski_baseline.yaml new file mode 100644 index 0000000..f3afbdd --- /dev/null +++ b/torch-points3d/conf/models/instance/minkowski_baseline.yaml @@ -0,0 +1,124 @@ +# @package models +# Minkowski Engine: https://github.com/StanfordVL/MinkowskiEngine/blob/master/examples/minkunet.py + +MPointNet: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "MinkowskiPointNet" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + global_pool: mean + add_pos: True + +ResNet14: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "ResNet14_" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +ResNet18: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "ResNet18_" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +ResNet34: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "ResNet34_" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +ResNet50: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "ResNet50_" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +ResNet101: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "ResNet101_" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + + +SENet14: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "SENet14" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +SENet18: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "SENet18" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +SENet34: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "SENet34" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +SENet50: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "SENet50" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean + +SENet101: + class: minkowski.MinkowskiBaselineModel + conv_type: "SPARSE" + model_name: "SENet101" + D: 3 + activation: "relu" + first_stride: 2 + dropout: 0.0 + drop_path: 0.01 + global_pool: mean \ No newline at end of file diff --git a/torch-points3d/conf/models/instance/minkowski_selfsupervised.yaml b/torch-points3d/conf/models/instance/minkowski_selfsupervised.yaml new file mode 100644 index 0000000..8f7bf51 --- /dev/null +++ b/torch-points3d/conf/models/instance/minkowski_selfsupervised.yaml @@ -0,0 +1,38 @@ +# @package models +# Minkowski Engine: https://github.com/StanfordVL/MinkowskiEngine/blob/master/examples/minkunet.py + +BarlowTwins: + class: minkowski.MinkowskiBarlowTwins + conv_type: "SPARSE" + model_name: "BarlowTwins" + backbone: "SENet34" + D: 3 + activation: elu + proj_activation: relu + first_stride: 1 + dropout: 0 + global_pool: mean + proj_layers: [ 2048, 2048, 2048 ] + proj_last_norm: True + loss_fn: "smoothl1" + scale_loss: { "lambda": 0.0051, "all": 0.1 , } + mode: "train" + backbone_lr: "base_lr" + +VICReg: + class: minkowski.MinkowskiVICReg + conv_type: "SPARSE" + model_name: "BarlowTwins" + backbone: "SENet34" + D: 3 + activation: elu + proj_activation: relu + first_stride: 1 + dropout: 0 + global_pool: mean + proj_layers: [ 2048, 2048, 2048 ] + proj_last_norm: False + loss_fn: "smoothl1" + mode: "train" + backbone_lr: "base_lr" + scale_loss: { "invariance": 25., "variance": 25. , "covariance": 1. } \ No newline at end of file diff --git a/torch-points3d/conf/models/instance/pointnet.yaml b/torch-points3d/conf/models/instance/pointnet.yaml new file mode 100755 index 0000000..f4138a9 --- /dev/null +++ b/torch-points3d/conf/models/instance/pointnet.yaml @@ -0,0 +1,10 @@ +# @package models +# Minkowski Engine: https://github.com/StanfordVL/MinkowskiEngine/blob/master/examples/minkunet.py + +PointNet: + class: pointnext.PointNext + conv_type: "PARTIAL_DENSE" + arch: "pointnet" + radius: ${data.first_subsampling} + stride: 4 + num_points: 8192 diff --git a/torch-points3d/conf/models/instance/pointnext.yaml b/torch-points3d/conf/models/instance/pointnext.yaml new file mode 100755 index 0000000..a92468c --- /dev/null +++ b/torch-points3d/conf/models/instance/pointnext.yaml @@ -0,0 +1,14 @@ +# @package models + +PointNext: + class: pointnext.PointNext + conv_type: "PARTIAL_DENSE" + arch: "pointnext_s" + radius: ${data.first_subsampling} + radius_scaling: 2 + nsample: 32 + stride: 4 + activation: relu + num_points: 8192 + use_mlps: True + diff --git a/torch-points3d/conf/models/instance/pointnext_selfsupervised.yaml b/torch-points3d/conf/models/instance/pointnext_selfsupervised.yaml new file mode 100755 index 0000000..c0874b2 --- /dev/null +++ b/torch-points3d/conf/models/instance/pointnext_selfsupervised.yaml @@ -0,0 +1,37 @@ +# @package models + +BarlowTwins: + class: pointnext.PointNextBarlowTwins + conv_type: "PARTIAL_DENSE" + arch: "pointnext_s" + radius: ${data.first_subsampling} + activation: relu + proj_activation: relu + num_points: 8192 + loss_fn: smoothl1 + stride: 4 + dropout: 0 + global_pool: mean + proj_layers: [ 2048, 2048, 2048 ] + proj_last_norm: True + scale_loss: { "lambda": 0.0051, "all": 0.1 , } + mode: "train" + backbone_lr: "base_lr" + +VICReg: + class: pointnext.PointNextVICReg + conv_type: "PARTIAL_DENSE" + arch: "pointnext_b" + radius: ${data.first_subsampling} + activation: relu + proj_activation: relu + num_points: 8192 + loss_fn: smoothl1 + stride: 4 + dropout: 0 + global_pool: mean + proj_layers: [ 2048, 2048, 2048 ] + proj_last_norm: False + mode: "train" + backbone_lr: "base_lr" + scale_loss: { "invariance": 25., "variance": 25. , "covariance": 1. } \ No newline at end of file diff --git a/torch-points3d/conf/models/instance/simplestnet.yaml b/torch-points3d/conf/models/instance/simplestnet.yaml new file mode 100755 index 0000000..b49c8fe --- /dev/null +++ b/torch-points3d/conf/models/instance/simplestnet.yaml @@ -0,0 +1,5 @@ +# @package models + +SimplestNet: + class: simplestnet.SimplestNet + conv_type: "PARTIAL_DENSE" diff --git a/torch-points3d/conf/sota.yaml b/torch-points3d/conf/sota.yaml new file mode 100644 index 0000000..22c7101 --- /dev/null +++ b/torch-points3d/conf/sota.yaml @@ -0,0 +1,26 @@ +# @package sota +s3dis5: + miou: 67.1 + mrec: 72.8 + +s3dis: + acc: 88.2 + macc: 81.5 + miou: 70.6 + +scannet: + miou: 72.5 + +semantic3d: + miou: 76.0 + acc: 94.4 + +semantickitti: + miou: 50.3 + +modelnet40: + acc: 92.9 + +shapenet: + mciou: 85.1 + miou: 86.4 \ No newline at end of file diff --git a/torch-points3d/conf/task/default.yaml b/torch-points3d/conf/task/default.yaml new file mode 100644 index 0000000..6ffc0e9 --- /dev/null +++ b/torch-points3d/conf/task/default.yaml @@ -0,0 +1,10 @@ +# @package task +defaults: + - /data@_group_: default + - /models@_group_: default + +# By default.yaml we turn off recursive instantiation, allowing the user to instantiate themselves at the appropriate times. +_recursive_: false + +#_target_: lightning_transformers.core.model.TaskTransformer +lr_scheduler: ${lr_scheduler} diff --git a/torch-points3d/conf/task/instance.yaml b/torch-points3d/conf/task/instance.yaml new file mode 100644 index 0000000..9258141 --- /dev/null +++ b/torch-points3d/conf/task/instance.yaml @@ -0,0 +1,7 @@ +# @package task +defaults: + - /task/default + - override /data@_group_: instance/default + - override /models@_group_: instance/default + +name: instance \ No newline at end of file diff --git a/torch-points3d/conf/training/default.yaml b/torch-points3d/conf/training/default.yaml new file mode 100644 index 0000000..bbfb90e --- /dev/null +++ b/torch-points3d/conf/training/default.yaml @@ -0,0 +1,55 @@ +# @package training +# Those arguments defines the training hyper-parameters +epochs: 100 +num_workers: 6 +batch_size: 16 +shuffle: True +cuda: 0 # -1 -> no cuda otherwise takes the specified index +precompute_multi_scale: False # Compute multiscate features on cpu for faster training / inference +optim: + base_lr: 0.001 + # accumulated_gradient: -1 # Accumulate gradient accumulated_gradient * batch_size + grad_clip: -1 + optimizer: + class: Adam + params: + lr: ${training.optim.base_lr} # The path is cut from training + lr_scheduler: ${lr_scheduler} + bn_scheduler: + bn_policy: "step_decay" + params: + bn_momentum: 0.1 + bn_decay: 0.9 + decay_step: 10 + bn_clip: 1e-2 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: "" + +# Those arguments within experiment defines which model, dataset and task to be created for benchmarking +# parameters for Weights and Biases +wandb: + project: default + log: True + name: dev + public: True # It will be display the model within wandb log, else not. + config: + model_name: ${model_name} + +# parameters for TensorBoard Visualization +tensorboard: + log: True + pytorch_profiler: + log: True # activate PyTorch Profiler in TensorBoard + nb_epoch: 3 # number of epochs to profile (0 -> all). + skip_first: 10 # number of first iterations to skip. + wait: 5 # number of iterations where the profiler is disable. + warmup: 3 # number of iterations where the profiler starts tracing but the results are discarded. This is for reducing the profiling overhead. The overhead at the beginning of profiling is high and easy to bring skew to the profiling result. + active: 5 # number of iterations where the profiler is active and records events. + repeat: 0 # number of cycle wait/warmup/active to realise before stoping profiling (0 -> all). + record_shapes: True # save information about operator’s input shapes. + profile_memory: True # track tensor memory allocation/deallocation. + with_stack: True # record source information (file and line number) for the ops. + with_flops: True # use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution). + +enable_mixed: True diff --git a/torch-points3d/conf/training/default_reg.yaml b/torch-points3d/conf/training/default_reg.yaml new file mode 100644 index 0000000..d08bfc2 --- /dev/null +++ b/torch-points3d/conf/training/default_reg.yaml @@ -0,0 +1,41 @@ +# @package training +# Those arguments defines the training hyper-parameters +epochs: 6000 +num_workers: 6 +batch_size: 64 +shuffle: True +cuda: 0 +precompute_multi_scale: False # Compute multiscate features on cpu for faster training / inference +optim: + base_lr: 0.001 + # accumulated_gradient: -1 # Accumulate gradient accumulated_gradient * batch_size + grad_clip: -1 + optimizer: + class: Adam + params: + lr: ${training.optim.base_lr} # The path is cut from training + + lr_scheduler: ${lr_scheduler} + bn_scheduler: + bn_policy: "step_decay" + params: + bn_momentum: 0.1 + bn_decay: 0.9 + decay_step: 3000 + bn_clip: 1e-2 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: "" + +# Those arguments within experiment defines which model, dataset and task to be created for benchmarking +# parameters for Weights and Biases +wandb: + project: default + log: False + notes: + name: + public: True # It will be display the model within wandb log, else not. + + # parameters for TensorBoard Visualization +tensorboard: + log: True diff --git a/torch-points3d/conf/training/nfi/kpconv.yaml b/torch-points3d/conf/training/nfi/kpconv.yaml new file mode 100644 index 0000000..c228f32 --- /dev/null +++ b/torch-points3d/conf/training/nfi/kpconv.yaml @@ -0,0 +1,64 @@ +# @package training +# Those arguments defines the training hyper-parameters +epochs: 310 +num_workers: 4 +batch_size: 32 +cuda: 0 +shuffle: True +optim: + base_lr: 0.005 + grad_clip: 100 + optimizer: +# class: SGD +# params: +# momentum: 0.98 +# lr: ${training.optim.base_lr} # The path is cut from training +# weight_decay: 1e-3 + class: AdaBelief + params: + lr: ${training.optim.base_lr} # The path is cut from training + weight_decay: 1e-2 + lr_scheduler: ${lr_scheduler} +# bn_scheduler: +# bn_policy: "step_decay" +# params: +# bn_momentum: 0.98 +# bn_decay: 0.9 +# decay_step: 1000 +# bn_clip: 1e-2 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: "" + +# Those arguments within experiment defines which model, dataset and task to be created for benchmarking +# parameters for Weights and Biases +wandb: + project: nfi + log: True + name: ${model_name} + public: True # It will be display the model within wandb log, else not. + config: + model_name: ${model_name} + features: ${data.features} + batch_size: ${training.batch_size} + first_subsampling: ${data.first_subsampling} + base_lr: ${training.optim.base_lr} + + +# parameters for TensorBoard Visualization +tensorboard: + log: False + pytorch_profiler: + log: False # activate PyTorch Profiler in TensorBoard + nb_epoch: 3 # number of epochs to profile (0 -> all). + skip_first: 10 # number of first iterations to skip. + wait: 5 # number of iterations where the profiler is disable. + warmup: 3 # number of iterations where the profiler starts tracing but the results are discarded. This is for reducing the profiling overhead. The overhead at the beginning of profiling is high and easy to bring skew to the profiling result. + active: 5 # number of iterations where the profiler is active and records events. + repeat: 0 # number of cycle wait/warmup/active to realise before stoping profiling (0 -> all). + record_shapes: True # save information about operator’s input shapes. + profile_memory: True # track tensor memory allocation/deallocation. + with_stack: True # record source information (file and line number) for the ops. + with_flops: True # use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution). + +enable_mixed: True diff --git a/torch-points3d/conf/training/nfi/minkowski.yaml b/torch-points3d/conf/training/nfi/minkowski.yaml new file mode 100644 index 0000000..a14ce6f --- /dev/null +++ b/torch-points3d/conf/training/nfi/minkowski.yaml @@ -0,0 +1,68 @@ +# @package training +# Those arguments defines the training hyper-parameters +epochs: 310 +num_workers: 4 +batch_size: 32 +cuda: 0 +shuffle: True +optim: + base_lr: 0.005 + grad_clip: 100 + optimizer: +# class: SGD +# params: +# momentum: 0.98 +# lr: ${training.optim.base_lr} # The path is cut from training +# weight_decay: 1e-3 + class: AdaBelief + params: + lr: ${training.optim.base_lr} # The path is cut from training + weight_decay: 1e-2 + lr_scheduler: ${lr_scheduler} +# bn_scheduler: +# bn_policy: "step_decay" +# params: +# bn_momentum: 0.98 +# bn_decay: 0.9 +# decay_step: 1000 +# bn_clip: 1e-2 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: "" + +# Those arguments within experiment defines which model, dataset, and task to be created for benchmarking +# parameters for Weights and Biases +wandb: + project: nfi + log: True + name: ${model_name} + public: True # It will be display the model within wandb log, else not. + config: + model_name: ${model_name} + features: ${data.features} + batch_size: ${training.batch_size} + first_subsampling: ${data.first_subsampling} + base_lr: ${training.optim.base_lr} + activation: ${models.${model_name}.activation} + first_stride: ${models.${model_name}.first_stride} + transform_type: ${data.transform_type} + + +# parameters for TensorBoard Visualization +tensorboard: + log: False + pytorch_profiler: + log: False # activate PyTorch Profiler in TensorBoard + nb_epoch: 3 # number of epochs to profile (0 -> all). + skip_first: 10 # number of first iterations to skip. + wait: 5 # number of iterations where the profiler is disable. + warmup: 3 # number of iterations where the profiler starts tracing but the results are discarded. This is for reducing the profiling overhead. The overhead at the beginning of profiling is high and easy to bring skew to the profiling result. + active: 5 # number of iterations where the profiler is active and records events. + repeat: 0 # number of cycle wait/warmup/active to realise before stoping profiling (0 -> all). + record_shapes: True # save information about operator’s input shapes. + profile_memory: True # track tensor memory allocation/deallocation. + with_stack: True # record source information (file and line number) for the ops. + with_flops: True # use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution). + + +enable_mixed: True diff --git a/torch-points3d/conf/training/nfi/pointnet.yaml b/torch-points3d/conf/training/nfi/pointnet.yaml new file mode 100644 index 0000000..8e2a911 --- /dev/null +++ b/torch-points3d/conf/training/nfi/pointnet.yaml @@ -0,0 +1,65 @@ +# @package training +# Those arguments defines the training hyper-parameters +epochs: 310 +num_workers: 4 +batch_size: 32 +cuda: 0 +shuffle: True +optim: + base_lr: 0.005 + grad_clip: 100 + optimizer: +# class: SGD +# params: +# momentum: 0.98 +# lr: ${training.optim.base_lr} # The path is cut from training +# weight_decay: 1e-3 + class: AdaBelief + params: + lr: ${training.optim.base_lr} # The path is cut from training + weight_decay: 1e-2 + lr_scheduler: ${lr_scheduler} +# bn_scheduler: +# bn_policy: "step_decay" +# params: +# bn_momentum: 0.98 +# bn_decay: 0.9 +# decay_step: 1000 +# bn_clip: 1e-2 +weight_name: "latest" # Used during resume, select with model to load from [miou, macc, acc..., latest] +enable_cudnn: True +checkpoint_dir: "" + +# Those arguments within experiment defines which model, dataset, and task to be created for benchmarking +# parameters for Weights and Biases +wandb: + project: nfi + log: True + name: ${model_name} + public: True # It will be display the model within wandb log, else not. + config: + model_name: ${model_name} + features: ${data.features} + batch_size: ${training.batch_size} + base_lr: ${training.optim.base_lr} + transform_type: ${data.transform_type} + + +# parameters for TensorBoard Visualization +tensorboard: + log: False + pytorch_profiler: + log: False # activate PyTorch Profiler in TensorBoard + nb_epoch: 3 # number of epochs to profile (0 -> all). + skip_first: 10 # number of first iterations to skip. + wait: 5 # number of iterations where the profiler is disable. + warmup: 3 # number of iterations where the profiler starts tracing but the results are discarded. This is for reducing the profiling overhead. The overhead at the beginning of profiling is high and easy to bring skew to the profiling result. + active: 5 # number of iterations where the profiler is active and records events. + repeat: 0 # number of cycle wait/warmup/active to realise before stoping profiling (0 -> all). + record_shapes: True # save information about operator’s input shapes. + profile_memory: True # track tensor memory allocation/deallocation. + with_stack: True # record source information (file and line number) for the ops. + with_flops: True # use formula to estimate the FLOPS of specific operators (matrix multiplication and 2D convolution). + + +enable_mixed: True diff --git a/torch-points3d/conf/visualization/default.yaml b/torch-points3d/conf/visualization/default.yaml new file mode 100644 index 0000000..c503387 --- /dev/null +++ b/torch-points3d/conf/visualization/default.yaml @@ -0,0 +1,14 @@ +# @package _group_ +activate: False +format: ["ply", "tensorboard"] +num_samples_per_epoch: 10 +deterministic: True # False -> Randomly sample elements from epoch to epoch +deterministic_seed: 42 +saved_keys: + pos: [['x', 'float'], ['y', 'float'], ['z', 'float']] + y: [['l', 'float']] + pred: [['p', 'float']] +ply_format: 'binary_big_endian' +tensorboard_mesh: + label: 'y' + prediction: 'pred' diff --git a/torch-points3d/conf/visualization/eval.yaml b/torch-points3d/conf/visualization/eval.yaml new file mode 100644 index 0000000..67e2316 --- /dev/null +++ b/torch-points3d/conf/visualization/eval.yaml @@ -0,0 +1,14 @@ +# @package _group_ +activate: True +format: ["csv"] # image will come later +num_samples_per_epoch: -1 +deterministic: True # False -> Randomly sample elements from epoch to epoch +deterministic_seed: 42 +saved_keys: + pos: [['x', 'float'], ['y', 'float'], ['z', 'float']] + y: [['l', 'float']] + pred: [['p', 'float']] +ply_format: 'binary_big_endian' +tensorboard_mesh: + label: 'y' + prediction: 'pred' diff --git a/torch-points3d/conf/visualization/predict.yaml b/torch-points3d/conf/visualization/predict.yaml new file mode 100644 index 0000000..67e2316 --- /dev/null +++ b/torch-points3d/conf/visualization/predict.yaml @@ -0,0 +1,14 @@ +# @package _group_ +activate: True +format: ["csv"] # image will come later +num_samples_per_epoch: -1 +deterministic: True # False -> Randomly sample elements from epoch to epoch +deterministic_seed: 42 +saved_keys: + pos: [['x', 'float'], ['y', 'float'], ['z', 'float']] + y: [['l', 'float']] + pred: [['p', 'float']] +ply_format: 'binary_big_endian' +tensorboard_mesh: + label: 'y' + prediction: 'pred' diff --git a/torch-points3d/env.yml b/torch-points3d/env.yml new file mode 100644 index 0000000..1557a31 --- /dev/null +++ b/torch-points3d/env.yml @@ -0,0 +1,62 @@ +name: pts +channels: + - pytorch + - nvidia + - metric-learning + - anaconda + - conda-forge + - pyg + - defaults +dependencies: + - cloudpickle + - colorama + - cudatoolkit=11.8 + - gdal + - gdown + - geopandas=0.12.2 + - llvmlite + - numba + - hydra-core + - matplotlib + - openblas + - pandas=1.5.2 + - proj + - python=3.8 + - pyyaml + - rtree + - scikit-image + - scikit-learn + - scipy + - tensorboard + - tqdm + - wandb + - werkzeug + - wheel + - yaml + - ipython + - h5py + - plyfile + - pip + - pytorch::pytorch=2.0 + - pytorch::pytorch-cuda=11.8 + - pyg::pyg=2.3.1 + - pyg::pytorch-cluster + - pyg::pytorch-scatter + - pip: + - jupyter-client + - jupyter-core + - jupyterlab + - lazrs + - open3d + - numba + - widgetsnbextension + - laspy + - plyfile + - dbscan1d + - multimethod + - termcolor + - shortuuid + - easydict + - tabulate + - torchnet + - visdom diff --git a/torch-points3d/env_cpu.yml b/torch-points3d/env_cpu.yml new file mode 100644 index 0000000..e84f623 --- /dev/null +++ b/torch-points3d/env_cpu.yml @@ -0,0 +1,61 @@ +name: pts +channels: + - pytorch + - metric-learning + - anaconda + - conda-forge + - pyg + - defaults +dependencies: + - cloudpickle + - colorama + - gdal + - gdown + - geopandas=0.12.2 + - llvmlite + - numba + - hydra-core + - matplotlib + - openblas + - pandas=1.5.2 + - proj + - python=3.8 + - pyyaml + - rtree + - scikit-image + - scikit-learn + - scipy + - tensorboard + - tqdm + - wandb + - werkzeug + - wheel + - yaml + - ipython + - h5py + - plyfile + - pip + - pytorch::pytorch=2.0 + - pytorch::torchvision + - pytorch::cpuonly + - pyg::pyg=2.3.1 + - pyg::pytorch-cluster + - pyg::pytorch-scatter + - pip: + - jupyter-client + - jupyter-core + - jupyterlab + - lazrs + - open3d + - numba + - widgetsnbextension + - laspy + - plyfile + - dbscan1d + - multimethod + - termcolor + - shortuuid + - easydict + - tabulate + - torchnet + - visdom diff --git a/torch-points3d/eval.py b/torch-points3d/eval.py new file mode 100644 index 0000000..bd35a08 --- /dev/null +++ b/torch-points3d/eval.py @@ -0,0 +1,38 @@ +import hydra +import numpy as np +import torch.random +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf, open_dict + +from torch_points3d.trainer import Trainer + + +@hydra.main(config_path="conf", config_name="eval") +def main(cfg): + rs = cfg.get("random_seed", 21) + + # disable random shuffling and dropping of last batch in training loader + OmegaConf.set_struct(cfg, True) + with open_dict(cfg): + cfg["shuffle"] = False + cfg["drop_last"] = False + + np.random.default_rng(rs) + torch.random.manual_seed(rs) + + OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly + if cfg.pretty_print: + print(OmegaConf.to_yaml(cfg)) + + trainer = Trainer(cfg) + eval_stages = cfg.get("eval_stages", [""]) + for stage in eval_stages: + trainer.eval(stage) + # + # # https://github.com/facebookresearch/hydra/issues/440 + GlobalHydra.get_state().clear() + return 0 + + +if __name__ == "__main__": + main() diff --git a/torch-points3d/torch_points3d/__init__.py b/torch-points3d/torch_points3d/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/applications/__init__.py b/torch-points3d/torch_points3d/applications/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/applications/conf/kpconv/encoder_4.yaml b/torch-points3d/torch_points3d/applications/conf/kpconv/encoder_4.yaml new file mode 100644 index 0000000..0ea0664 --- /dev/null +++ b/torch-points3d/torch_points3d/applications/conf/kpconv/encoder_4.yaml @@ -0,0 +1,68 @@ +class: kpconv.KPConvPaper +conv_type: "PARTIAL_DENSE" +define_constants: + in_grid_size: 0.02 + in_feat: 64 + bn_momentum: 0.2 + output_nc: 256 + max_neighbors: 25 +down_conv: + down_conv_nn: + [ + [[FEAT + 1, in_feat], [in_feat, 2*in_feat]], + [[2*in_feat, 2*in_feat], [2*in_feat, 4*in_feat]], + [[4*in_feat, 4*in_feat], [4*in_feat, 8*in_feat]], + [[8*in_feat, 8*in_feat], [8*in_feat, 16*in_feat]], + [[16*in_feat, 16*in_feat], [16*in_feat, 32 * in_feat]], + ] + grid_size: + [ + [in_grid_size, in_grid_size], + [2*in_grid_size, 2*in_grid_size], + [4*in_grid_size, 4*in_grid_size], + [8*in_grid_size, 8*in_grid_size], + [16*in_grid_size, 16*in_grid_size], + ] + prev_grid_size: + [ + [in_grid_size, in_grid_size], + [in_grid_size, 2*in_grid_size], + [2*in_grid_size, 4*in_grid_size], + [4*in_grid_size, 8*in_grid_size], + [8*in_grid_size, 16*in_grid_size], + ] + block_names: + [ + ["SimpleBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ["ResnetBBlock", "ResnetBBlock"], + ] + has_bottleneck: + [ + [False, True], + [True, True], + [True, True], + [True, True], + [True, True], + ] + deformable: + [ + [False, False], + [False, False], + [False, False], + [False, False], + [False, False], + ] + max_num_neighbors: + [[max_neighbors,max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors], [max_neighbors, max_neighbors]] + module_name: KPDualBlock +innermost: + module_name: GlobalBaseModule + activation: + name: LeakyReLU + negative_slope: 0.2 + aggr: "mean" + nn: [32 * in_feat + 3, 32 * in_feat] + diff --git a/torch-points3d/torch_points3d/applications/minkowski.py b/torch-points3d/torch_points3d/applications/minkowski.py new file mode 100644 index 0000000..b90454e --- /dev/null +++ b/torch-points3d/torch_points3d/applications/minkowski.py @@ -0,0 +1,196 @@ +import os +import sys +from omegaconf import DictConfig, OmegaConf +import logging +import torch +from torch_geometric.data import Batch + +from torch_points3d.applications.modelfactory import ModelFactory +from torch_points3d.modules.MinkowskiEngine.api_modules import * +from torch_points3d.core.base_conv.message_passing import * +from torch_points3d.core.base_conv.partial_dense import * +from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel +from torch_points3d.core.common_modules.base_modules import MLP + +from .utils import extract_output_nc + + +CUR_FILE = os.path.realpath(__file__) +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d") + +log = logging.getLogger(__name__) + + +def Minkowski( + architecture: str = None, input_nc: int = None, num_layers: int = None, config: DictConfig = None, *args, **kwargs +): + """ Create a Minkowski backbone model based on architecture proposed in + https://arxiv.org/abs/1904.08755 + + Parameters + ---------- + architecture : str, optional + Architecture of the model, choose from unet, encoder and decoder + input_nc : int, optional + Number of channels for the input + output_nc : int, optional + If specified, then we add a fully connected head at the end of the network to provide the requested dimension + num_layers : int, optional + Depth of the network + config : DictConfig, optional + Custom config, overrides the num_layers and architecture parameters + in_feat: + Size of the first layer + block: + Type of resnet block, ResBlock by default but can be any of the blocks in modules/MinkowskiEngine/api_modules.py + """ + log.warning( + "Minkowski API is deprecated in favor of the SparseConv3d API. It should be a simple drop in replacement (no change to the API)." + ) + factory = MinkowskiFactory( + architecture=architecture, num_layers=num_layers, input_nc=input_nc, config=config, **kwargs + ) + return factory.build() + + +class MinkowskiFactory(ModelFactory): + def _build_unet(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers)) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return MinkowskiUnet(model_config, None, None, modules_lib, **self.kwargs) + + def _build_encoder(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join(PATH_TO_CONFIG, "encoder_{}.yaml".format(self.num_layers),) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return MinkowskiEncoder(model_config, None, None, modules_lib, **self.kwargs) + + +class BaseMinkowski(UnwrappedUnetBasedModel): + CONV_TYPE = "sparse" + + def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): + super(BaseMinkowski, self).__init__(model_config, model_type, dataset, modules) + self.weight_initialization() + default_output_nc = kwargs.get("default_output_nc", None) + if not default_output_nc: + default_output_nc = extract_output_nc(model_config) + + self._output_nc = default_output_nc + self._has_mlp_head = False + if "output_nc" in kwargs: + self._has_mlp_head = True + self._output_nc = kwargs["output_nc"] + self.mlp = MLP([default_output_nc, self.output_nc], activation=torch.nn.LeakyReLU(0.2), bias=False) + + @property + def has_mlp_head(self): + return self._has_mlp_head + + @property + def output_nc(self): + return self._output_nc + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiConvolution): + ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") + + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _set_input(self, data): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters + ----------- + data: + a dictionary that contains the data itself and its metadata information. + """ + coords = torch.cat([data.batch.unsqueeze(-1).int(), data.coords.int()], -1) + self.input = ME.SparseTensor(features=data.x, coordinates=coords, device=self.device) + if data.pos is not None: + self.xyz = data.pos.to(self.device) + else: + self.xyz = data.coords.to(self.device) + + +class MinkowskiEncoder(BaseMinkowski): + def forward(self, data, *args, **kwargs): + """ + Parameters: + ----------- + data + A SparseTensor that contains the data itself and its metadata information. Should contain + F -- Features [N, C] + coords -- Coords [N, 4] + + Returns + -------- + data: + - x [1, output_nc] + + """ + self._set_input(data) + data = self.input + for i in range(len(self.down_modules)): + data = self.down_modules[i](data) + + out = Batch(x=data.F, batch=data.C[:, 0].long().to(data.F.device)) + if not isinstance(self.inner_modules[0], Identity): + out = self.inner_modules[0](out) + + if self.has_mlp_head: + out.x = self.mlp(out.x) + return out + + +class MinkowskiUnet(BaseMinkowski): + def forward(self, data, *args, **kwargs): + """Run forward pass. + Input --- D1 -- D2 -- D3 -- U1 -- U2 -- output + | |_________| | + |______________________| + + Parameters + ----------- + data + A SparseTensor that contains the data itself and its metadata information. Should contain + F -- Features [N, C] + coords -- Coords [N, 4] + + Returns + -------- + data: + - pos [N, 3] (coords or real pos if xyz is in data) + - x [N, output_nc] + - batch [N] + """ + self._set_input(data) + data = self.input + stack_down = [] + for i in range(len(self.down_modules) - 1): + data = self.down_modules[i](data) + stack_down.append(data) + + data = self.down_modules[-1](data) + stack_down.append(None) + # TODO : Manage the inner module + for i in range(len(self.up_modules)): + data = self.up_modules[i](data, stack_down.pop()) + + out = Batch(x=data.F, pos=self.xyz, batch=data.C[:, 0]) + if self.has_mlp_head: + out.x = self.mlp(out.x) + return out diff --git a/torch-points3d/torch_points3d/applications/modelfactory.py b/torch-points3d/torch_points3d/applications/modelfactory.py new file mode 100644 index 0000000..9346b67 --- /dev/null +++ b/torch-points3d/torch_points3d/applications/modelfactory.py @@ -0,0 +1,99 @@ +from enum import Enum +from omegaconf import DictConfig +import logging + +from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve + +log = logging.getLogger(__name__) + + +class ModelArchitectures(Enum): + UNET = "unet" + ENCODER = "encoder" + DECODER = "decoder" + + +class ModelFactory: + MODEL_ARCHITECTURES = [e.value for e in ModelArchitectures] + + @staticmethod + def raise_enum_error(arg_name, arg_value, options): + raise Exception("The provided argument {} with value {} isn't within {}".format(arg_name, arg_value, options)) + + def __init__( + self, + architecture: str = None, + input_nc: int = None, + num_layers: int = None, + config: DictConfig = None, + **kwargs + ): + if not architecture: + raise ValueError() + self._architecture = architecture.lower() + assert self._architecture in self.MODEL_ARCHITECTURES, ModelFactory.raise_enum_error( + "model_architecture", self._architecture, self.MODEL_ARCHITECTURES + ) + + self._input_nc = input_nc + self._num_layers = num_layers + self._config = config + self._kwargs = kwargs + + if self._config: + log.info("The config will be used to build the model") + + @property + def modules_lib(self): + raise NotImplementedError + + @property + def kwargs(self): + return self._kwargs + + @property + def num_layers(self): + return self._num_layers + + @property + def num_features(self): + return self._input_nc + + def _build_unet(self): + raise NotImplementedError + + def _build_encoder(self): + raise NotImplementedError + + def _build_decoder(self): + raise NotImplementedError + + def build(self): + if self._architecture == ModelArchitectures.UNET.value: + return self._build_unet() + elif self._architecture == ModelArchitectures.ENCODER.value: + return self._build_encoder() + elif self._architecture == ModelArchitectures.DECODER.value: + return self._build_decoder() + else: + raise NotImplementedError + + @staticmethod + def resolve_model(model_config, num_features, kwargs): + """ Parses the model config and evaluates any expression that may contain constants + Overrides any argument in the `define_constants` with keywords wrgument to the constructor + """ + # placeholders to subsitute + constants = { + "FEAT": max(num_features, 0), + } + + # user defined contants to subsitute + if "define_constants" in model_config.keys(): + constants.update(dict(model_config.define_constants)) + define_constants = model_config.define_constants + for key in define_constants.keys(): + value = kwargs.get(key) + if value: + constants[key] = value + resolve(model_config, constants) diff --git a/torch-points3d/torch_points3d/applications/models.py b/torch-points3d/torch_points3d/applications/models.py new file mode 100644 index 0000000..afef03e --- /dev/null +++ b/torch-points3d/torch_points3d/applications/models.py @@ -0,0 +1,15 @@ +import logging + +log = logging.getLogger(__name__) + +try: + from .sparseconv3d import SparseConv3d +except: + log.warning( + "Sparse convolutions are not supported, please install one of the available backends, MinkowskiEngine or MIT SparseConv" + ) + +try: + from .minkowski import Minkowski +except: + log.warning("MinkowskiEngine is not installed.") diff --git a/torch-points3d/torch_points3d/applications/pretrained_api.py b/torch-points3d/torch_points3d/applications/pretrained_api.py new file mode 100644 index 0000000..b71fa2f --- /dev/null +++ b/torch-points3d/torch_points3d/applications/pretrained_api.py @@ -0,0 +1,174 @@ +import os +import logging +import urllib.request +from omegaconf import DictConfig + +# Import building function for model and dataset +from torch_points3d.datasets.dataset_factory import instantiate_dataset +from torch_points3d.models.model_factory import instantiate_model + +# Import BaseModel / BaseDataset for type checking +from torch_points3d.models.base_model import BaseModel +from torch_points3d.datasets.base_dataset import BaseDataset + +from torch_points3d.utils.wandb_utils import Wandb +from torch_points3d.metrics.model_checkpoint import ModelCheckpoint + +log = logging.getLogger(__name__) + +DIR = os.path.dirname(os.path.realpath(__file__)) +CHECKPOINT_DIR = os.path.join(DIR, "weights") + + +def download_file(url, out_file): + if not os.path.exists(out_file): + if not os.path.exists(os.path.dirname(out_file)): + os.makedirs(os.path.dirname(out_file)) + urllib.request.urlretrieve(url, out_file) + else: + log.warning("WARNING: skipping download of existing file " + out_file) + + +class PretainedRegistry(object): + + MODELS = { + "pointnet2_largemsg-s3dis-1": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1e1p0csk/pointnet2_largemsg.pt", + "pointnet2_largemsg-s3dis-2": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/2i499g2e/pointnet2_largemsg.pt", + "pointnet2_largemsg-s3dis-3": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1gyokj69/pointnet2_largemsg.pt", + "pointnet2_largemsg-s3dis-4": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1ejjs4s2/pointnet2_largemsg.pt", + "pointnet2_largemsg-s3dis-5": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/etxij0j6/pointnet2_largemsg.pt", + "pointnet2_largemsg-s3dis-6": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/8n8t391d/pointnet2_largemsg.pt", + "pointgroup-scannet": "https://api.wandb.ai/files/nicolas/panoptic/2ta6vfu2/PointGroup.pt", + "minkowski-res16-s3dis-1": "https://api.wandb.ai/files/nicolas/s3dis-benchmark/1fyr7ri9/Res16UNet34C.pt", + "minkowski-res16-s3dis-2": "https://api.wandb.ai/files/nicolas/s3dis-benchmark/1gdgx2ni/Res16UNet34C.pt", + "minkowski-res16-s3dis-3": "https://api.wandb.ai/files/nicolas/s3dis-benchmark/gt3ttamp/Res16UNet34C.pt", + "minkowski-res16-s3dis-4": "https://api.wandb.ai/files/nicolas/s3dis-benchmark/36yxu3yc/Res16UNet34C.pt", + "minkowski-res16-s3dis-5": "https://api.wandb.ai/files/nicolas/s3dis-benchmark/2r0tsub1/Res16UNet34C.pt", + "minkowski-res16-s3dis-6": "https://api.wandb.ai/files/nicolas/s3dis-benchmark/30yrkk5p/Res16UNet34C.pt", + "minkowski-registration-3dmatch": "https://api.wandb.ai/files/humanpose1/registration/2wvwf92e/MinkUNet_Fragment.pt", + "minkowski-registration-kitti": "https://api.wandb.ai/files/humanpose1/KITTI/2xpy7u1i/MinkUNet_Fragment.pt", + "minkowski-registration-modelnet": "https://api.wandb.ai/files/humanpose1/modelnet/39u5v3bm/MinkUNet_Fragment.pt", + "rsconv-s3dis-1": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/2b99o12e/RSConv_MSN_S3DIS.pt", + "rsconv-s3dis-2": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1onl4h59/RSConv_MSN_S3DIS.pt", + "rsconv-s3dis-3": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/2cau6jua/RSConv_MSN_S3DIS.pt", + "rsconv-s3dis-4": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1qqmzgnz/RSConv_MSN_S3DIS.pt", + "rsconv-s3dis-5": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/378enxsu/RSConv_MSN_S3DIS.pt", + "rsconv-s3dis-6": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/23f4upgc/RSConv_MSN_S3DIS.pt", + "kpconv-s3dis-1": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/okiba8gp/KPConvPaper.pt", + "kpconv-s3dis-2": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/2at56wrm/KPConvPaper.pt", + "kpconv-s3dis-3": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1ipv9lso/KPConvPaper.pt", + "kpconv-s3dis-4": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/2c13jhi0/KPConvPaper.pt", + "kpconv-s3dis-5": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/1kf8yg5s/KPConvPaper.pt", + "kpconv-s3dis-6": "https://api.wandb.ai/files/loicland/benchmark-torch-points-3d-s3dis/2ph7ejss/KPConvPaper.pt", + } + + MOCK_USED_PROPERTIES = { + "pointnet2_largemsg-s3dis-1": {"feature_dimension": 4, "num_classes": 13}, + "pointnet2_largemsg-s3dis-2": {"feature_dimension": 4, "num_classes": 13}, + "pointnet2_largemsg-s3dis-3": {"feature_dimension": 4, "num_classes": 13}, + "pointnet2_largemsg-s3dis-4": {"feature_dimension": 4, "num_classes": 13}, + "pointnet2_largemsg-s3dis-5": {"feature_dimension": 4, "num_classes": 13}, + "pointnet2_largemsg-s3dis-6": {"feature_dimension": 4, "num_classes": 13}, + "pointgroup-scannet": {}, + "rsconv-s3dis-1": {"feature_dimension": 4, "num_classes": 13}, + "rsconv-s3dis-2": {"feature_dimension": 4, "num_classes": 13}, + "rsconv-s3dis-3": {"feature_dimension": 4, "num_classes": 13}, + "rsconv-s3dis-4": {"feature_dimension": 4, "num_classes": 13}, + "rsconv-s3dis-5": {"feature_dimension": 4, "num_classes": 13}, + "rsconv-s3dis-6": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-res16-s3dis-1": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-res16-s3dis-2": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-res16-s3dis-3": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-res16-s3dis-4": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-res16-s3dis-5": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-res16-s3dis-6": {"feature_dimension": 4, "num_classes": 13}, + "minkowski-registration-3dmatch": {"feature_dimension": 1}, + "minkowski-registration-kitti": {"feature_dimension": 1}, + "minkowski-registration-modelnet": {"feature_dimension": 1}, + "kpconv-s3dis-1": {"feature_dimension": 4, "num_classes": 13}, + "kpconv-s3dis-2": {"feature_dimension": 4, "num_classes": 13}, + "kpconv-s3dis-3": {"feature_dimension": 4, "num_classes": 13}, + "kpconv-s3dis-4": {"feature_dimension": 4, "num_classes": 13}, + "kpconv-s3dis-5": {"feature_dimension": 4, "num_classes": 13}, + "kpconv-s3dis-6": {"feature_dimension": 4, "num_classes": 13}, + } + + @staticmethod + def from_pretrained(model_tag, download=True, out_file=None, weight_name="latest", mock_dataset=True): + # Convert inputs to registry format + + if PretainedRegistry.MODELS.get(model_tag) is not None: + url = PretainedRegistry.MODELS.get(model_tag) + else: + raise Exception( + "model_tag {} doesn't exist within available models. Here is the list of pre-trained models {}".format( + model_tag, PretainedRegistry.available_models() + ) + ) + + checkpoint_name = model_tag + ".pt" + out_file = os.path.join(CHECKPOINT_DIR, checkpoint_name) + + if download: + download_file(url, out_file) + + weight_name = weight_name if weight_name is not None else "latest" + + checkpoint: ModelCheckpoint = ModelCheckpoint( + CHECKPOINT_DIR, model_tag, weight_name if weight_name is not None else "latest", resume=False, + ) + if mock_dataset: + dataset = checkpoint.dataset_properties.copy() + if PretainedRegistry.MOCK_USED_PROPERTIES.get(model_tag) is not None: + for k, v in PretainedRegistry.MOCK_USED_PROPERTIES.get(model_tag).items(): + dataset[k] = v + + else: + dataset = instantiate_dataset(checkpoint.data_config) + + model: BaseModel = checkpoint.create_model(dataset, weight_name=weight_name) + + Wandb.set_urls_to_model(model, url) + + BaseDataset.set_transform(model, checkpoint.data_config) + + return model + + @staticmethod + def from_file(path, weight_name="latest", mock_property=None): + """ + Load a pretrained model trained with torch-points3d from file. + return a pretrained model + Parameters + ---------- + path: str + path of a pretrained model + weight_name: str, optional + name of the weight + mock_property: dict, optional + mock dataset + + """ + weight_name = weight_name if weight_name is not None else "latest" + path_dir, name = os.path.split(path) + name = name.split(".")[0] # ModelCheckpoint will add the extension + + checkpoint: ModelCheckpoint = ModelCheckpoint( + path_dir, name, weight_name if weight_name is not None else "latest", resume=False, + ) + dataset = checkpoint.data_config + + if mock_property is not None: + for k, v in mock_property.items(): + dataset[k] = v + + else: + dataset = instantiate_dataset(checkpoint.data_config) + + model: BaseModel = checkpoint.create_model(dataset, weight_name=weight_name) + BaseDataset.set_transform(model, checkpoint.data_config) + return model + + @staticmethod + def available_models(): + return PretainedRegistry.MODELS.keys() diff --git a/torch-points3d/torch_points3d/applications/sparseconv3d.py b/torch-points3d/torch_points3d/applications/sparseconv3d.py new file mode 100644 index 0000000..4be99b5 --- /dev/null +++ b/torch-points3d/torch_points3d/applications/sparseconv3d.py @@ -0,0 +1,208 @@ +import os +import sys +from omegaconf import DictConfig, OmegaConf +import logging +import torch +from torch_geometric.data import Batch + +from torch_points3d.applications.modelfactory import ModelFactory +import torch_points3d.modules.SparseConv3d as sp3d +from torch_points3d.modules.SparseConv3d.modules import * +from torch_points3d.models.base_architectures.unet import UnwrappedUnetBasedModel +from torch_points3d.core.common_modules.base_modules import MLP + +from .utils import extract_output_nc + + +CUR_FILE = os.path.realpath(__file__) +DIR_PATH = os.path.dirname(os.path.realpath(__file__)) +PATH_TO_CONFIG = os.path.join(DIR_PATH, "conf/sparseconv3d") + +log = logging.getLogger(__name__) + + +def SparseConv3d( + architecture: str = None, + input_nc: int = None, + num_layers: int = None, + config: DictConfig = None, + backend: str = "minkowski", + *args, + **kwargs +): + """Create a Sparse Conv backbone model based on architecture proposed in + https://arxiv.org/abs/1904.08755 + + Two backends are available at the moment: + - https://github.com/mit-han-lab/torchsparse + - https://github.com/NVIDIA/MinkowskiEngine + + Parameters + ---------- + architecture : str, optional + Architecture of the model, choose from unet, encoder and decoder + input_nc : int, optional + Number of channels for the input + output_nc : int, optional + If specified, then we add a fully connected head at the end of the network to provide the requested dimension + num_layers : int, optional + Depth of the network + config : DictConfig, optional + Custom config, overrides the num_layers and architecture parameters + block: + Type of resnet block, ResBlock by default but can be any of the blocks in modules/SparseConv3d/modules.py + backend: + torchsparse or minkowski + """ + if "SPARSE_BACKEND" in os.environ and sp3d.nn.backend_valid(os.environ["SPARSE_BACKEND"]): + sp3d.nn.set_backend(os.environ["SPARSE_BACKEND"]) + else: + sp3d.nn.set_backend(backend) + + factory = SparseConv3dFactory( + architecture=architecture, num_layers=num_layers, input_nc=input_nc, config=config, **kwargs + ) + return factory.build() + + +class SparseConv3dFactory(ModelFactory): + def _build_unet(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join(PATH_TO_CONFIG, "unet_{}.yaml".format(self.num_layers)) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return SparseConv3dUnet(model_config, None, None, modules_lib, **self.kwargs) + + def _build_encoder(self): + if self._config: + model_config = self._config + else: + path_to_model = os.path.join( + PATH_TO_CONFIG, + "encoder_{}.yaml".format(self.num_layers), + ) + model_config = OmegaConf.load(path_to_model) + ModelFactory.resolve_model(model_config, self.num_features, self._kwargs) + modules_lib = sys.modules[__name__] + return SparseConv3dEncoder(model_config, None, None, modules_lib, **self.kwargs) + + +class BaseSparseConv3d(UnwrappedUnetBasedModel): + CONV_TYPE = "sparse" + + def __init__(self, model_config, model_type, dataset, modules, *args, **kwargs): + super().__init__(model_config, model_type, dataset, modules) + self.weight_initialization() + default_output_nc = kwargs.get("default_output_nc", None) + if not default_output_nc: + default_output_nc = extract_output_nc(model_config) + + self._output_nc = default_output_nc + self._has_mlp_head = False + if "output_nc" in kwargs: + self._has_mlp_head = True + self._output_nc = kwargs["output_nc"] + self.mlp = MLP([default_output_nc, self.output_nc], activation=torch.nn.ReLU(), bias=False) + + @property + def has_mlp_head(self): + return self._has_mlp_head + + @property + def output_nc(self): + return self._output_nc + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, sp3d.nn.Conv3d) or isinstance(m, sp3d.nn.Conv3dTranspose): + torch.nn.init.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") + + if isinstance(m, sp3d.nn.BatchNorm): + torch.nn.init.constant_(m.bn.weight, 1) + torch.nn.init.constant_(m.bn.bias, 0) + + def _set_input(self, data): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters + ----------- + data: + a dictionary that contains the data itself and its metadata information. + """ + self.input = sp3d.nn.SparseTensor(data.x, data.coords, data.batch, self.device) + if data.pos is not None: + self.xyz = data.pos + else: + self.xyz = data.coords + +class SparseConv3dEncoder(BaseSparseConv3d): + def forward(self, data, *args, **kwargs): + """ + Parameters: + ----------- + data + A SparseTensor that contains the data itself and its metadata information. Should contain + F -- Features [N, C] + coords -- Coords [N, 4] + + Returns + -------- + data: + - x [1, output_nc] + + """ + self._set_input(data) + data = self.input + for i in range(len(self.down_modules)): + data = self.down_modules[i](data) + + out = Batch(x=data.F, batch=data.C[:, -1].long().to(data.F.device)) + if not isinstance(self.inner_modules[0], Identity): + out = self.inner_modules[0](out) + + if self.has_mlp_head: + out.x = self.mlp(out.x) + return out + + +class SparseConv3dUnet(BaseSparseConv3d): + def forward(self, data, *args, **kwargs): + """Run forward pass. + Input --- D1 -- D2 -- D3 -- U1 -- U2 -- output + | |_________| | + |______________________| + + Parameters + ----------- + data + A SparseTensor that contains the data itself and its metadata information. Should contain + F -- Features [N, C] + coords -- Coords [N, 4] + + Returns + -------- + data: + - pos [N, 3] (coords or real pos if xyz is in data) + - x [N, output_nc] + - batch [N] + """ + self._set_input(data) + data = self.input + stack_down = [] + for i in range(len(self.down_modules) - 1): + data = self.down_modules[i](data) + stack_down.append(data) + + data = self.down_modules[-1](data) + stack_down.append(None) + # TODO : Manage the inner module + for i in range(len(self.up_modules)): + data = self.up_modules[i](data, stack_down.pop()) + + out = Batch(x=data.F, pos=self.xyz).to(self.device) + if self.has_mlp_head: + out.x = self.mlp(out.x) + return out diff --git a/torch-points3d/torch_points3d/applications/utils.py b/torch-points3d/torch_points3d/applications/utils.py new file mode 100644 index 0000000..c79e310 --- /dev/null +++ b/torch-points3d/torch_points3d/applications/utils.py @@ -0,0 +1,10 @@ +def extract_output_nc(model_config): + """ Extracts the number of channels at the output of the network form the model config + """ + if model_config.get('up_conv') is not None: + output_nc = model_config.up_conv.up_conv_nn[-1][-1] + elif model_config.get('innermost') is not None: + output_nc = model_config.innermost.nn[-1] + else: + raise ValueError("Input model_config does not match expected pattern") + return output_nc diff --git a/torch-points3d/torch_points3d/core/__init__.py b/torch-points3d/torch_points3d/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/core/common_modules/__init__.py b/torch-points3d/torch_points3d/core/common_modules/__init__.py new file mode 100644 index 0000000..cc11ef7 --- /dev/null +++ b/torch-points3d/torch_points3d/core/common_modules/__init__.py @@ -0,0 +1,2 @@ +from .base_modules import * +from .spatial_transform import * diff --git a/torch-points3d/torch_points3d/core/common_modules/base_modules.py b/torch-points3d/torch_points3d/core/common_modules/base_modules.py new file mode 100644 index 0000000..5669a86 --- /dev/null +++ b/torch-points3d/torch_points3d/core/common_modules/base_modules.py @@ -0,0 +1,166 @@ +import numpy as np +import torch +from torch import nn +from torch.nn.parameter import Parameter + +import torch_points3d.models.instance.semi_supervised_helper + + +class BaseModule(nn.Module): + """ Base module class with some basic additions to the pytorch Module class + """ + + @property + def nb_params(self): + """This property is used to return the number of trainable parameters for a given layer + It is useful for debugging and reproducibility. + """ + model_parameters = filter(lambda p: p.requires_grad, self.parameters()) + self._nb_params = sum([np.prod(p.size()) for p in model_parameters]) + return self._nb_params + + +def weight_variable(shape): + initial = torch.empty(shape, dtype=torch.float) + torch.nn.init.xavier_normal_(initial) + return initial + + +class Identity(BaseModule): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, data): + return data + + +def MLP(channels, activation=nn.LeakyReLU(0.2), bn_momentum=0.1, bias=True): + return nn.Sequential( + *[ + nn.Sequential( + nn.Linear(channels[i - 1], channels[i], bias=bias), + FastBatchNorm1d(channels[i], momentum=bn_momentum), + activation, + ) + for i in range(1, len(channels)) + ] + ) + + +class UnaryConv(BaseModule): + def __init__(self, kernel_shape): + """ + 1x1 convolution on point cloud (we can even call it a mini pointnet) + """ + super(UnaryConv, self).__init__() + self.weight = Parameter(weight_variable(kernel_shape)) + + def forward(self, features): + """ + features(Torch Tensor): size N x d d is the size of inputs + """ + return torch.matmul(features, self.weight) + + def __repr__(self): + return "UnaryConv {}".format(self.weight.shape) + + +class MultiHeadClassifier(BaseModule): + """ Allows segregated segmentation in case the category of an object is known. This is the case in ShapeNet + for example. + + Arguments: + in_features -- size of the input channel + cat_to_seg {[type]} -- category to segment maps for example: + { + 'Airplane': [0,1,2], + 'Table': [3,4] + } + + Keyword Arguments: + dropout_proba (default: {0.5}) + bn_momentum -- batch norm momentum (default: {0.1}) + """ + + def __init__(self, in_features, cat_to_seg, dropout_proba=0.5, bn_momentum=0.1): + super().__init__() + self._cat_to_seg = {} + self._num_categories = len(cat_to_seg) + self._max_seg_count = 0 + self._max_seg = 0 + self._shifts = torch.zeros((self._num_categories,), dtype=torch.long) + for i, seg in enumerate(cat_to_seg.values()): + self._max_seg_count = max(self._max_seg_count, len(seg)) + self._max_seg = max(self._max_seg, max(seg)) + self._shifts[i] = min(seg) + self._cat_to_seg[i] = seg + + self.channel_rasing = MLP( + [in_features, self._num_categories * in_features], bn_momentum=bn_momentum, bias=False + ) + if dropout_proba: + self.channel_rasing.add_module("Dropout", nn.Dropout(p=dropout_proba)) + + self.classifier = UnaryConv((self._num_categories, in_features, self._max_seg_count)) + self._bias = Parameter(torch.zeros(self._max_seg_count,)) + + def forward(self, features, category_labels, **kwargs): + assert features.dim() == 2 + self._shifts = self._shifts.to(features.device) + in_dim = features.shape[-1] + features = self.channel_rasing(features) + features = features.reshape((-1, self._num_categories, in_dim)) + features = features.transpose(0, 1) # [num_categories, num_points, in_dim] + features = self.classifier(features) + self._bias # [num_categories, num_points, max_seg] + ind = category_labels.unsqueeze(-1).repeat(1, 1, features.shape[-1]).long() + + logits = torch_points3d.models.instance.semi_supervised_helper.gather(0, ind).squeeze(0) + softmax = torch.nn.functional.log_softmax(logits, dim=-1) + + output = torch.zeros(logits.shape[0], self._max_seg + 1).to(features.device) + cats_in_batch = torch.unique(category_labels) + for cat in cats_in_batch: + cat_mask = category_labels == cat + seg_indices = self._cat_to_seg[cat.item()] + probs = softmax[cat_mask, : len(seg_indices)] + output[cat_mask, seg_indices[0] : seg_indices[-1] + 1] = probs + return output + + +class FastBatchNorm1d(BaseModule): + def __init__(self, num_features, momentum=0.1, **kwargs): + super().__init__() + self.batch_norm = nn.BatchNorm1d(num_features, momentum=momentum, **kwargs) + + def _forward_dense(self, x): + return self.batch_norm(x.permute(0, 2, 1)).permute(0, 2, 1) + + def _forward_sparse(self, x): + """ Batch norm 1D is not optimised for 2D tensors. The first dimension is supposed to be + the batch and therefore not very large. So we introduce a custom version that leverages BatchNorm1D + in a more optimised way + """ + x = x.unsqueeze(2) + x = x.transpose(0, 2) + x = self.batch_norm(x) + x = x.transpose(0, 2) + return x.squeeze(dim=2) + + def forward(self, x): + if x.dim() == 2: + return self._forward_sparse(x) + elif x.dim() == 3: + return self._forward_dense(x) + else: + raise ValueError("Non supported number of dimensions {}".format(x.dim())) + + +class Seq(nn.Sequential): + def __init__(self): + super().__init__() + self._num_modules = 0 + + def append(self, module): + self.add_module(str(self._num_modules), module) + self._num_modules += 1 + return self diff --git a/torch-points3d/torch_points3d/core/common_modules/dense_modules.py b/torch-points3d/torch_points3d/core/common_modules/dense_modules.py new file mode 100644 index 0000000..5cb3046 --- /dev/null +++ b/torch-points3d/torch_points3d/core/common_modules/dense_modules.py @@ -0,0 +1,29 @@ +import torch.nn as nn +from .base_modules import Seq + + +class Conv2D(Seq): + def __init__(self, in_channels, out_channels, bias=True, bn=True, activation=nn.LeakyReLU(negative_slope=0.01)): + super().__init__() + self.append(nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1), stride=(1, 1), bias=bias)) + if bn: + self.append(nn.BatchNorm2d(out_channels)) + if activation: + self.append(activation) + + +class Conv1D(Seq): + def __init__(self, in_channels, out_channels, bias=True, bn=True, activation=nn.LeakyReLU(negative_slope=0.01)): + super().__init__() + self.append(nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)) + if bn: + self.append(nn.BatchNorm1d(out_channels)) + if activation: + self.append(activation) + + +class MLP2D(Seq): + def __init__(self, channels, bias=False, bn=True, activation=nn.LeakyReLU(negative_slope=0.01)): + super().__init__() + for i in range(len(channels) - 1): + self.append(Conv2D(channels[i], channels[i + 1], bn=bn, bias=bias, activation=activation)) diff --git a/torch-points3d/torch_points3d/core/common_modules/gathering.py b/torch-points3d/torch_points3d/core/common_modules/gathering.py new file mode 100644 index 0000000..98d1c39 --- /dev/null +++ b/torch-points3d/torch_points3d/core/common_modules/gathering.py @@ -0,0 +1,36 @@ +import torch_points3d.models.instance.semi_supervised_helper + + +def gather(x, idx, method=2): + """ + https://github.com/pytorch/pytorch/issues/15245 + implementation of a custom gather operation for faster backwards. + :param x: input with shape [N, D_1, ... D_d] + :param idx: indexing with shape [n_1, ..., n_m] + :param method: Choice of the method + :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d] + """ + idx[idx == -1] = x.shape[0] - 1 # Shadow point + if method == 0: + return x[idx] + elif method == 1: + x = x.unsqueeze(1) + x = x.expand((-1, idx.shape[-1], -1)) + idx = idx.unsqueeze(2) + idx = idx.expand((-1, -1, x.shape[-1])) + return torch_points3d.models.instance.semi_supervised_helper.gather(0, idx) + elif method == 2: + for i, ni in enumerate(idx.size()[1:]): + x = x.unsqueeze(i + 1) + new_s = list(x.size()) + new_s[i + 1] = ni + x = x.expand(new_s) + n = len(idx.size()) + for i, di in enumerate(x.size()[n:]): + idx = idx.unsqueeze(i + n) + new_s = list(idx.size()) + new_s[i + n] = di + idx = idx.expand(new_s) + return x.gather(0, idx) + else: + raise ValueError("Unknown method") diff --git a/torch-points3d/torch_points3d/core/common_modules/spatial_transform.py b/torch-points3d/torch_points3d/core/common_modules/spatial_transform.py new file mode 100644 index 0000000..712b36c --- /dev/null +++ b/torch-points3d/torch_points3d/core/common_modules/spatial_transform.py @@ -0,0 +1,65 @@ +import torch +from torch.nn import Linear + + +class BaseLinearTransformSTNkD(torch.nn.Module): + """STN which learns a k-dimensional linear transformation + + Arguments: + nn (torch.nn.Module) -- module which takes feat_x as input and regresses it to a global feature used to calculate the transform + nn_feat_size -- the size of the global feature + k -- the size of trans_x + batch_size -- the number of examples per batch + """ + + def __init__(self, nn, nn_feat_size, k=3, batch_size=1): + super().__init__() + + self.nn = nn + self.k = k + self.batch_size = batch_size + + # fully connected layer to regress the global feature to a k-d linear transform + # the transform is initialized to the identity + self.fc_layer = Linear(nn_feat_size, k * k) + torch.nn.init.constant_(self.fc_layer.weight, 0) + torch.nn.init.constant_(self.fc_layer.bias, 0) + self.identity = torch.eye(k).view(1, k * k).repeat(batch_size, 1) + + def forward(self, feat_x, trans_x, batch): + """ + Learns and applies a linear transformation to trans_x based on feat_x. + feat_x and trans_x may be the same or different. + """ + global_feature = self.nn(feat_x, batch) + trans = self.fc_layer(global_feature) + + # needed so that transform is initialized to identity + trans = trans + self.identity.to(feat_x.device) + trans = trans.view(-1, self.k, self.k) + self.trans = trans + + # convert trans_x from (N, K) to (B, N, K) to do batched matrix multiplication + # batch_x = trans_x.view(self.batch_size, -1, trans_x.shape[1]) + if trans_x.squeeze().dim() == 2: + batch_x = trans_x.view(trans_x.shape[0], 1, trans_x.shape[1]) + x_transformed = torch.bmm(batch_x[:, :, :trans.shape[-1]], trans[batch]) + if batch_x.shape[-1] > trans.shape[-1]: + x_transformed = torch.cat([x_transformed, batch_x[:, :, trans.shape[-1]:]], dim=-1) + return x_transformed.view(len(trans_x), trans_x.shape[1]) + else: + x_transformed = torch.bmm(trans_x[:, :, :trans.shape[-1]], trans) + if trans_x.shape[-1] > trans.shape[-1]: + x_transformed = torch.cat([x_transformed, trans_x[:, :, trans.shape[-1]:]], dim=-1) + return x_transformed + + def get_orthogonal_regularization_loss(self): + loss = torch.mean( + torch.norm( + torch.bmm(self.trans, self.trans.transpose(2, 1)) + - self.identity.to(self.trans.device).view(-1, self.k, self.k), + dim=(1, 2), + ) + ) + + return loss diff --git a/torch-points3d/torch_points3d/core/data_transform/__init__.py b/torch-points3d/torch_points3d/core/data_transform/__init__.py new file mode 100644 index 0000000..44aa58a --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/__init__.py @@ -0,0 +1,249 @@ +import sys + +import numpy as np +import torch_geometric.transforms as T +from .transforms import * +from .grid_transform import * +from .sparse_transforms import * +from .inference_transforms import * +from .feature_augment import * +from .features import * +from .filters import * +from .precollate import * +from .prebatchcollate import * +from omegaconf.dictconfig import DictConfig +from omegaconf.listconfig import ListConfig +from omegaconf import OmegaConf + +_custom_transforms = sys.modules[__name__] +_torch_geometric_transforms = sys.modules["torch_geometric.transforms"] +_intersection_names = set(_custom_transforms.__dict__) & set(_torch_geometric_transforms.__dict__) +_intersection_names = set([module for module in _intersection_names if not module.startswith("_")]) +L_intersection_names = len(_intersection_names) > 0 +_intersection_cls = [] + +for transform_name in _intersection_names: + transform_cls = getattr(_custom_transforms, transform_name) + if not "torch_geometric.transforms." in str(transform_cls): + _intersection_cls.append(transform_cls) +L_intersection_cls = len(_intersection_cls) > 0 + +if L_intersection_names: + if L_intersection_cls: + raise Exception( + "It seems that you are overriding a transform from pytorch gemetric, \ + this is forbidden, please rename your classes {} from {}".format( + _intersection_names, _intersection_cls + ) + ) + else: + raise Exception( + "It seems you are importing transforms {} from pytorch geometric within the current code base. \ + Please, remove them or add them within a class, function, etc.".format( + _intersection_names + ) + ) + + +def instantiate_transform(transform_option, attr="transform"): + """ Creates a transform from an OmegaConf dict such as + transform: GridSampling3D + params: + size: 0.01 + """ + tr_name = getattr(transform_option, attr, None) + try: + # tr_params = transform_option.params + tr_params = transform_option.get('params') # Update to OmegaConf 2.0 + except KeyError: + tr_params = None + try: + # lparams = transform_option.lparams + lparams = transform_option.get('lparams') # Update to OmegaConf 2.0 + except KeyError: + lparams = None + + cls = getattr(_custom_transforms, tr_name, None) + if not cls: + cls = getattr(_torch_geometric_transforms, tr_name, None) + if not cls: + raise ValueError("Transform %s is nowhere to be found" % tr_name) + + if tr_params and lparams: + return cls(*lparams, **tr_params) + + if tr_params: + return cls(**tr_params) + + if lparams: + return cls(*lparams) + + return cls() + + +def instantiate_transforms(transform_options): + """ Creates a torch_geometric composite transform from an OmegaConf list such as + - transform: GridSampling3D + params: + size: 0.01 + - transform: NormaliseScale + """ + transforms = [] + for transform in transform_options: + transforms.append(instantiate_transform(transform)) + return T.Compose(transforms) + + +def instantiate_filters(filter_options): + filters = [] + for filt in filter_options: + filters.append(instantiate_transform(filt, "filter")) + return FCompose(filters) + + +class LotteryTransform(object): + """ + Transforms which draw a transform randomly among several transforms indicated in transform options + Examples + + Parameters + ---------- + transform_options Omegaconf list which contains the transform + """ + + def __init__(self, transform_options): + self.random_transforms = instantiate_transforms(transform_options) + + def __call__(self, data): + list_transforms = self.random_transforms.transforms + i = np.random.randint(len(list_transforms)) + transform = list_transforms[i] + return transform(data) + + def __repr__(self): + rep = "LotteryTransform([" + for trans in self.random_transforms.transforms: + rep = rep + "{}, ".format(trans.__repr__()) + rep = rep + "])" + return rep + + +class ComposeTransform(object): + """ + Transform to compose other transforms with YAML (Compose of torch_geometric does not work). + Example : + .. code-block:: yaml + + - transform: ComposeTransform + params: + transform_options: + - transform: GridSampling3D + params: + size: 0.1 + - transform: RandomNoise + params: + sigma: 0.05 + + + Parameters: + transform_options: Omegaconf Dict + contains a list of transform + """ + + def __init__(self, transform_options): + self.transform = instantiate_transforms(transform_options) + + def __call__(self, data): + return self.transform(data) + + def __repr__(self): + rep = "ComposeTransform([" + for trans in self.transform.transforms: + rep = rep + "{}, ".format(trans.__repr__()) + rep = rep + "])" + return rep + + +class RandomParamTransform(object): + """ + create a transform with random parameters + + Example (on the yaml) + + .. code-block:: yaml + + transform: RandomParamTransform + params: + transform_name: GridSampling3D + transform_params: + size: + min: 0.1 + max: 0.3 + type: "float" + mode: + value: "last" + + + We can also draw random numbers for two parameters, integer or float + + .. code-block:: yaml + + transform: RandomParamTransform + params: + transform_name: RandomSphereDropout + transform_params: + radius: + min: 1 + max: 2 + type: "float" + num_sphere: + min: 1 + max: 5 + type: "int" + + + Parameters + ---------- + transform_name: string: + the name of the transform + transform_options: Omegaconf Dict + contains the name of a variables as a key and min max type as value to specify the range of the parameters and + the type of the parameters or it contains the value "value" to specify a variables (see Example above) + + """ + + def __init__(self, transform_name, transform_params): + self.transform_name = transform_name + self.transform_params = transform_params + self.random_transform = self._instanciate_transform_with_random_params() + + def _instanciate_transform_with_random_params(self): + dico = dict() + for p, rang in self.transform_params.items(): + if "max" in rang and "min" in rang: + assert rang["max"] - rang["min"] > 0 + v = np.random.random() * (rang["max"] - rang["min"]) + rang["min"] + + if rang["type"] == "float": + v = float(v) + elif rang["type"] == "int": + v = int(v) + else: + raise NotImplementedError + dico[p] = v + elif "value" in rang: + v = rang["value"] + dico[p] = v + else: + raise NotImplementedError + + trans_opt = DictConfig(dict(params=dico, transform=self.transform_name)) + random_transform = instantiate_transform(trans_opt, attr="transform") + return random_transform + + def __call__(self, data): + self.random_transform = self._instanciate_transform_with_random_params() + return self.random_transform(data) + + def __repr__(self): + return "RandomParamTransform({}, params={})".format(self.transform_name, self.transform_params) diff --git a/torch-points3d/torch_points3d/core/data_transform/feature_augment.py b/torch-points3d/torch_points3d/core/data_transform/feature_augment.py new file mode 100644 index 0000000..3c856c5 --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/feature_augment.py @@ -0,0 +1,164 @@ +import random +import torch + +# Those Transformation are adapted from https://github.com/chrischoy/SpatioTemporalSegmentation/blob/master/lib/transforms.py + + +class NormalizeRGB(object): + """Normalize rgb between 0 and 1 + + Parameters + ---------- + normalize: bool: Whether to normalize the rgb attributes + """ + + def __init__(self, normalize=True): + self._normalize = normalize + + def __call__(self, data): + assert hasattr(data, "rgb") + if not (data.rgb.max() <= 1 and data.rgb.min() >= 0): + data.rgb = data.rgb.float() / 255.0 + return data + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self._normalize) + + +class ChromaticTranslation(object): + """Add random color to the image, data must contain an rgb attribute between 0 and 1 + + Parameters + ---------- + trans_range_ratio: + ratio of translation i.e. tramnslation = 2 * ratio * rand(-0.5, 0.5) (default: 1e-1) + """ + + def __init__(self, trans_range_ratio=1e-1): + self.trans_range_ratio = trans_range_ratio + + def __call__(self, data): + assert hasattr(data, "rgb") + assert data.rgb.max() <= 1 and data.rgb.min() >= 0 + if random.random() < 0.95: + tr = (torch.rand(1, 3) - 0.5) * 2 * self.trans_range_ratio + data.rgb = torch.clamp(tr + data.rgb, 0, 1) + return data + + def __repr__(self): + return "{}(trans_range_ratio={})".format(self.__class__.__name__, self.trans_range_ratio) + + +class ChromaticAutoContrast(object): + """ Rescale colors between 0 and 1 to enhance contrast + + Parameters + ---------- + randomize_blend_factor : + Blend factor is random + blend_factor: + Ratio of the original color that is kept + """ + + def __init__(self, randomize_blend_factor=True, blend_factor=0.5): + self.randomize_blend_factor = randomize_blend_factor + self.blend_factor = blend_factor + + def __call__(self, data): + assert hasattr(data, "rgb") + assert data.rgb.max() <= 1 and data.rgb.min() >= 0 + if random.random() < 0.2: + feats = data.rgb + lo = feats.min(0, keepdims=True)[0] + hi = feats.max(0, keepdims=True)[0] + assert hi.max() > 0, "invalid color value. Color is supposed to be [0-255]" + + scale = 1.0 / (hi - lo) + + contrast_feats = (feats - lo) * scale + + blend_factor = random.random() if self.randomize_blend_factor else self.blend_factor + data.rgb = (1 - blend_factor) * feats + blend_factor * contrast_feats + return data + + def __repr__(self): + return "{}(randomize_blend_factor={}, blend_factor={})".format( + self.__class__.__name__, self.randomize_blend_factor, self.blend_factor + ) + + +class ChromaticJitter: + """ Jitter on the rgb attribute of data + + Parameters + ---------- + std : + standard deviation of the Jitter + """ + + def __init__(self, std=0.01): + self.std = std + + def __call__(self, data): + assert hasattr(data, "rgb") + assert data.rgb.max() <= 1 and data.rgb.min() >= 0 + if random.random() < 0.95: + noise = torch.randn(data.rgb.shape[0], 3) + noise *= self.std + data.rgb = torch.clamp(noise + data.rgb, 0, 1) + return data + + def __repr__(self): + return "{}(std={})".format(self.__class__.__name__, self.std) + + +class DropFeature: + """ Sets the given feature to 0 with a given probability + + Parameters + ---------- + drop_proba: + Probability that the feature gets dropped + feature_name: + Name of the feature to drop + """ + + def __init__(self, drop_proba=0.2, feature_name="rgb"): + self._drop_proba = drop_proba + self._feature_name = feature_name + + def __call__(self, data): + assert hasattr(data, self._feature_name) + if random.random() < self._drop_proba: + data[self._feature_name] = data[self._feature_name] * 0 + return data + + def __repr__(self): + return "DropFeature: proba = {}, feature = {}".format(self._drop_proba, self._feature_name) + + +class Jitter: + """ + add a small gaussian noise to the feature. + Parameters + ---------- + mu: float + mean of the gaussian noise + sigma: float + standard deviation of the gaussian noise + p: float + probability of noise + """ + + def __init__(self, mu=0, sigma=0.01, p=0.95): + self.mu = mu + self.sigma = sigma + self.p = p + + def __call__(self, data): + if random.random() < self.p: + data.x += torch.randn_like(data.x) * self.sigma + self.mu + return data + + def __repr__(self): + return "Jitter(mu={}, sigma={})".format(self.mu, self.sigma) diff --git a/torch-points3d/torch_points3d/core/data_transform/features.py b/torch-points3d/torch_points3d/core/data_transform/features.py new file mode 100644 index 0000000..ada3f1f --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/features.py @@ -0,0 +1,386 @@ +import random +from typing import List, Optional + +import numpy as np +import torch +from torch.nn import PairwiseDistance +from torch_geometric.data import Data + +from torch_points3d.utils.geometry import euler_angles_to_rotation_matrix + + +class Random3AxisRotation(object): + """ + Rotate pointcloud with random angles along x, y, z axis + + The angles should be given `in degrees`. + + Parameters + ----------- + apply_rotation: bool: + Whether to apply the rotation + rot_x: float + Rotation angle in degrees on x axis + rot_y: float + Rotation anglei n degrees on y axis + rot_z: float + Rotation angle in degrees on z axis + """ + + def __init__(self, apply_rotation: bool = True, rot_x: float = None, rot_y: float = None, rot_z: float = None, + p: float = None): + self._apply_rotation = apply_rotation + if apply_rotation: + if (rot_x is None) and (rot_y is None) and (rot_z is None): + raise Exception("At least one rot_ should be defined") + + self._rot_x = np.abs(min(rot_x, 180)) if rot_x else 0 + self._rot_y = np.abs(min(rot_y, 180)) if rot_y else 0 + self._rot_z = np.abs(min(rot_z, 180)) if rot_z else 0 + self._p = 1 if p is None else p + + self._degree_angles = [self._rot_x, self._rot_y, self._rot_z] + + def generate_random_rotation_matrix(self): + thetas = torch.zeros(3, dtype=torch.float) + for axis_ind, deg_angle in enumerate(self._degree_angles): + if deg_angle > 0 and random.random() < self._p: + rand_deg_angle = random.random() * 2 * deg_angle - deg_angle + rand_radian_angle = float(rand_deg_angle * np.pi) / 180.0 + thetas[axis_ind] = rand_radian_angle + return euler_angles_to_rotation_matrix(thetas, random_order=True) + + def __call__(self, data): + if self._apply_rotation: + pos = data.pos.float() + M = self.generate_random_rotation_matrix() + data.pos = pos @ M.T + if getattr(data, "norm", None) is not None: + data.norm = data.norm.float() @ M.T + return data + + def __repr__(self): + return "{}(apply_rotation={}, rot_x={}, rot_y={}, rot_z={})".format( + self.__class__.__name__, self._apply_rotation, self._rot_x, self._rot_y, self._rot_z + ) + + +class RandomTranslation(object): + """ + random translation + Parameters + ----------- + delta_min: list + min translation + delta_max: list + max translation + """ + + def __init__(self, delta_max: List = [1.0, 1.0, 1.0], delta_min: List = [-1.0, -1.0, -1.0]): + self.delta_max = torch.tensor(delta_max) + self.delta_min = torch.tensor(delta_min) + + def __call__(self, data): + pos = data.pos + trans = torch.rand(3) * (self.delta_max - self.delta_min) + self.delta_min + data.pos = pos + trans + return data + + def __repr__(self): + return "{}(delta_min={}, delta_max={})".format(self.__class__.__name__, self.delta_min, self.delta_max) + + +class AddFeatsByKeys(object): + """This transform takes a list of attributes names and if allowed, add them to x + + Example: + + Before calling "AddFeatsByKeys", if data.x was empty + + - transform: AddFeatsByKeys + params: + list_add_to_x: [False, True, True] + feat_names: ['normal', 'rgb', "elevation"] + input_nc_feats: [3, 3, 1] + + After calling "AddFeatsByKeys", data.x contains "rgb" and "elevation". Its shape[-1] == 4 (rgb:3 + elevation:1) + If input_nc_feats was [4, 4, 1], it would raise an exception as rgb dimension is only 3. + + Paremeters + ---------- + list_add_to_x: List[bool] + For each boolean within list_add_to_x, control if the associated feature is going to be concatenated to x + feat_names: List[str] + The list of features within data to be added to x + input_nc_feats: List[int], optional + If provided, evaluate the dimension of the associated feature shape[-1] found using feat_names and this provided value. It allows to make sure feature dimension didn't change + stricts: List[bool], optional + Recommended to be set to list of True. If True, it will raise an Exception if feat isn't found or dimension doesn t match. + delete_feats: List[bool], optional + Wether we want to delete the feature from the data object. List length must match teh number of features added. + """ + + def __init__( + self, + list_add_to_x: List[bool], + feat_names: List[str], + input_nc_feats: List[Optional[int]] = None, + stricts: List[bool] = None, + delete_feats: List[bool] = None, + ): + + self._feat_names = feat_names + self._list_add_to_x = list_add_to_x + self._delete_feats = delete_feats + if self._delete_feats: + assert len(self._delete_feats) == len(self._feat_names) + from torch_geometric.transforms import Compose + + num_names = len(feat_names) + if num_names == 0: + raise Exception("Expected to have at least one feat_names") + + assert len(list_add_to_x) == num_names + + if input_nc_feats: + assert len(input_nc_feats) == num_names + else: + input_nc_feats = [None for _ in range(num_names)] + + if stricts: + assert len(stricts) == num_names + else: + stricts = [True for _ in range(num_names)] + + transforms = [ + AddFeatByKey(add_to_x, feat_name, input_nc_feat=input_nc_feat, strict=strict) + for add_to_x, feat_name, input_nc_feat, strict in zip(list_add_to_x, feat_names, input_nc_feats, stricts) + ] + + self.transform = Compose(transforms) + + def __call__(self, data): + data = self.transform(data) + if self._delete_feats: + for feat_name, delete_feat in zip(self._feat_names, self._delete_feats): + if delete_feat: + delattr(data, feat_name) + return data + + def __repr__(self): + msg = "" + for f, a in zip(self._feat_names, self._list_add_to_x): + msg += "{}={}, ".format(f, a) + return "{}({})".format(self.__class__.__name__, msg[:-2]) + + +class AddFeatByKey(object): + """This transform is responsible to get an attribute under feat_name and add it to x if add_to_x is True + + Paremeters + ---------- + add_to_x: bool + Control if the feature is going to be added/concatenated to x + feat_name: str + The feature to be found within data to be added/concatenated to x + input_nc_feat: int, optional + If provided, check if feature last dimension maches provided value. + strict: bool, optional + Recommended to be set to True. If False, it won't break if feat isn't found or dimension doesn t match. (default: ``True``) + """ + + def __init__(self, add_to_x, feat_name, input_nc_feat=None, strict=True): + + self._add_to_x: bool = add_to_x + self._feat_name: str = feat_name + self._input_nc_feat = input_nc_feat + self._strict: bool = strict + + def __call__(self, data: Data): + if not self._add_to_x: + return data + feat = getattr(data, self._feat_name, None) + if feat is None: + if self._strict: + raise Exception("Data should contain the attribute {}".format(self._feat_name)) + else: + return data + else: + if self._input_nc_feat: + feat_dim = 1 if feat.dim() == 1 else feat.shape[-1] + if self._input_nc_feat != feat_dim and self._strict: + raise Exception("The shape of feat: {} doesn t match {}".format(feat.shape, self._input_nc_feat)) + x = getattr(data, "x", None) + if x is None: + if self._strict and data.pos.shape[0] != feat.shape[0]: + raise Exception("We expected to have an attribute x") + if feat.dim() == 1: + feat = feat.unsqueeze(-1) + data.x = feat + else: + if x.shape[0] == feat.shape[0]: + if x.dim() == 1: + x = x.unsqueeze(-1) + if feat.dim() == 1: + feat = feat.unsqueeze(-1) + data.x = torch.cat([x, feat], dim=-1) + else: + raise Exception( + "The tensor x and {} can't be concatenated, x: {}, feat: {}".format( + self._feat_name, x.pos.shape[0], feat.pos.shape[0] + ) + ) + return data + + def __repr__(self): + return "{}(add_to_x: {}, feat_name: {}, strict: {})".format( + self.__class__.__name__, self._add_to_x, self._feat_name, self._strict + ) + + +def compute_planarity(eigenvalues): + r""" + compute the planarity with respect to the eigenvalues of the covariance matrix of the pointcloud + let + :math:`\lambda_1, \lambda_2, \lambda_3` be the eigenvalues st: + + .. math:: + \lambda_1 \leq \lambda_2 \leq \lambda_3 + + then planarity is defined as: + + .. math:: + planarity = \frac{\lambda_2 - \lambda_1}{\lambda_3} + """ + + return (eigenvalues[1] - eigenvalues[0]) / eigenvalues[2] + + +class NormalFeature(object): + """ + add normal as feature. if it doesn't exist, compute normals + using PCA + """ + + def __call__(self, data): + if getattr(data, "norm", None) is None: + raise NotImplementedError("TODO: Implement normal computation") + + norm = data.norm + if data.x is None: + data.x = norm + else: + data.x = torch.cat([data.x, norm], -1) + return data + + +class PCACompute(object): + r""" + compute `Principal Component Analysis `__ of a point cloud :math:`x_1,\dots, x_n`. + It computes the eigenvalues and the eigenvectors of the matrix :math:`C` which is the covariance matrix of the point cloud: + + .. math:: + x_{centered} &= \frac{1}{n} \sum_{i=1}^n x_i + + C &= \frac{1}{n} \sum_{i=1}^n (x_i - x_{centered})(x_i - x_{centered})^T + + store the eigen values and the eigenvectors in data. + in eigenvalues attribute and eigenvectors attributes. + data.eigenvalues is a tensor :math:`(\lambda_1, \lambda_2, \lambda_3)` such that :math:`\lambda_1 \leq \lambda_2 \leq \lambda_3`. + + data.eigenvectors is a 3 x 3 matrix such that the column are the eigenvectors associated to their eigenvalues + Therefore, the first column of data.eigenvectors estimates the normal at the center of the pointcloud. + """ + + def __call__(self, data): + pos_centered = data.pos - data.pos.mean(axis=0) + cov_matrix = pos_centered.T.mm(pos_centered) / len(pos_centered) + eig, v = torch.symeig(cov_matrix, eigenvectors=True) + data.eigenvalues = eig + data.eigenvectors = v + return data + + def __repr__(self): + return "{}()".format(self.__class__.__name__) + + +class AddOnes(object): + """ + Add ones tensor to data + """ + + def __call__(self, data): + num_nodes = data.pos.shape[0] + data.ones = torch.ones((num_nodes, 1)).float() + return data + + def __repr__(self): + return "{}()".format(self.__class__.__name__) + + +class AddXYDistanceToCenter(object): + """ + Distance to a certain point (center) + """ + + def __init__(self, center_x: float, center_y: float): + self.pdist = PairwiseDistance() + self.center: torch.Tensor = torch.tensor([[center_x, center_y]]) + + def __call__(self, data): + pos = data.pos[:, :2] + + data.xy_distance = self.pdist(pos, self.center.repeat_interleave(pos.shape[0], dim=0)) + return data + + def __repr__(self): + return "{}(center_x: {}, center_y: {})".format(self.__class__.__name__, self.center[0, 0], self.center[0, 1]) + + +class AddZDistanceToTop(object): + """ + Add distance to top of the point cloud (99 quantile) + """ + + def __call__(self, data): + pos = data.pos[:, [2]] + highest_point = torch.quantile(pos, 0.99, keepdim=True) + + data.z_distance_to_top = -(pos - highest_point) + return data + + +class XYZFeature(object): + """ + Add the X, Y and Z as a feature + Parameters + ----------- + add_x: bool [default: False] + whether we add the x position or not + add_y: bool [default: False] + whether we add the y position or not + add_z: bool [default: True] + whether we add the z position or not + """ + + def __init__(self, add_x=False, add_y=False, add_z=True): + self._axis = [] + axis_names = ["x", "y", "z"] + if add_x: + self._axis.append(0) + if add_y: + self._axis.append(1) + if add_z: + self._axis.append(2) + + self._axis_names = [axis_names[idx_axis] for idx_axis in self._axis] + + def __call__(self, data): + assert data.pos is not None + for axis_name, id_axis in zip(self._axis_names, self._axis): + f = data.pos[:, id_axis].clone() + setattr(data, "pos_{}".format(axis_name), f) + return data + + def __repr__(self): + return "{}(axis={})".format(self.__class__.__name__, self._axis_names) diff --git a/torch-points3d/torch_points3d/core/data_transform/filters.py b/torch-points3d/torch_points3d/core/data_transform/filters.py new file mode 100644 index 0000000..eaba8cd --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/filters.py @@ -0,0 +1,138 @@ +import random + +import numpy as np +import torch + +from torch_points3d.core.data_transform.features import PCACompute, compute_planarity + + +class FCompose(object): + """ + allow to compose different filters using the boolean operation + + Parameters + ---------- + list_filter: list + list of different filter functions we want to apply + boolean_operation: function, optional + boolean function to compose the filter (take a pair and return a boolean) + """ + + def __init__(self, list_filter, boolean_operation=np.logical_and): + self.list_filter = list_filter + self.boolean_operation = boolean_operation + + def __call__(self, data): + assert len(self.list_filter) > 0 + res = self.list_filter[0](data) + for filter_fn in self.list_filter: + res = self.boolean_operation(res, filter_fn(data)) + return res + + def __repr__(self): + rep = "{}([".format(self.__class__.__name__) + for filt in self.list_filter: + rep = rep + filt.__repr__() + ", " + rep = rep + "])" + return rep + + +class PlanarityFilter(object): + """ + compute planarity and return false if the planarity of a pointcloud is above or below a threshold + + Parameters + ---------- + thresh: float, optional + threshold to filter low planar pointcloud + is_leq: bool, optional + choose whether planarity should be lesser or equal than the threshold or greater than the threshold. + """ + + def __init__(self, thresh=0.3, is_leq=True): + self.thresh = thresh + self.is_leq = is_leq + + def __call__(self, data): + if getattr(data, "eigenvalues", None) is None: + data = PCACompute()(data) + planarity = compute_planarity(data.eigenvalues) + if self.is_leq: + return planarity <= self.thresh + else: + return planarity > self.thresh + + def __repr__(self): + return "{}(thresh={}, is_leq={})".format(self.__class__.__name__, self.thresh, self.is_leq) + + +class RandomFilter(object): + """ + Randomly select an elem of the dataset (to have smaller dataset) with a bernouilli distribution of parameter thresh. + + Parameters + ---------- + thresh: float, optional + the parameter of the bernouilli function + """ + + def __init__(self, thresh=0.3): + self.thresh = thresh + + def __call__(self, data): + return random.random() < self.thresh + + def __repr__(self): + return "{}(thresh={})".format(self.__class__.__name__, self.thresh) + + +class ClassificationFilter(object): + """ + Select specific classes from "classification" feature to remove or keep. + Keep is prioritized. + + Parameters + ---------- + feature_index: int + which index the classification is expected in + class_indices: + which class indices to select for keeping or removing + keep: bool, optional + keep the given class indices if true, else remove them (default: True) + remove_feat: bool, optional + if the feature should be removed after filtering (default: True) + + """ + + def __init__(self, feature_index: int, class_indices: list, keep: bool = True, remove_feat: bool = True): + self.class_indices = class_indices + self.keep = keep + self.feature_index = feature_index + self.remove_feat = remove_feat + + def __call__(self, data): + cls = data.x[:, self.feature_index] + mask = torch.stack([cls == i for i in self.class_indices]).any(0) + if not self.keep: + mask = ~mask + + num_nodes = data.num_nodes + for key, item in data: + if key == 'num_nodes': + data.num_nodes = mask.size(0) + elif (torch.is_tensor(item) and item.size(0) == num_nodes + and item.size(0) != 1): + data[key] = item[mask] + + if self.remove_feat: + if data.x.shape[1] == 1: + data.x = None + else: + data.x = torch.cat([data.x[:, :self.feature_index], data.x[:, self.feature_index + 1:]], 1) + + return data + + def __repr__(self): + return "{}(feature_index={},class_indices={},keep={},remove_feat={})".format( + self.__class__.__name__, self.feature_index, self.class_indices, self.keep, self.remove_feat + ) diff --git a/torch-points3d/torch_points3d/core/data_transform/grid_transform.py b/torch-points3d/torch_points3d/core/data_transform/grid_transform.py new file mode 100644 index 0000000..e6c897d --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/grid_transform.py @@ -0,0 +1,231 @@ +import logging +import random +import re +from typing import * + +import numpy as np +import scipy +import torch +import torch.nn.functional as F +from torch_cluster import grid_cluster +from torch_geometric.data import Data +from torch_geometric.nn import voxel_grid +from torch_geometric.nn.pool.consecutive import consecutive_cluster +from torch_scatter import scatter_mean, scatter_add + +log = logging.getLogger(__name__) + +# Label will be the majority label in each voxel +_INTEGER_LABEL_KEYS = ["y", "y_cls", "instance_labels"] + + +def shuffle_data(data): + num_points = data.pos.shape[0] + shuffle_idx = torch.randperm(num_points) + for key in set(data.keys): + item = data[key] + if torch.is_tensor(item) and num_points == item.shape[0]: + data[key] = item[shuffle_idx] + return data + + +def group_data(data, cluster=None, unique_pos_indices=None, mode="last", skip_keys=[]): + """ Group data based on indices in cluster. + The option ``mode`` controls how data gets aggregated within each cluster. + + Parameters + ---------- + data : Data + [description] + cluster : torch.Tensor + Tensor of the same size as the number of points in data. Each element is the cluster index of that point. + unique_pos_indices : torch.tensor + Tensor containing one index per cluster, this index will be used to select features and labels + mode : str + Option to select how the features and labels for each voxel is computed. Can be ``last`` or ``mean``. + ``last`` selects the last point falling in a voxel as the represented, ``mean`` takes the average. + skip_keys: list + Keys of attributes to skip in the grouping + """ + + assert mode in ["mean", "last"] + if mode == "mean" and cluster is None: + raise ValueError("In mean mode the cluster argument needs to be specified") + if mode == "last" and unique_pos_indices is None: + raise ValueError("In last mode the unique_pos_indices argument needs to be specified") + + num_nodes = data.num_nodes + for key, item in data: + if bool(re.search("edge", key)): + raise ValueError("Edges not supported. Wrong data type.") + if key in skip_keys: + continue + + if torch.is_tensor(item) and item.size(0) == num_nodes: + if mode == "last" or key == "batch" or key == SaveOriginalPosId.KEY: + data[key] = item[unique_pos_indices] + elif mode == "mean": + is_item_bool = item.dtype == torch.bool + if is_item_bool: + item = item.int() + if key in _INTEGER_LABEL_KEYS: + item_min = item.min() + item = F.one_hot(item - item_min) + item = scatter_add(item, cluster, dim=0) + data[key] = item.argmax(dim=-1) + item_min + else: + data[key] = scatter_mean(item, cluster, dim=0) + if is_item_bool: + data[key] = data[key].bool() + return data + + +class GridSampling3D: + """ Clusters points into voxels with size :attr:`size`. + Parameters + ---------- + size: float + Size of a voxel (in each dimension). + quantize_coords: bool + If True, it will convert the points into their associated sparse coordinates within the grid and store + the value into a new `coords` attribute + mode: string: + The mode can be either `last` or `mean`. + If mode is `mean`, all the points and their features within a cell will be averaged + If mode is `last`, one random points per cell will be selected with its associated features + """ + + def __init__(self, size, quantize_coords=False, mode="mean", verbose=False): + self._grid_size = size + self._quantize_coords = quantize_coords + self._mode = mode + if verbose: + log.warning( + "If you need to keep track of the position of your points, use SaveOriginalPosId transform before using GridSampling3D" + ) + + if self._mode == "last": + log.warning( + "The tensors within data will be shuffled each time this transform is applied. Be careful that if an attribute doesn't have the size of num_points, it won't be shuffled" + ) + + def _process(self, data): + if self._mode == "last": + data = shuffle_data(data) + + coords = torch.round((data.pos) / self._grid_size) + if "batch" not in data: + cluster = grid_cluster(coords, torch.tensor([1, 1, 1])) + else: + cluster = voxel_grid(pos=coords, size=1, batch=data.batch) + cluster, unique_pos_indices = consecutive_cluster(cluster) + + data = group_data(data, cluster, unique_pos_indices, mode=self._mode) + if self._quantize_coords: + data.coords = coords[unique_pos_indices].int() + + data.grid_size = torch.tensor([self._grid_size]) + return data + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in data] + else: + data = self._process(data) + return data + + def __repr__(self): + return "{}(grid_size={}, quantize_coords={}, mode={})".format( + self.__class__.__name__, self._grid_size, self._quantize_coords, self._mode + ) + + +class SaveOriginalPosId: + """ Transform that adds the index of the point to the data object + This allows us to track this point from the output back to the input data object + """ + + KEY = "origin_id" + + def _process(self, data): + if hasattr(data, self.KEY): + return data + + setattr(data, self.KEY, torch.arange(0, data.pos.shape[0])) + return data + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in data] + else: + data = self._process(data) + return data + + def __repr__(self): + return self.__class__.__name__ + + +class ElasticDistortion: + """Apply elastic distortion on sparse coordinate space. First projects the position onto a + voxel grid and then apply the distortion to the voxel grid. + + Parameters + ---------- + granularity: List[float] + Granularity of the noise in meters + magnitude:List[float] + Noise multiplier in meters + Returns + ------- + data: Data + Returns the same data object with distorted grid + """ + + def __init__( + self, apply_distorsion: bool = True, granularity: List = [0.2, 0.8], magnitude=[0.4, 1.6], + ): + assert len(magnitude) == len(granularity) + self._apply_distorsion = apply_distorsion + self._granularity = granularity + self._magnitude = magnitude + + @staticmethod + def elastic_distortion(coords, granularity, magnitude): + coords = coords.numpy() + blurx = np.ones((3, 1, 1, 1)).astype("float32") / 3 + blury = np.ones((1, 3, 1, 1)).astype("float32") / 3 + blurz = np.ones((1, 1, 3, 1)).astype("float32") / 3 + coords_min = coords.min(0) + + # Create Gaussian noise tensor of the size given by granularity. + noise_dim = ((coords - coords_min).max(0) // granularity).astype(int) + 3 + noise = np.random.randn(*noise_dim, 3).astype(np.float32) + + # Smoothing. + for _ in range(2): + noise = scipy.ndimage.filters.convolve(noise, blurx, mode="constant", cval=0) + noise = scipy.ndimage.filters.convolve(noise, blury, mode="constant", cval=0) + noise = scipy.ndimage.filters.convolve(noise, blurz, mode="constant", cval=0) + + # Trilinear interpolate noise filters for each spatial dimensions. + ax = [ + np.linspace(d_min, d_max, d) + for d_min, d_max, d in zip(coords_min - granularity, coords_min + granularity * (noise_dim - 2), noise_dim) + ] + interp = scipy.interpolate.RegularGridInterpolator(ax, noise, bounds_error=0, fill_value=0) + coords = coords + interp(coords) * magnitude + return torch.tensor(coords).float() + + def __call__(self, data): + # coords = data.pos / self._spatial_resolution + if self._apply_distorsion: + if random.random() < 0.95: + for i in range(len(self._granularity)): + data.pos = ElasticDistortion.elastic_distortion(data.pos, self._granularity[i], + self._magnitude[i], ) + return data + + def __repr__(self): + return "{}(apply_distorsion={}, granularity={}, magnitude={})".format( + self.__class__.__name__, self._apply_distorsion, self._granularity, self._magnitude, + ) diff --git a/torch-points3d/torch_points3d/core/data_transform/inference_transforms.py b/torch-points3d/torch_points3d/core/data_transform/inference_transforms.py new file mode 100644 index 0000000..6a8e805 --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/inference_transforms.py @@ -0,0 +1,87 @@ +import os +import sys +import logging + +ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "..", "..") +sys.path.insert(0, os.path.join(ROOT)) + +log = logging.getLogger(__name__) + + +class ModelInference(object): + """ Base class transform for performing a point cloud inference using a pre_trained model + Subclass and implement the ``__call__`` method with your own forward. + See ``PointNetForward`` for an example implementation. + + Parameters + ---------- + checkpoint_dir: str + Path to a checkpoint directory + model_name: str + Model name, the file ``checkpoint_dir/model_name.pt`` must exist + """ + + def __init__(self, checkpoint_dir, model_name, weight_name, feat_name, num_classes=None, mock_dataset=True): + # Checkpoint + from torch_points3d.datasets.base_dataset import BaseDataset + from torch_points3d.datasets.dataset_factory import instantiate_dataset + from torch_points3d.utils.mock import MockDataset + import torch_points3d.metrics.model_checkpoint as model_checkpoint + + checkpoint = model_checkpoint.ModelCheckpoint(checkpoint_dir, model_name, weight_name, strict=True) + if mock_dataset: + dataset = MockDataset(num_classes) + dataset.num_classes = num_classes + else: + dataset = instantiate_dataset(checkpoint.data_config) + BaseDataset.set_transform(self, checkpoint.data_config) + self.model = checkpoint.create_model(dataset, weight_name=weight_name) + self.model.eval() + + def __call__(self, data): + raise NotImplementedError + + +class PointNetForward(ModelInference): + """ Transform for running a PointNet inference on a Data object. It assumes that the + model has been trained for segmentation. + + Parameters + ---------- + checkpoint_dir: str + Path to a checkpoint directory + model_name: str + Model name, the file ``checkpoint_dir/model_name.pt`` must exist + weight_name: str + Type of weights to load (best for iou, best for loss etc...) + feat_name: str + Name of the key in Data that will hold the output of the forward + num_classes: int + Number of classes that the model was trained on + """ + + def __init__(self, checkpoint_dir, model_name, weight_name, feat_name, num_classes, mock_dataset=True): + super(PointNetForward, self).__init__( + checkpoint_dir, model_name, weight_name, feat_name, num_classes=num_classes, mock_dataset=mock_dataset + ) + self.feat_name = feat_name + + from torch_points3d.datasets.base_dataset import BaseDataset + from torch_geometric.transforms import FixedPoints, GridSampling3D + + self.inference_transform = BaseDataset.remove_transform(self.inference_transform, [GridSampling3D, FixedPoints]) + + def __call__(self, data): + data_c = data.clone() + data_c.pos = data_c.pos.float() + if self.inference_transform: + data_c = self.inference_transform(data_c) + self.model.set_input(data_c, data.pos.device) + feat = self.model.get_local_feat().detach() + setattr(data, str(self.feat_name), feat) + return data + + def __repr__(self): + return "{}(model: {}, transform: {})".format( + self.__class__.__name__, self.model.__class__.__name__, self.inference_transform + ) diff --git a/torch-points3d/torch_points3d/core/data_transform/prebatchcollate.py b/torch-points3d/torch_points3d/core/data_transform/prebatchcollate.py new file mode 100644 index 0000000..58ce9a3 --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/prebatchcollate.py @@ -0,0 +1,43 @@ +import logging + +log = logging.getLogger(__name__) + + +class ClampBatchSize: + """ Drops sample in a batch if the batch gets too large + + Parameters + ---------- + num_points : int, optional + Maximum number of points per batch, by default 100000 + """ + + def __init__(self, num_points=100000): + self._num_points = num_points + + def __call__(self, datas): + assert isinstance(datas, list) + batch_id = 0 + batch_num_points = 0 + removed_sample = False + datas_out = [] + for batch_id, d in enumerate(datas): + num_points = datas[batch_id].pos.shape[0] + batch_num_points += num_points + if self._num_points and batch_num_points > self._num_points: + batch_num_points -= num_points + removed_sample = True + continue + datas_out.append(d) + + if removed_sample: + num_full_points = sum(len(d.pos) for d in datas) + num_full_batch_size = len(datas_out) + log.warning( + f"\t\tCannot fit {num_full_points} points into {self._num_points} points " + f"limit. Truncating batch size at {num_full_batch_size} out of {len(datas)} with {batch_num_points}." + ) + return datas_out + + def __repr__(self): + return "{}(num_points={})".format(self.__class__.__name__, self._num_points) diff --git a/torch-points3d/torch_points3d/core/data_transform/precollate.py b/torch-points3d/torch_points3d/core/data_transform/precollate.py new file mode 100644 index 0000000..c531a6d --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/precollate.py @@ -0,0 +1,26 @@ + + +class NormalizeFeature(object): + """Normalize a feature. By default, features will be scaled between [0,1]. Should only be applied on a dataset-level. + + Parameters + ---------- + standardize: bool: Will use standardization rather than scaling. + """ + + def __init__(self, feature_name, standardize=False): + self._feature_name = feature_name + self._standardize = standardize + + def __call__(self, data): + assert hasattr(data, self._feature_name) + feature = data[self._feature_name] + if self._standardize: + feature = (feature - feature.mean()) / (feature.std()) + else: + feature = (feature - feature.min()) / (feature.max() - feature.min()) + data[self._feature_name] = feature + return data + + def __repr__(self): + return "{}(feature_name={}, standardize={})".format(self.__class__.__name__, self._feature_name, self._standardize) \ No newline at end of file diff --git a/torch-points3d/torch_points3d/core/data_transform/sparse_transforms.py b/torch-points3d/torch_points3d/core/data_transform/sparse_transforms.py new file mode 100644 index 0000000..5163b33 --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/sparse_transforms.py @@ -0,0 +1,60 @@ +from typing import List +import itertools +import numpy as np +import math +import re +import torch +import scipy +import random +from tqdm.auto import tqdm as tq +from torch.nn import functional as F +from functools import partial +from torch_geometric.nn import fps, radius, knn, voxel_grid +from torch_geometric.nn.pool.consecutive import consecutive_cluster +from torch_geometric.nn.pool.pool import pool_pos, pool_batch +from torch_scatter import scatter_add, scatter_mean +from torch_cluster import grid_cluster + +from torch_points3d.datasets.multiscale_data import MultiScaleData +from torch_points3d.utils.config import is_list +from torch_points3d.utils import is_iterable +from .grid_transform import group_data, GridSampling3D, shuffle_data + + +class RandomCoordsFlip(object): + def __init__(self, ignored_axis, is_temporal=False, p=0.95): + """This transform is used to flip sparse coords using a given axis. Usually, it would be x or y + + Parameters + ---------- + ignored_axis: str + Axis to be chosen between x, y, z + is_temporal : bool + Used to indicate if the pointcloud is actually 4 dimensional + + Returns + ------- + data: Data + Returns the same data object with only one point per voxel + """ + assert 0 <= p <= 1, "p should be within 0 and 1. Higher probability reduce chance of flipping" + self._is_temporal = is_temporal + self._D = 4 if is_temporal else 3 + mapping = {"x": 0, "y": 1, "z": 2} + self._ignored_axis = [mapping[axis] for axis in ignored_axis] + # Use the rest of axes for flipping. + self._horz_axes = set(range(self._D)) - set(self._ignored_axis) + self._p = p + + def __call__(self, data): + for curr_ax in self._horz_axes: + if random.random() < self._p: + coords = data.coords + coord_max = torch.max(coords[:, curr_ax]) + data.coords[:, curr_ax] = coord_max - coords[:, curr_ax] + return data + + def __repr__(self): + return "{}(flip_axis={}, prob={}, is_temporal={})".format( + self.__class__.__name__, self._horz_axes, self._p, self._is_temporal + ) diff --git a/torch-points3d/torch_points3d/core/data_transform/transforms.py b/torch-points3d/torch_points3d/core/data_transform/transforms.py new file mode 100755 index 0000000..1024db5 --- /dev/null +++ b/torch-points3d/torch_points3d/core/data_transform/transforms.py @@ -0,0 +1,1796 @@ +import math +import os +import random +import re +from functools import partial +from glob import glob +from itertools import chain +from pathlib import Path as PPath +from typing import List + +import numba +import numpy as np +import torch +from dbscan1d.core import DBSCAN1D +from matplotlib.path import Path +from matplotlib.transforms import Affine2D +from omegaconf import OmegaConf +from sklearn.cluster import OPTICS +from sklearn.neighbors import KDTree, KernelDensity +from torch.nn import functional as F +from torch_geometric.data import Data, Batch +from torch_geometric.transforms import FixedPoints as FP +from tqdm.auto import tqdm as tq + +from torch_points3d.datasets.multiscale_data import MultiScaleData +from torch_points3d.utils.transform_utils import SamplingStrategy +from .features import Random3AxisRotation +from .grid_transform import GridSampling3D, shuffle_data +from ...utils import is_iterable + +KDTREE_KEY = "kd_tree" + + +class RemoveAttributes(object): + """This transform allows to remove unnecessary attributes from data for optimization purposes + + Parameters + ---------- + attr_names: list + Remove the attributes from data using the provided `attr_name` within attr_names + strict: bool=False + Wether True, it will raise an execption if the provided attr_name isn t within data keys. + """ + + def __init__(self, attr_names=[], strict=False): + self._attr_names = attr_names + self._strict = strict + + def __call__(self, data): + keys = set(data.keys) + for attr_name in self._attr_names: + if attr_name not in keys and self._strict: + raise Exception("attr_name: {} isn t within keys: {}".format(attr_name, keys)) + for attr_name in self._attr_names: + delattr(data, attr_name) + return data + + def __repr__(self): + return "{}(attr_names={}, strict={})".format(self.__class__.__name__, self._attr_names, self._strict) + + +class PointCloudFusion(object): + """This transform is responsible to perform a point cloud fusion from a list of data + + - If a list of data is provided -> Create one Batch object with all data + - If a list of list of data is provided -> Create a list of fused point cloud + """ + + def _process(self, data_list): + if len(data_list) == 0: + return Data() + data = Batch.from_data_list(data_list) + delattr(data, "batch") + delattr(data, "ptr") + return data + + def __call__(self, data_list: List[Data]): + if len(data_list) == 0: + raise Exception("A list of data should be provided") + elif len(data_list) == 1: + return data_list[0] + else: + if isinstance(data_list[0], list): + data = [self._process(d) for d in data_list] + else: + data = self._process(data_list) + return data + + def __repr__(self): + return "{}()".format(self.__class__.__name__) + + +class GridSphereSampling(object): + """Fits the point cloud to a grid and for each point in this grid, + create a sphere with a radius r + + Parameters + ---------- + radius: float + Radius of the sphere to be sampled. + grid_size: float, optional + Grid_size to be used with GridSampling3D to select spheres center. If None, radius will be used + delattr_kd_tree: bool, optional + If True, KDTREE_KEY should be deleted as an attribute if it exists + center: bool, optional + If True, a centre transform is apply on each sphere. + """ + + KDTREE_KEY = KDTREE_KEY + + def __init__(self, radius, grid_size=None, delattr_kd_tree=True, center=True): + self._radius = eval(radius) if isinstance(radius, str) else float(radius) + grid_size = eval(grid_size) if isinstance(grid_size, str) else float(grid_size) + self._grid_sampling = GridSampling3D(size=grid_size if grid_size else self._radius) + self._delattr_kd_tree = delattr_kd_tree + self._center = center + + def _process(self, data): + if not hasattr(data, self.KDTREE_KEY): + tree = KDTree(np.asarray(data.pos), leaf_size=50) + else: + tree = getattr(data, self.KDTREE_KEY) + + # The kdtree has bee attached to data for optimization reason. + # However, it won't be used for down the transform pipeline and should be removed before any collate func call. + if hasattr(data, self.KDTREE_KEY) and self._delattr_kd_tree: + delattr(data, self.KDTREE_KEY) + + # apply grid sampling + grid_data = self._grid_sampling(data.clone()) + + datas = [] + for grid_center in np.asarray(grid_data.pos): + pts = np.asarray(grid_center)[np.newaxis] + + # Find closest point within the original data + ind = torch.LongTensor(tree.query(pts, k=1)[1][0]) + grid_label = data.y[ind] + + # Find neighbours within the original data + ind = torch.LongTensor(tree.query_radius(pts, r=self._radius)[0]) + sampler = SphereSampling(self._radius, grid_center, align_origin=self._center) + new_data = sampler(data) + new_data.center_label = grid_label + + datas.append(new_data) + return datas + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in tq(data)] + data = list(chain(*data)) # 2d list needs to be flatten + else: + data = self._process(data) + return data + + def __repr__(self): + return "{}(radius={}, center={})".format(self.__class__.__name__, self._radius, self._center) + + +class GridCylinderSampling(object): + """Fits the point cloud to a grid and for each point in this grid, + create a cylinder with a radius r + + Parameters + ---------- + radius: float + Radius of the cylinder to be sampled. + grid_size: float, optional + Grid_size to be used with GridSampling3D to select cylinders center. If None, radius will be used + delattr_kd_tree: bool, optional + If True, KDTREE_KEY should be deleted as an attribute if it exists + center: bool, optional + If True, a centre transform is apply on each cylinder. + """ + + KDTREE_KEY = KDTREE_KEY + + def __init__(self, radius, grid_size=None, delattr_kd_tree=True, center=True): + self._radius = eval(radius) if isinstance(radius, str) else float(radius) + grid_size = eval(grid_size) if isinstance(grid_size, str) else float(grid_size) + self._grid_sampling = GridSampling3D(size=grid_size if grid_size else self._radius) + self._delattr_kd_tree = delattr_kd_tree + self._center = center + + def _process(self, data): + if not hasattr(data, self.KDTREE_KEY): + tree = KDTree(np.asarray(data.pos[:, :-1]), leaf_size=50) + else: + tree = getattr(data, self.KDTREE_KEY) + + # The kdtree has bee attached to data for optimization reason. + # However, it won't be used for down the transform pipeline and should be removed before any collate func call. + if hasattr(data, self.KDTREE_KEY) and self._delattr_kd_tree: + delattr(data, self.KDTREE_KEY) + + # apply grid sampling + grid_data = self._grid_sampling(data.clone()) + + datas = [] + for grid_center in np.unique(grid_data.pos[:, :-1], axis=0): + pts = np.asarray(grid_center)[np.newaxis] + + # Find closest point within the original data + ind = torch.LongTensor(tree.query(pts, k=1)[1][0]) + grid_label = data.y[ind] + + # Find neighbours within the original data + ind = torch.LongTensor(tree.query_radius(pts, r=self._radius)[0]) + sampler = CylinderSampling(self._radius, grid_center, align_origin=self._center) + new_data = sampler(data) + new_data.center_label = grid_label + + datas.append(new_data) + return datas + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in tq(data)] + data = list(chain(*data)) # 2d list needs to be flatten + else: + data = self._process(data) + return data + + def __repr__(self): + return "{}(radius={}, center={})".format(self.__class__.__name__, self._radius, self._center) + + +class ComputeKDTree(object): + """Calculate the KDTree and saves it within data + + Parameters + ----------- + leaf_size:int + Size of the leaf node. + """ + + def __init__(self, leaf_size): + self._leaf_size = leaf_size + + def _process(self, data): + data.kd_tree = KDTree(np.asarray(data.pos), leaf_size=self._leaf_size) + return data + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in data] + else: + data = self._process(data) + return data + + def __repr__(self): + return "{}(leaf_size={})".format(self.__class__.__name__, self._leaf_size) + + +class RandomSphere(object): + """Select points within a sphere of a given radius. The centre is chosen randomly within the point cloud. + + Parameters + ---------- + radius: float + Radius of the sphere to be sampled. + strategy: str + choose between `random` and `freq_class_based`. The `freq_class_based` \ + favors points with low frequency class. This can be used to balance unbalanced datasets + center: bool + if True then the sphere will be moved to the origin + """ + + def __init__(self, radius, strategy="random", class_weight_method="sqrt", center=True): + self._radius = eval(radius) if isinstance(radius, str) else float(radius) + self._sampling_strategy = SamplingStrategy(strategy=strategy, class_weight_method=class_weight_method) + self._center = center + + def _process(self, data): + # apply sampling strategy + random_center = self._sampling_strategy(data) + random_center = np.asarray(data.pos[random_center])[np.newaxis] + sphere_sampling = SphereSampling(self._radius, random_center, align_origin=self._center) + return sphere_sampling(data) + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in data] + else: + data = self._process(data) + return data + + def __repr__(self): + return "{}(radius={}, center={}, sampling_strategy={})".format( + self.__class__.__name__, self._radius, self._center, self._sampling_strategy + ) + + +class SphereSampling: + """ Samples points within a sphere + + Parameters + ---------- + radius : float + Radius of the sphere + sphere_centre : torch.Tensor or np.array + Centre of the sphere (1D array that contains (x,y,z)) + align_origin : bool, optional + move resulting point cloud to origin + """ + + KDTREE_KEY = KDTREE_KEY + + def __init__(self, radius, sphere_centre, align_origin=True): + self._radius = radius + self._centre = np.asarray(sphere_centre) + if len(self._centre.shape) == 1: + self._centre = np.expand_dims(self._centre, 0) + self._align_origin = align_origin + + def __call__(self, data): + num_points = data.pos.shape[0] + if not hasattr(data, self.KDTREE_KEY): + tree = KDTree(np.asarray(data.pos), leaf_size=50) + setattr(data, self.KDTREE_KEY, tree) + else: + tree = getattr(data, self.KDTREE_KEY) + + t_center = torch.FloatTensor(self._centre) + ind = torch.LongTensor(tree.query_radius(self._centre, r=self._radius)[0]) + new_data = Data() + for key in set(data.keys): + if key == self.KDTREE_KEY: + continue + item = data[key] + if torch.is_tensor(item) and num_points == item.shape[0]: + item = item[ind] + if self._align_origin and key == "pos": # Center the sphere. + item -= t_center + elif torch.is_tensor(item): + item = item.clone() + setattr(new_data, key, item) + return new_data + + def __repr__(self): + return "{}(radius={}, center={}, align_origin={})".format( + self.__class__.__name__, self._radius, self._centre, self._align_origin + ) + + +class CylinderSampling: + """ Samples points within a cylinder + + Parameters + ---------- + radius : float + Radius of the cylinder + cylinder_centre : torch.Tensor or np.array + Centre of the cylinder (1D array that contains (x,y,z) or (x,y)) + align_origin : bool, optional + move resulting point cloud to origin + """ + + KDTREE_KEY = KDTREE_KEY + + def __init__(self, radius, cylinder_centre, align_origin=True): + self._radius = radius + if cylinder_centre.shape[0] == 3: + cylinder_centre = cylinder_centre[:-1] + self._centre = np.asarray(cylinder_centre) + if len(self._centre.shape) == 1: + self._centre = np.expand_dims(self._centre, 0) + self._align_origin = align_origin + + def __call__(self, data): + num_points = data.pos.shape[0] + if not hasattr(data, self.KDTREE_KEY): + tree = KDTree(np.asarray(data.pos[:, :-1]), leaf_size=50) + setattr(data, self.KDTREE_KEY, tree) + else: + tree = getattr(data, self.KDTREE_KEY) + + t_center = torch.FloatTensor(self._centre) + ind = torch.LongTensor(tree.query_radius(self._centre, r=self._radius)[0]) + + new_data = Data() + for key in set(data.keys): + if key == self.KDTREE_KEY: + continue + item = data[key] + if torch.is_tensor(item) and num_points == item.shape[0]: + item = item[ind] + if self._align_origin and key == "pos": # Center the cylinder. + item[:, :-1] -= t_center + elif torch.is_tensor(item): + item = item.clone() + setattr(new_data, key, item) + return new_data + + def __repr__(self): + return "{}(radius={}, center={}, align_origin={})".format( + self.__class__.__name__, self._radius, self._centre, self._align_origin + ) + + +class Select: + """ Selects given points from a data object + + Parameters + ---------- + indices : torch.Tensor + indeices of the points to keep. Can also be a boolean mask + """ + + def __init__(self, indices=None): + self._indices = indices + + def __call__(self, data): + num_points = data.pos.shape[0] + new_data = Data() + for key in data.keys: + if key == KDTREE_KEY: + continue + item = data[key] + if torch.is_tensor(item) and num_points == item.shape[0]: + item = item[self._indices].clone() + elif torch.is_tensor(item): + item = item.clone() + setattr(new_data, key, item) + return new_data + + +class CylinderNormalizeScale(object): + """ Normalize points within a cylinder + + """ + + def __init__(self, normalize_z=True): + self._normalize_z = normalize_z + + def _process(self, data): + data.pos -= data.pos.mean(dim=0, keepdim=True) + scale = (1 / data.pos[:, :-1].abs().max()) * 0.999999 + data.pos[:, :-1] = data.pos[:, :-1] * scale + if self._normalize_z: + scale = (1 / data.pos[:, -1].abs().max()) * 0.999999 + data.pos[:, -1] = data.pos[:, -1] * scale + return data + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in data] + else: + data = self._process(data) + return data + + def __repr__(self): + return "{}(normalize_z={})".format(self.__class__.__name__, self._normalize_z) + + +class RandomSymmetry(object): + """ Apply a random symmetry transformation on the data + + Parameters + ---------- + axis: Tuple[bool,bool,bool], optional + axis along which the symmetry is applied + """ + + def __init__(self, axis=[False, False, False]): + self.axis = axis + + def __call__(self, data): + + for i, ax in enumerate(self.axis): + if ax: + if torch.rand(1) < 0.5: + c_max = torch.max(data.pos[:, i]) + data.pos[:, i] = c_max - data.pos[:, i] + return data + + def __repr__(self): + return "Random symmetry of axes: x={}, y={}, z={}".format(*self.axis) + + +class RandomNoise(object): + """ Simple isotropic additive gaussian noise (Jitter) + + Parameters + ---------- + sigma: + Variance of the noise + clip: + Maximum amplitude of the noise + """ + + def __init__(self, sigma=0.01, clip=0.05, p: float = None): + self.sigma = sigma + self.clip = clip + self.p = 1 if p is None else p + + def __call__(self, data): + if random.random() < self.p: + noise = self.sigma * torch.randn(data.pos.shape) + noise = noise.clamp(-self.clip, self.clip) + data.pos = data.pos + noise + return data + + def __repr__(self): + return "{}(sigma={}, clip={})".format(self.__class__.__name__, self.sigma, self.clip) + + +class StatZOutlierRemoval: + def __init__(self, threshold: float = 4, skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + self.threshold = threshold # std deviation + + def __call__(self, data): + z = data.pos[:, 2] + m = z.mean() + s = z.std() + out = abs((z - m) / s) + mask = out < self.threshold + data = apply_mask(data, mask, self.skip_list) + return data + + def __repr__(self): + return "{}(threshold={})".format(self.__class__.__name__, self.p) + + +class DBSCANZOutlierRemoval: + def __init__(self, eps: float = 1, min_samples: int = 10, skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + self.eps = eps + self.min_samples = min_samples + self.dbscan = DBSCAN1D(eps=eps, min_samples=min_samples) + + def __call__(self, data): + z = data.pos[:, 2] + label = torch.tensor(self.dbscan.fit_predict(z[:, None])) + mask = label != -1 + mask = (z <= z[mask].max()) & (z >= z[mask].min()) + data = apply_mask(data, mask, self.skip_list) + return data + + def __repr__(self): + return "{}(eps={},min_samples={})".format(self.__class__.__name__, self.eps, self.min_samples) + + +class OPTICSZOutlierRemoval: + def __init__(self, eps: float = 1, min_samples: int = 10, skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + self.eps = eps + self.min_samples = min_samples + self.dbscan = OPTICS(eps=eps, min_samples=min_samples, cluster_method="dbscan") + + def __call__(self, data): + z = data.pos[:, 2] + label = torch.tensor(self.dbscan.fit_predict(z[:, None])) + mask = label != -1 + mask = (z <= z[mask].max()) & (z >= z[mask].min()) + # if ~(mask).any(): + # from openpoints.dataset import vis_points + # vis_points(data['pos'], mask) + data = apply_mask(data, mask, self.skip_list) + return data + + def __repr__(self): + return "{}(eps={},min_samples={})".format(self.__class__.__name__, self.eps, self.min_samples) + + +class KernelDensityZOutlierRemoval: + def __init__(self, bandwidth: float = 1, p: float = 0.05, skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + self.bandwidth = bandwidth + self.p = p + self.kd = KernelDensity(kernel="gaussian", bandwidth=bandwidth) + + def __call__(self, data): + z = data.pos[:, 2] + label = torch.tensor(self.kd.fit(z[:, None]).score_samples(z[:, None])) + mask = label > np.log(self.p) + mask = (z <= z[mask].max()) & (z >= z[mask].min()) + # if ~(mask).any(): + # from openpoints.dataset import vis_points + # vis_points(data['pos'], mask) + data = apply_mask(data, mask, self.skip_list) + return data + + def __repr__(self): + return "{}(bandwidth={},p={})".format(self.__class__.__name__, self.bandwidth, self.p) + + +class ScalePos: + def __init__(self, scale_x=1., scale_y=1., scale_z=1., op="mul"): + self.scale = torch.tensor([scale_x, scale_y, scale_z]).unsqueeze(0) + self.op_str = op + self.op = torch.mul if op == "mul" else torch.div + + def __call__(self, data): + data.pos = self.op(data.pos, self.scale) + return data + + def __repr__(self): + return "{}(scale={},op={})".format(self.__class__.__name__, self.scale, self.op_str) + + +def maxmin_center(data): + return (data.pos.amax(dim=0, keepdim=True) + data.pos.amin(dim=0, keepdim=True)) / 2. + + +def quantile_center(data): + return (torch.quantile(data.pos, 0.99, dim=0, keepdim=True) + + torch.quantile(data.pos, 0.01, dim=0, keepdim=True)) / 2. + + +def mean_center(data): + return data.pos.mean(axis=0, keepdims=True) + + +class CenterPosPerSample: + r"""Centers point positions by a defined 'center' function. + Parameters + ----------- + center_x: bool + centering the x-axis. + center_y: bool + centering the y-axis. + center_z: bool + centering the z-axis. + center: str + which center function is used (choose from: 'mean', 'quantile', 'maxmin'). + """ + + def __init__(self, center_x: bool = True, center_y: bool = True, center_z: bool = False, center: str = "mean"): + self.center_ = torch.FloatTensor([[center_x, center_y, center_z]]) + self.center_x = center_x + self.center_y = center_y + self.center_z = center_z + self.center_any = center_x or center_z or center_y + self.center = center + if center == "mean": + self.agg = mean_center + elif center == "quantile": + self.agg = quantile_center + elif center == "maxmin": + self.agg = maxmin_center + else: + raise Exception(f"Unknown center function: {center} (should be 'mean', 'quantile', or 'maxmin')") + + def __call__(self, data): + if self.center_any: + center = self.agg(data) * self.center_ + data.pos -= center + return data + + def __repr__(self): + return "{}(center_x={},center_y={},center_z={},center={})".format( + self.__class__.__name__, self.center_x, self.center_y, self.center_z, self.center + ) + + +class CenterXYbyZ: + r"""Centers xy point positions by z selected points. + Parameters + ----------- + center_x: float + centering the x-axis. + center_y: float + centering the y-axis. + z_thresh_min: float + min threshold for selecting z. + z_thresh_max: float + max threshold for selecting z. + """ + + def __init__(self, center_x: float = 0., center_y: float = 0., z_thresh_min: float = 0., z_thresh_max: float = 1.): + self.z_thresh_min = z_thresh_min + self.z_thresh_max = z_thresh_max + self.center_ = torch.FloatTensor([[center_x, center_y]]) + + def __call__(self, data): + z_points = (self.z_thresh_min < data.pos[:, 2]) & (data.pos[:, 2] < self.z_thresh_max) + pos = data.pos[:, :2] + amax = pos[z_points].amax(0, keepdim=True) + amin = pos[z_points].amin(0, keepdim=True) + pos -= (amax + amin) / 2. + pos += self.center_ + data.pos[:, :2] = pos + data["pos_deviation"] = amax - amin + data["pos_center_points"] = z_points.sum() + return data + + def __repr__(self): + return "{}(center_x={},center_y={},z_thresh_min={},z_thresh_max={})".format( + self.__class__.__name__, self.center_[0, 0], self.center_[0, 1], self.z_thresh_min, self.z_thresh_max + ) + + +class FixedCenterPosPerSample: + r"""Centers point positions by a defined 'center' function. + Parameters + ----------- + center_x: float + centering the x-axis. + center_y: float + centering the y-axis. + center_z: float + centering the z-axis. + """ + + def __init__(self, center_x: float = 0.5, center_y: float = 0.5, center_z: float = 0.5): + self.center_ = torch.FloatTensor([[center_x, center_y, center_z]]) + + def __call__(self, data): + data.pos -= (data.pos.amax(0, keepdim=True) + data.pos.amin(0, keepdim=True)) / 2. + data.pos += self.center_ + return data + + def __repr__(self): + return "{}(center_x={},center_y={},center_z={})".format( + self.__class__.__name__, self.center_[0, 0], self.center_[0, 1], self.center_[0, 2], + ) + + +class MoveCenterPosPerSample: + r"""Centers point positions by a defined 'center' function. + Parameters + ----------- + center_x: float + centering the x-axis. + center_y: float + centering the y-axis. + center_z: float + centering the z-axis. + """ + + def __init__(self, center_x: float = 0.5, center_y: float = 0.5, center_z: float = 0.5): + self.center_ = torch.FloatTensor([[center_x, center_y, center_z]]) + + def __call__(self, data): + data.pos += self.center_ + return data + + def __repr__(self): + return "{}(center_x={},center_y={},center_z={})".format( + self.__class__.__name__, self.center_[0, 0], self.center_[0, 1], self.center_[0, 2], + ) + + +class RandomShiftPos: + def __init__(self, max_x: float = 0.01, max_y: float = 0.01, max_z: float = 0.01, p: float = 0.5): + self.max_ = torch.FloatTensor([[max_x, max_y, max_y]]) + self.max_x = max_x + self.max_y = max_y + self.max_z = max_z + self.p = p + + def __call__(self, data): + if random.random() > self.p: + data.pos += (torch.rand(1, 3) * 2 * self.max_) - self.max_ + return data + + def __repr__(self): + return "{}(max_x={},max_y={},max_z={},p={})".format( + self.__class__.__name__, self.max_x, self.max_y, self.max_z, self.p + ) + + +class StartZFromZero: + def __call__(self, data): + data.pos[:, 2] -= data.pos[:, 2].min() + return data + + def __repr__(self): + return "{}()".format(self.__class__.__name__, ) + + +class AddRandomPoints: + r"""Add points randomly within existing cloud bounds. Only works without additional features. + Intended for regression or classification (will not add point-wise labels). + Parameters + ----------- + n_max_points: int + Maximal total number of points (will not add points if there are already many). + add_ratio_min: float + Minimal amount of points to add according to existing number of points. + add_ratio_max: float + Maximal amount of points to add according to existing number of points. + + """ + + def __init__(self, n_max_points: int, add_ratio_min: float, add_ratio_max: float, p: float = 0.5): + self.n_max_points = n_max_points + self.add_ratio_min = add_ratio_min + self.add_ratio_max = add_ratio_max + self.p = p + + def __call__(self, data): + n_ori_points = len(data.pos) + if n_ori_points >= self.n_max_points: + return data + + if self.p > random.random(): + ratio = random.random() * (self.add_ratio_max - self.add_ratio_min) + self.add_ratio_min + n_points = int(ratio * n_ori_points) + n_points += np.amin([0, self.n_max_points - (n_ori_points + n_points)]) # remove points if necessary + + min_ = data.pos.amin(0, keepdim=True) + max_ = data.pos.amin(0, keepdim=True) + random_points = (torch.rand(n_points, data.pos.shape[1]) * (max_ - min_) + min_) + + data.pos = torch.cat([data.pos, random_points], 0) + return data + + def __repr__(self): + return "{}(n_max_points={},add_ratio_min={},add_ratio_max={},p={})".format( + self.__class__.__name__, self.n_max_points, self.add_ratio_min, self.add_ratio_max, self.p + ) + + +class CopyJitterRandomPoints: + r"""Randomly copies and jitters points. Will also copy features and labels (if present) but not alter them. + Parameters + ----------- + n_max_points: int + Maximal total number of points (will not add points if there are already many). + add_ratio_min: float + Minimal amount of points to add according to existing number of points. + add_ratio_max: float + Maximal amount of points to add according to existing number of points. + sigma: + Variance of the noise + clip: + Maximum amplitude of the noise + + + """ + + def __init__(self, n_max_points: int, add_ratio_min: float, add_ratio_max: float, + sigma: float, clip: float, p: float = 0.5): + self.n_max_points = n_max_points + self.add_ratio_min = add_ratio_min + self.add_ratio_max = add_ratio_max + self.sigma = sigma + self.clip = clip + self.p = p + + def __call__(self, data): + n_ori_points = len(data.pos) + if n_ori_points >= self.n_max_points: + return data + + if self.p > random.random(): + ratio = random.random() * (self.add_ratio_max - self.add_ratio_min) + self.add_ratio_min + n_points = int(ratio * n_ori_points) + n_points += np.amin([0, self.n_max_points - (n_ori_points + n_points)]) # remove points if necessary + + idx = np.random.choice(n_ori_points, size=n_points, replace=True) + random_points = data.pos[idx].clone() + noise = self.sigma * torch.randn(random_points.shape) + noise = noise.clamp(-self.clip, self.clip) + random_points += noise + + if data.x is not None: + data.x = torch.cat([data.x, data.x[idx].clone()], 0) + if data.y is not None and len(data.y) == len(data.pos): + data.y = torch.cat([data.y, data.y[idx].clone()], 0) + + data.pos = torch.cat([data.pos, random_points], 0) + return data + + def __repr__(self): + return "{}(n_max_points={},add_ratio_min={},add_ratio_max={},sigma={},clip={},p={})".format( + self.__class__.__name__, self.n_max_points, self.add_ratio_min, self.add_ratio_max, + self.sigma, self.clip, self.p + ) + + +class RandomScaling: + r""" Scales node positions by a randomly sampled factor ``s1, s2, s3`` within a + given interval, *e.g.*, resulting in the transformation matrix + + .. math:: + \left[ + \begin{array}{ccc} + s1 & 0 & 0 \\ + 0 & s2 & 0 \\ + 0 & 0 & s3 \\ + \end{array} + \right] + + + for three-dimensional positions. + + Parameters + ----------- + scales: + scaling factor interval, e.g. ``(a, b)``, then scale \ + is randomly sampled from the range \ + ``a <= b``. \ + """ + + def __init__(self, scales=None): + assert is_iterable(scales) and len(scales) == 2 + assert scales[0] <= scales[1] + self.scales = scales + + def __call__(self, data): + scale = self.scales[0] + torch.rand((3,)) * (self.scales[1] - self.scales[0]) + data.pos = data.pos * scale + if getattr(data, "norm", None) is not None: + data.norm = data.norm / scale + data.norm = torch.nn.functional.normalize(data.norm, dim=1) + return data + + def __repr__(self): + return "{}({})".format(self.__class__.__name__, self.scales) + + +class MeshToNormal(object): + """ Computes mesh normals (IN PROGRESS) + """ + + def __init__(self): + pass + + def __call__(self, data): + if hasattr(data, "face"): + pos = data.pos + face = data.face + vertices = [pos[f] for f in face] + normals = torch.cross(vertices[0] - vertices[1], vertices[0] - vertices[2], dim=1) + normals = F.normalize(normals) + data.normals = normals + return data + + def __repr__(self): + return "{}".format(self.__class__.__name__) + + +class MultiScaleTransform(object): + """ Pre-computes a sequence of downsampling / neighboorhood search on the CPU. + This currently only works on PARTIAL_DENSE formats + + Parameters + ----------- + strategies: Dict[str, object] + Dictionary that contains the samplers and neighbour_finder + """ + + def __init__(self, strategies): + self.strategies = strategies + self.num_layers = len(self.strategies["sampler"]) + + @staticmethod + def __inc__wrapper(func, special_params): + def new__inc__(key, num_nodes, special_params=None, func=None): + if key in special_params: + return special_params[key] + else: + return func(key, num_nodes) + + return partial(new__inc__, special_params=special_params, func=func) + + def __call__(self, data: Data) -> MultiScaleData: + # Compute sequentially multi_scale indexes on cpu + data.contiguous() + ms_data = MultiScaleData.from_data(data) + precomputed = [Data(pos=data.pos)] + upsample = [] + upsample_index = 0 + for index in range(self.num_layers): + sampler, neighbour_finder = self.strategies["sampler"][index], self.strategies["neighbour_finder"][index] + support = precomputed[index] + new_data = Data(pos=support.pos) + if sampler: + query = sampler(new_data.clone()) + query.contiguous() + + if len(self.strategies["upsample_op"]): + if upsample_index >= len(self.strategies["upsample_op"]): + raise ValueError("You are missing some upsample blocks in your network") + + upsampler = self.strategies["upsample_op"][upsample_index] + upsample_index += 1 + pre_up = upsampler.precompute(query, support) + upsample.append(pre_up) + special_params = {} + special_params["x_idx"] = query.num_nodes + special_params["y_idx"] = support.num_nodes + setattr(pre_up, "__inc__", self.__inc__wrapper(pre_up.__inc__, special_params)) + else: + query = new_data + + s_pos, q_pos = support.pos, query.pos + if hasattr(query, "batch"): + s_batch, q_batch = support.batch, query.batch + else: + s_batch, q_batch = ( + torch.zeros((s_pos.shape[0]), dtype=torch.long), + torch.zeros((q_pos.shape[0]), dtype=torch.long), + ) + + idx_neighboors = neighbour_finder(s_pos, q_pos, batch_x=s_batch, batch_y=q_batch) + special_params = {} + special_params["idx_neighboors"] = s_pos.shape[0] + setattr(query, "idx_neighboors", idx_neighboors) + setattr(query, "__inc__", self.__inc__wrapper(query.__inc__, special_params)) + precomputed.append(query) + ms_data.multiscale = precomputed[1:] + upsample.reverse() # Switch to inner layer first + ms_data.upsample = upsample + return ms_data + + def __repr__(self): + return "{}".format(self.__class__.__name__) + + +class ShuffleData(object): + """ This transform allow to shuffle feature, pos and label tensors within data + """ + + def _process(self, data): + return shuffle_data(data) + + def __call__(self, data): + if isinstance(data, list): + data = [self._process(d) for d in tq(data)] + data = list(chain(*data)) # 2d list needs to be flatten + else: + data = self._process(data) + return data + + +class ShiftVoxels: + """ Trick to make Sparse conv invariant to even and odds coordinates + https://github.com/chrischoy/SpatioTemporalSegmentation/blob/master/lib/train.py#L78 + + Parameters + ----------- + apply_shift: bool: + Whether to apply the shift on indices + """ + + def __init__(self, apply_shift=True, p=0.5): + self._apply_shift = apply_shift + self.p = p + + def __call__(self, data): + if self._apply_shift and random.random() < self.p: + if not hasattr(data, "coords"): + raise Exception("should quantize first using GridSampling3D") + + if not isinstance(data.coords, torch.IntTensor): + raise Exception("The pos are expected to be coordinates, so torch.IntTensor") + data.coords[:, :3] += (torch.rand(3) * 100).type_as(data.coords) + return data + + def __repr__(self): + return "{}(apply_shift={})".format(self.__class__.__name__, self._apply_shift) + + +class RandomDropout: + """ Randomly drop points from the input data + + Parameters + ---------- + dropout_ratio : float, optional + Ratio that gets dropped + dropout_application_ratio : float, optional + chances of the dropout to be applied + """ + + def __init__(self, dropout_ratio: float = 0.2, dropout_application_ratio: float = 0.5, min_points: int = 0, + skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + self.dropout_ratio = dropout_ratio + self.dropout_application_ratio = dropout_application_ratio + self.min_points = min_points + + def __call__(self, data): + N = len(data.pos) + if N > self.min_points and random.random() < self.dropout_application_ratio: + data = FixedPointsOwn(int(N * (1 - self.dropout_ratio)), skip_list=self.skip_list)(data) + return data + + def __repr__(self): + return "{}(dropout_ratio={}, dropout_application_ratio={})".format( + self.__class__.__name__, self.dropout_ratio, self.dropout_application_ratio + ) + + +def apply_mask(data, mask, skip_keys=[]): + size_pos = len(data.pos) + for k in data.keys: + if torch.is_tensor(data[k]) and size_pos == len(data[k]) and k not in skip_keys: + data[k] = data[k][mask] + return data + + +@numba.jit(nopython=True, cache=True) +def rw_mask(pos, ind, dist, mask_vertices, random_ratio=0.04, num_iter=5000): + rand_ind = np.random.randint(0, len(pos)) + for _ in range(num_iter): + mask_vertices[rand_ind] = False + if np.random.rand() < random_ratio: + rand_ind = np.random.randint(0, len(pos)) + else: + neighbors = ind[rand_ind][dist[rand_ind] > 0] + if len(neighbors) == 0: + rand_ind = np.random.randint(0, len(pos)) + else: + n_i = np.random.randint(0, len(neighbors)) + rand_ind = neighbors[n_i] + return mask_vertices + + +def topview_sample(data, num_samples: int): + # simulates a little airborne lidar behavior (discarding of lower points more likely) + num_nodes = data.num_nodes + z = data.pos[:, 2].numpy() + choice = random.choices(np.arange(num_nodes), weights=z, k=num_samples) + + for key, item in data: + if key == 'num_nodes': + data.num_nodes = choice.size(0) + elif (torch.is_tensor(item) and item.size(0) == num_nodes + and item.size(0) != 1): + data[key] = item[choice] + + return data + + +class RandomGroundRemoval: + def __init__(self, min_v: float, max_v: float, p: float = 0.5, min_points: int = 500, skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + self.min_v = min_v + self.max_v = max_v + self.range = max_v - min_v + self.p = p + self.min_points = min_points + + def __call__(self, data): + if random.random() < self.p: + pos = data.pos + remove_v = random.random() * self.range + self.min_v + cond = pos[:, 2] > remove_v + if cond.sum() < self.min_points: + return data + pos[:, 2] -= remove_v + data = apply_mask(data, cond, self.skip_list) + + return data + + +class RadiusObjectAdder: + def __init__(self, areas, root_folder: str, dataset_name: str, processed_folder: str, + min_radius: float, max_radius: float, n_max_objects, + rot_x: float, rot_y: float, rot_z: float, indicator_key: str = None, + adjust_point_density: bool = False, density_topview_sample: bool = False, density_index: int = 0, + density_adjustment: list = 1., split: str = "train", zero_center_z: bool = False, + only_doubled_batch: bool = False, in_memory: bool = False, p: float = 0.5): + areas = OmegaConf.to_container(areas) + self.areas = {area: areas[area] for area in areas if areas[area]["type"] == "object"} + self.processed_dir = PPath(os.path.join(root_folder, dataset_name, processed_folder, split)) + self.object_files = list(chain(*[glob(str(self.processed_dir / f"{area}/*.pt")) for area in self.areas])) + + self.min_radius = min_radius + self.max_radius = max_radius + self.adjust_point_density = adjust_point_density + if adjust_point_density: + self.density_index = density_index + self.density_topview_sample = density_topview_sample + # adjust original data density given as range + # (e.g., 0.5 will decrease original point density, + # thus potentially removing more points from the added object) + self.density_adjustment = (density_adjustment[0], density_adjustment[1]) + self.memory = {} + self.in_memory = in_memory + self.random_rotation = Random3AxisRotation(rot_x=rot_x, rot_y=rot_y, rot_z=rot_z) + if isinstance(n_max_objects, int): + n_max_objects = { + "object": n_max_objects, + "scene": n_max_objects, + } + self.n_max_objects: dict = n_max_objects + self.p: float = p + self.zero_center_z = zero_center_z + self.indicator_key = indicator_key + self.only_doubled_batch = only_doubled_batch + + def __call__(self, data): + if len(self.object_files) == 0: + self.object_files = list(chain(*[glob(str(self.processed_dir / f"{area}/*.pt")) for area in self.areas])) + assert len(self.object_files) > 0, "no objects given for RadiusObjectAdder" + ori_n = None + if random.random() < self.p and ( + not self.only_doubled_batch or (self.only_doubled_batch and data.get("is_double", False))): + sample_type = "object" if data.area_name in self.areas else "scene" + n_objects = random.randint(1, self.n_max_objects[sample_type]) + files = np.random.choice(self.object_files, n_objects, replace=True) + pos_ = [] + feat_ = [] + i = 0 + while i < len(files): + file = files[i] + i += 1 + if self.in_memory: + new_object = self.memory.get(file, None) + if new_object is None: + new_object = torch.load(file) + self.memory[file] = new_object.clone() + else: + new_object = new_object.clone() + else: + new_object = torch.load(file) + if self.zero_center_z: + new_object.pos[:, 2] -= new_object.pos[:, 2].min() + new_object = self.random_rotation(new_object) + + if self.adjust_point_density: + # only removes points if too dense, will not add points + sample_density = data["local_stats"][self.density_index] + obj_density = new_object["local_stats"][self.density_index] + density_adjustment_factor = random.random() + density_adjustment_factor *= self.density_adjustment[1] - self.density_adjustment[0] + density_adjustment_factor += self.density_adjustment[0] + drop_ratio = (sample_density * density_adjustment_factor) / obj_density + if drop_ratio < 1: + if self.density_topview_sample: + new_object = topview_sample(new_object, int(drop_ratio * len(new_object.pos))) + else: + new_object = FP(int(drop_ratio * len(new_object.pos)), replace=False)(new_object) + + # random point in outer circle + angle = random.uniform(0, 2 * math.pi) + + min_radius = self.min_radius + max_radius = self.max_radius + # add safety margin if we are given center deviation + if "pos_deviation" in new_object: + min_radius += (new_object["pos_deviation"] ** 2).sum() ** .5 / 2 # pythagoras + if min_radius > max_radius: # add another object to list and skip this one + files = np.concatenate([files, np.random.choice(self.object_files, 1)], axis=0) + continue + radius = random.uniform(min_radius, max_radius) + shift = torch.tensor(([[math.cos(angle), math.sin(angle), 0]])) * radius # no shift in z + + pos_.append(new_object.pos + shift) + feat_.append(new_object.x) + + ori_n = len(data.pos) + data.pos = torch.cat([data.pos, *pos_], 0) + if data.x is not None: + if len(feat_) > 0 and feat_[0] is not None: + data.x = torch.cat([data.x, *feat_], 0) + else: + data.x = torch.cat([data.x, torch.zeros(len(data.pos) - ori_n, data.x.shape[1])], 0) + + if self.indicator_key is not None: + if ori_n is not None: + indicator = torch.zeros(len(data.pos)) + indicator[ori_n:] = True + else: + indicator = torch.zeros(len(data.pos)) + + data[self.indicator_key] = indicator + return data + + +class CubeCrop(object): + """ + Crop cubically the point cloud. This function take a cube of size c + centered on a random point, then points outside the cube are rejected. + + Parameters + ---------- + c: float, optional + half size of the cube + rot_x: float_otional + rotation of the cube around x axis + rot_y: float_otional + rotation of the cube around x axis + rot_z: float_otional + rotation of the cube around x axis + """ + + def __init__( + self, c: float = 1, rot_x: float = 180, rot_y: float = 180, rot_z: float = 180, + grid_size_center: float = 0.01 + ): + self.c = c + self.random_rotation = Random3AxisRotation(rot_x=rot_x, rot_y=rot_y, rot_z=rot_z) + self.grid_sampling = GridSampling3D(grid_size_center, mode="last") + + def __call__(self, data): + data_c = self.grid_sampling(data.clone()) + data_temp = data.clone() + i = torch.randint(0, len(data_c.pos), (1,)) + center = data_c.pos[i] + min_square = center - self.c + max_square = center + self.c + data_temp.pos = data_temp.pos - center + data_temp = self.random_rotation(data_temp) + data_temp.pos = data_temp.pos + center + mask = torch.prod((data_temp.pos - min_square) > 0, dim=1) * torch.prod((max_square - data_temp.pos) > 0, dim=1) + mask = mask.to(torch.bool) + data = apply_mask(data, mask) + return data + + def __repr__(self): + return "{}(c={}, rotation={})".format(self.__class__.__name__, self.c, self.random_rotation) + + +class FixedPointsOwn(object): + r"""Samples a fixed number of :obj:`num` points and features from a point + cloud (functional name: :obj:`fixed_points`). + + Args: + num (int): The number of points to sample. + replace (bool, optional): If set to :obj:`False`, samples points + without replacement. (default: :obj:`True`) + allow_duplicates (bool, optional): In case :obj:`replace` is + :obj`False` and :obj:`num` is greater than the number of points, + this option determines whether to add duplicated nodes to the + output points or not. + In case :obj:`allow_duplicates` is :obj:`False`, the number of + output points might be smaller than :obj:`num`. + In case :obj:`allow_duplicates` is :obj:`True`, the number of + duplicated points are kept to a minimum. (default: :obj:`False`) + """ + + def __init__(self, num, replace=False, allow_duplicates=True, skip_list: list = None): + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) if isinstance(skip_list, + OmegaConf) else skip_list + self.num = num + self.replace = replace + self.allow_duplicates = allow_duplicates + + def __call__(self, data): + num_nodes = data.num_nodes + + if self.replace: + choice = np.random.choice(num_nodes, self.num, replace=True) + choice = torch.from_numpy(choice).to(torch.long) + elif not self.allow_duplicates: + choice = torch.randperm(num_nodes)[:self.num] + else: + choice = torch.cat([ + torch.randperm(num_nodes) + for _ in range(math.ceil(self.num / num_nodes)) + ], dim=0)[:self.num] + + for key, item in data: + if key == 'num_nodes': + data.num_nodes = choice.size(0) + elif bool(re.search('edge', key)): + continue + elif (torch.is_tensor(item) and item.size(0) == num_nodes and key not in self.skip_list + and (item.size(0) != 1) or key == "pos"): + data[key] = item[choice] + assert data.pos.shape[ + 0] == self.num, f"pos: {data.pos.shape}, y: {data.y_mol.shape}, {data.y_mol_mask.shape}, choice: {len(choice)} {self.num}" + return data + + +class CylinderExtend(object): + """ + Restrict extend the point cloud to a cylinder. This function take a radius + centered at the origin, then points outside are rejected. + Parameters + ---------- + radius: float + half size of the x axis of the rectangle + skip_list: list + list of keys not to mask away + """ + + def __init__(self, radius: float, skip_list: list = None): + self.radius = radius + self.skip_list = skip_list + + def __call__(self, data): + pos = data.pos + if not hasattr(data, KDTREE_KEY): + tree = KDTree(np.asarray(pos[:, :-1]), leaf_size=50) + setattr(data, KDTREE_KEY, tree) + else: + tree = getattr(data, KDTREE_KEY) + idx = tree.query_radius([[0., 0.]], self.radius)[0] + mask = torch.zeros(len(pos)).bool() + mask[idx] = True + + data = apply_mask(data, mask, self.skip_list) + return data + + def __repr__(self): + return "{}(radius={}, skip_list={})".format(self.__class__.__name__, self.radius, self.skip_list) + + +class RectangleExtend(object): + """ + Restrict extend the point cloud to a rectangle. This function take a rectangle of size (e_x, e_y, e_z) + centered at the origin, then points outside are rejected. + Parameters + ---------- + e_x: float, optional + half size of the x axis of the rectangle + e_y: float, optional + half size of the y axis of the rectangle + e_z: float, optional + half size of the z axis of the rectangle + """ + + def __init__(self, e_x: float = 1, e_y: float = 1, e_z: float = 1, ): + self.e_x = e_x + self.e_y = e_y + self.e_z = e_z + + def __call__(self, data): + pos = data.pos + posx = pos[:, 0] + posy = pos[:, 1] + posz = pos[:, 2] + mask = (posx < self.e_x) & (posx > -self.e_x) & \ + (posy < self.e_y) & (posx > -self.e_y) & \ + (posz < self.e_z) & (posz > -self.e_z) + data = apply_mask(data, mask) + return data + + def __repr__(self): + return "{}(e_x={}, e_y={}, e_z={})".format(self.__class__.__name__, self.e_x, self.e_y, self.e_z) + + +def append_skeleton(self, data, skeleton): + if self.cage_skeleton: + min_z = data.pos[:, -1].min() + max_z = data.pos[:, -1].max() + heights = torch.arange(min_z, max_z + self.height_skeleton_pts, self.height_skeleton_pts).float() + n_heights = len(heights) + n_pts = len(skeleton) + skeleton = skeleton.repeat_interleave(n_heights, 0) + skeleton[:, 2] *= heights.reshape(-1).repeat(n_pts) + else: + + skeleton *= self.height_skeleton_pts + num_skeleton_pts = len(skeleton) + indicator = torch.zeros(len(data.pos) + num_skeleton_pts) + indicator[-num_skeleton_pts:] = 1.0 + # add empty features for skeleton + size_pos = len(data.pos) + for k in data.keys: + if torch.is_tensor(data[k]) and size_pos == len(data[k]) and k not in ["pos"] + self.skip_list: + dtype = data[k].dtype + if len(data[k].shape) > 1: + n_feat = data[k].shape[1] + data[k] = torch.cat([data[k], torch.ones(num_skeleton_pts, n_feat, dtype=dtype)], 0) + else: + data[k] = torch.cat([data[k], torch.ones(num_skeleton_pts, dtype=dtype)], 0) + data["skeleton"] = indicator + data["pos"] = torch.cat([data["pos"], skeleton], 0) + + +class Polygon2dExtend(object): + """ + Restrict extend the point cloud to a given polygon. This function takes point tuples of size + (e.g., [[0, 1], [1, 0], [1, 1]]). + centered at the origin, then points outside are rejected. + + Parameters + ---------- + polygon: list + List of tuples containing the border points of the polygon + """ + + def __init__(self, polygon, skip_list: list = None, add_skeleton_pts: bool = False, + num_skeleton_pts: int = 100, height_skeleton_pts: float = 1.0, + cage_skeleton: bool = False): + self.polygon = Path(polygon) + + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + + self.add_skeleton_pts = add_skeleton_pts + self.num_skeleton_pts = num_skeleton_pts + self.height_skeleton_pts = height_skeleton_pts + self.cage_skeleton = cage_skeleton + + if add_skeleton_pts: + skeleton = torch.tensor(self.polygon.interpolated(self.num_skeleton_pts).vertices).float() + self.skeleton = torch.cat([skeleton, torch.ones(len(skeleton), 1)], 1) + + def __call__(self, data): + pos = data.pos[:, [0, 1]] + mask = self.polygon.contains_points(pos) + data = apply_mask(data, mask, self.skip_list) + if self.add_skeleton_pts: + append_skeleton(self, data, self.skeleton) + + return data + + def __repr__(self): + return "{}(polygon={})".format(self.__class__.__name__, self.polygon.to_polygons()) + + +class RandomPolygon2dExtend(object): + """ + Restrict extend the point cloud to a given polygon. This function takes point tuples of size + (e.g., [[0, 1], [1, 0], [1, 1]]). + centered at the origin, then points outside are rejected. + + Parameters + ---------- + polygons: list + List of polygons, each defined by tuples containing the border points + """ + + def __init__(self, polygons: list, skip_list: list = None, size_min: float = 1, size_max: float = 1, + rotate: float = 180, + add_skeleton_pts: bool = False, num_skeleton_pts: int = 100, height_skeleton_pts: float = 1.0, + cage_skeleton: bool = False): + self.polygons = [polygon for polygon in polygons] + self.n_p = len(self.polygons) + self.size_min = size_min + self.size_max = size_max + self.rotate = rotate + + self.skip_list = [] if skip_list is None else OmegaConf.to_object(skip_list) + + self.add_skeleton_pts = add_skeleton_pts + self.num_skeleton_pts = num_skeleton_pts + self.height_skeleton_pts = height_skeleton_pts + self.cage_skeleton = cage_skeleton + + def __call__(self, data): + pos = data.pos[:, [0, 1]] + polygon = self.polygons[np.random.choice(self.n_p)] + if polygon != "None": + rand_scale = np.random.rand() * (self.size_max - self.size_min) + self.size_min + trans = (1 - rand_scale) / 2 + rand_rotate = np.random.rand() * self.rotate * np.sign(np.random.rand() - .5) + A = Affine2D().scale(rand_scale).translate(trans, trans).rotate_deg_around(0.5, 0.5, rand_rotate) + polygon = Path(polygon).transformed(A) + mask = polygon.contains_points(pos) + if mask.sum() > 0: # apply masking if any points remain + data = apply_mask(data, mask, self.skip_list) + if self.add_skeleton_pts: + skeleton = torch.tensor(polygon.interpolated(self.num_skeleton_pts).vertices).to(pos.dtype) + skeleton = torch.cat([skeleton, torch.ones(len(skeleton), 1)], 1) + + append_skeleton(self, data, skeleton) + elif self.add_skeleton_pts: + data["skeleton"] = torch.zeros(len(data.pos), 1) + return data + + def __repr__(self): + return "{}(polygons={}, size_min={}, size_max={}, rotate={})".format( + self.__class__.__name__, str(self.polygons), self.size_min, self.size_max, self.rotate + ) + + +class EllipsoidCrop(object): + """ + + """ + + def __init__( + self, a: float = 1, b: float = 1, c: float = 1, rot_x: float = 180, rot_y: float = 180, rot_z: float = 180 + ): + """ + Crop with respect to an ellipsoid. + the function of an ellipse is defined as: + + Parameters + ---------- + a: float, optional + half size of the cube + b: float_otional + rotation of the cube around x axis + c: float_otional + rotation of the cube around x axis + + + """ + self._a2 = a ** 2 + self._b2 = b ** 2 + self._c2 = c ** 2 + self.random_rotation = Random3AxisRotation(rot_x=rot_x, rot_y=rot_y, rot_z=rot_z) + + def _compute_mask(self, pos: torch.Tensor): + mask = (pos[:, 0] ** 2 / self._a2 + pos[:, 1] ** 2 / self._b2 + pos[:, 2] ** 2 / self._c2) < 1 + return mask + + def __call__(self, data): + data_temp = data.clone() + i = torch.randint(0, len(data.pos), (1,)) + data_temp = self.random_rotation(data_temp) + center = data_temp.pos[i] + data_temp.pos = data_temp.pos - center + mask = self._compute_mask(data_temp.pos) + data = apply_mask(data, mask) + return data + + def __repr__(self): + return "{}(a={}, b={}, c={}, rotation={})".format( + self.__class__.__name__, np.sqrt(self._a2), np.sqrt(self._b2), np.sqrt(self._c2), self.random_rotation + ) + + +class ZFilter(object): + """ + Remove points lower or higher than certain values + """ + + def __init__(self, z_min, z_max, skip_keys: List = []): + self.z_min = z_min + self.z_max = z_max + self.skip_keys = skip_keys + + def __call__(self, data): + z = data.pos[:, 2] + mask = (z > self.z_min) & (z < self.z_max) + + data = apply_mask(data, mask, self.skip_keys) + return data + + def __repr__(self): + return "{}(z_min={}, z_max={}, skip_keys={})".format( + self.__class__.__name__, self.z_min, self.z_max, self.skip_keys + ) + + +class DensityFilter(object): + """ + Remove points with a low density(compute the density with a radius search and remove points with) + a low number of neighbors + + Parameters + ---------- + radius_nn: float, optional + radius for the neighbors search + min_num: int, optional + minimum number of neighbors to be dense + skip_keys: int, optional + list of attributes of data to skip when we apply the mask + """ + + def __init__(self, radius_nn: float = 0.04, min_num: int = 6, skip_keys: List = []): + self.radius_nn = radius_nn + self.min_num = min_num + self.skip_keys = skip_keys + + def __call__(self, data): + ind, dist = ball_query(data.pos, data.pos, radius=self.radius_nn, max_num=-1, mode=0) + + mask = (dist > 0).sum(1) > self.min_num + data = apply_mask(data, mask, self.skip_keys) + return data + + def __repr__(self): + return "{}(radius_nn={}, min_num={}, skip_keys={})".format( + self.__class__.__name__, self.radius_nn, self.min_num, self.skip_keys + ) + + +class IrregularSampling(object): + """ + a sort of soft crop. the more we are far from the center, the more it is unlikely to choose the point + """ + + def __init__(self, d_half=2.5, p=2, grid_size_center=0.1, skip_keys=[]): + self.d_half = d_half + self.p = p + self.skip_keys = skip_keys + self.grid_sampling = GridSampling3D(grid_size_center, mode="last") + + def __call__(self, data): + data_temp = self.grid_sampling(data.clone()) + i = torch.randint(0, len(data_temp.pos), (1,)) + center = data_temp.pos[i] + + d_p = (torch.abs(data.pos - center) ** self.p).sum(1) + + sigma_2 = (self.d_half ** self.p) / (2 * np.log(2)) + thresh = torch.exp(-d_p / (2 * sigma_2)) + + mask = torch.rand(len(data.pos)) < thresh + data = apply_mask(data, mask, self.skip_keys) + return data + + def __repr__(self): + return "{}(d_half={}, p={}, skip_keys={})".format(self.__class__.__name__, self.d_half, self.p, self.skip_keys) + + +class PeriodicSampling(object): + """ + sample point at a periodic distance + """ + + def __init__(self, period=0.1, prop=0.1, box_multiplier=1, skip_keys=[]): + self.pulse = 2 * np.pi / period + self.thresh = np.cos(self.pulse * prop * period * 0.5) + self.box_multiplier = box_multiplier + self.skip_keys = skip_keys + + def __call__(self, data): + data_temp = data.clone() + max_p = data_temp.pos.max(0)[0] + min_p = data_temp.pos.min(0)[0] + + center = self.box_multiplier * torch.rand(3) * (max_p - min_p) + min_p + d_p = torch.norm(data.pos - center, dim=1) + mask = torch.cos(self.pulse * d_p) > self.thresh + data = apply_mask(data, mask, self.skip_keys) + return data + + def __repr__(self): + return "{}(pulse={}, thresh={}, box_mullti={}, skip_keys={})".format( + self.__class__.__name__, self.pulse, self.thresh, self.box_multiplier, self.skip_keys + ) + + +class AddGround: + '''simple class to add "n_points" ground points if less than "max_points" are present in a unit radius''' + + def __init__(self, max_points: int, n_points: int, xy_min: float = 0, xy_max: float = 1): + self.max_points = max_points + self.n_points = n_points + self.xy_range = (xy_max - xy_min) / 2. + self.xy_min = xy_min + + def __call__(self, data): + nodes = data.num_nodes + if nodes < self.max_points: + data.pos = torch.rand(self.n_points, 3) * self.xy_range + self.xy_min + data.pos[:, 2] = 0.0 + + return data + + def __repr__(self): + return "{}(max_points={}, n_points={})".format( + self.__class__.__name__, self.max_points, self.n_points + ) + + +class MinPoints(FixedPointsOwn): + r"""Samples a minimal number of :obj:`num` points and features from a point + cloud. + + Args: + num (int): The number of minimal points in point_idxs, resamples with replacement if less are present. + """ + + def __init__(self, num, skip_list: list = None): + super().__init__(num, False, True, skip_list) + + def __call__(self, data): + num_nodes = data.num_nodes + + if num_nodes < self.num: + # TODO verify state is persistent + state = np.random.get_state() + np.random.set_state(np.random.RandomState(42).get_state()) + data = super().__call__(data) + np.random.set_state(state) + return data + + return data + + def __repr__(self): + return "{}(num={}, skip_list={})".format( + self.__class__.__name__, self.num, self.skip_list + ) + + +class MaxPoints(FixedPointsOwn): + r"""Samples a maximal number of :obj:`num` points and features from a point + cloud. + + Args: + num (int): The number to maximal number of points in point_idxs, resamples without replacement. + """ + + def __init__(self, num, skip_list: list = None): + super().__init__(num, False, False, skip_list) + + def __call__(self, data): + num_nodes = data.num_nodes + + if num_nodes > self.num: + # TODO verify state is persistent + data = super().__call__(data) + return data + + return data + + def __repr__(self): + return "{}(num={}, skip_list={})".format( + self.__class__.__name__, self.num, self.skip_list + ) diff --git a/torch-points3d/torch_points3d/core/initializer/__init__.py b/torch-points3d/torch_points3d/core/initializer/__init__.py new file mode 100644 index 0000000..07c9955 --- /dev/null +++ b/torch-points3d/torch_points3d/core/initializer/__init__.py @@ -0,0 +1 @@ +from .initializer import * diff --git a/torch-points3d/torch_points3d/core/initializer/initializer.py b/torch-points3d/torch_points3d/core/initializer/initializer.py new file mode 100644 index 0000000..6cc0b7b --- /dev/null +++ b/torch-points3d/torch_points3d/core/initializer/initializer.py @@ -0,0 +1,35 @@ +import torch +from torch.nn import init + + +def init_weights(net, init_type="normal", gain=0.02): + def init_func(m): + classname = m.__class__.__name__ + if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): + if init_type == "normal": + init.normal_(m.weight.data, 0.0, gain) + elif init_type == "xavier": + init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == "kaiming": + init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") + elif init_type == "orthogonal": + init.orthogonal_(m.weight.data, gain=gain) + else: + raise NotImplementedError("initialization method [%s] is not implemented" % init_type) + if hasattr(m, "bias") and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find("BatchNorm2d") != -1: + init.normal_(m.weight.data, 1.0, gain) + init.constant_(m.bias.data, 0.0) + + print("initialize network with %s" % init_type) + net.apply(init_func) + + +def init_net(net, init_type="normal", init_gain=0.02, gpu_ids=[]): + if len(gpu_ids) > 0: + assert torch.cuda.is_available() + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) + init_weights(net, init_type, gain=init_gain) + return net diff --git a/torch-points3d/torch_points3d/core/losses/__init__.py b/torch-points3d/torch_points3d/core/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/core/losses/focal_loss.py b/torch-points3d/torch_points3d/core/losses/focal_loss.py new file mode 100644 index 0000000..c9f377f --- /dev/null +++ b/torch-points3d/torch_points3d/core/losses/focal_loss.py @@ -0,0 +1,48 @@ +import torch +import torch.nn.functional as F + + +def focal_ce(input, target, alpha=None, gamma=2, reduction="mean", label_smoothing=0.0): + """ + input: [N, C], float32 + target: [N, ], int64 + """ + assert 0 <= label_smoothing < 1 + + if input.ndim > 2: + # (N, C, d1, d2, ..., dK) --> (N * d1 * ... * dK, C) + c = input.shape[1] + input = input.permute(0, *range(2, input.ndim), 1).reshape(-1, c) + # (N, d1, d2, ..., dK) --> (N * d1 * ... * dK,) + target = target.view(-1) + + # compute weighted cross entropy term: -alpha * log(pt) + # (alpha is already part of self.nll_loss) + log_p = F.log_softmax(input, dim=-1) + ce = F.nll_loss(log_p, target, weight=alpha, reduction="none") + if label_smoothing != 0: + confidence = 1.0 - label_smoothing + smoothing = label_smoothing + smooth_loss = -log_p.mean(dim=-1) + ce = confidence * ce + smoothing * smooth_loss + + # get true class column from each row + all_rows = torch.arange(len(input)) + log_pt = log_p[all_rows, target] + + # compute focal term: (1 - pt)^gamma + pt = log_pt.exp() + focal_term = (1 - pt) ** gamma + + # the full loss: -alpha * ((1 - pt)^gamma) * log(pt) + loss = focal_term * ce + + + + if reduction == 'mean': + loss = loss.mean() + elif reduction == 'sum': + loss = loss.sum() + + + return loss diff --git a/torch-points3d/torch_points3d/core/optimizer/__init__.py b/torch-points3d/torch_points3d/core/optimizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/core/optimizer/adabelief.py b/torch-points3d/torch_points3d/core/optimizer/adabelief.py new file mode 100644 index 0000000..80a0de0 --- /dev/null +++ b/torch-points3d/torch_points3d/core/optimizer/adabelief.py @@ -0,0 +1,201 @@ +import math +import torch +from torch.optim.optimizer import Optimizer + + +class AdaBelief(Optimizer): + r"""Implements AdaBelief algorithm. Modified from Adam in PyTorch + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-16) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant of this + algorithm from the paper `On the Convergence of Adam and Beyond`_ + (default: False) + decoupled_decay (boolean, optional): (default: True) If set as True, then + the optimizer uses decoupled weight decay as in AdamW + fixed_decay (boolean, optional): (default: False) This is used when weight_decouple + is set as True. + When fixed_decay == True, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay$. + When fixed_decay == False, the weight decay is performed as + $W_{new} = W_{old} - W_{old} \times decay \times lr$. Note that in this case, the + weight decay ratio decreases with learning rate (lr). + rectify (boolean, optional): (default: True) If set as True, then perform the rectified + update similar to RAdam + degenerated_to_sgd (boolean, optional) (default:True) If set as True, then perform SGD update + when variance of gradient is high + reference: AdaBelief Optimizer, adapting stepsizes by the belief in observed gradients, NeurIPS 2020 + + For a complete table of recommended hyperparameters, see https://github.com/juntang-zhuang/Adabelief-Optimizer' + For example train/args for EfficientNet see these gists + - link to train_scipt: https://gist.github.com/juntang-zhuang/0a501dd51c02278d952cf159bc233037 + - link to args.yaml: https://gist.github.com/juntang-zhuang/517ce3c27022b908bb93f78e4f786dc3 + """ + + def __init__( + self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, amsgrad=False, + decoupled_decay=True, fixed_decay=False, rectify=True, degenerated_to_sgd=True): + + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, + degenerated_to_sgd=degenerated_to_sgd, decoupled_decay=decoupled_decay, rectify=rectify, + fixed_decay=fixed_decay, buffer=[[None, None, None] for _ in range(10)]) + super(AdaBelief, self).__init__(params, defaults) + + def __setstate__(self, state): + super(AdaBelief, self).__setstate__(state) + for group in self.param_groups: + group.setdefault('amsgrad', False) + + @torch.no_grad() + def reset(self): + for group in self.param_groups: + for p in group['params']: + state = self.state[p] + amsgrad = group['amsgrad'] + + # State initialization + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad + if grad.dtype in {torch.float16, torch.bfloat16}: + grad = grad.float() + if grad.is_sparse: + raise RuntimeError( + 'AdaBelief does not support sparse gradients, please consider SparseAdam instead') + + p_fp32 = p + if p.dtype in {torch.float16, torch.bfloat16}: + p_fp32 = p_fp32.float() + + amsgrad = group['amsgrad'] + beta1, beta2 = group['betas'] + state = self.state[p] + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p_fp32) + # Exponential moving average of squared gradient values + state['exp_avg_var'] = torch.zeros_like(p_fp32) + if amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state['max_exp_avg_var'] = torch.zeros_like(p_fp32) + + # perform weight decay, check if decoupled weight decay + if group['decoupled_decay']: + if not group['fixed_decay']: + p_fp32.mul_(1.0 - group['lr'] * group['weight_decay']) + else: + p_fp32.mul_(1.0 - group['weight_decay']) + else: + if group['weight_decay'] != 0: + grad.add_(p_fp32, alpha=group['weight_decay']) + + # get current state variable + exp_avg, exp_avg_var = state['exp_avg'], state['exp_avg_var'] + + state['step'] += 1 + bias_correction1 = 1 - beta1 ** state['step'] + bias_correction2 = 1 - beta2 ** state['step'] + + # Update first and second moment running average + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + grad_residual = grad - exp_avg + exp_avg_var.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) + + if amsgrad: + max_exp_avg_var = state['max_exp_avg_var'] + # Maintains the maximum of all 2nd moment running avg. till now + torch.max(max_exp_avg_var, exp_avg_var.add_(group['eps']), out=max_exp_avg_var) + + # Use the max. for normalizing running avg. of gradient + denom = (max_exp_avg_var.sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + else: + denom = (exp_avg_var.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps']) + + # update + if not group['rectify']: + # Default update + step_size = group['lr'] / bias_correction1 + p_fp32.addcdiv_(exp_avg, denom, value=-step_size) + else: + # Rectified update, forked from RAdam + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + num_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + num_sma_max = 2 / (1 - beta2) - 1 + num_sma = num_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = num_sma + + # more conservative since it's an approximated value + if num_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * + (num_sma - 4) / (num_sma_max - 4) * + (num_sma - 2) / num_sma * + num_sma_max / (num_sma_max - 2)) / (1 - beta1 ** state['step']) + elif group['degenerated_to_sgd']: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + if num_sma >= 5: + denom = exp_avg_var.sqrt().add_(group['eps']) + p_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) + elif step_size > 0: + p_fp32.add_(exp_avg, alpha=-step_size * group['lr']) + + if p.dtype in {torch.float16, torch.bfloat16}: + p.copy_(p_fp32) + + return loss diff --git a/torch-points3d/torch_points3d/core/regularizer/__init__.py b/torch-points3d/torch_points3d/core/regularizer/__init__.py new file mode 100644 index 0000000..843295f --- /dev/null +++ b/torch-points3d/torch_points3d/core/regularizer/__init__.py @@ -0,0 +1 @@ +from .regularizers import * diff --git a/torch-points3d/torch_points3d/core/regularizer/regularizers.py b/torch-points3d/torch_points3d/core/regularizer/regularizers.py new file mode 100644 index 0000000..4f8fb87 --- /dev/null +++ b/torch-points3d/torch_points3d/core/regularizer/regularizers.py @@ -0,0 +1,202 @@ +from enum import Enum + + +class _Regularizer(object): + """ + Parent class of Regularizers + """ + + def __init__(self, model): + super(_Regularizer, self).__init__() + self.model = model + + def regularized_param(self, param_weights, reg_loss_function): + raise NotImplementedError + + def regularized_all_param(self, reg_loss_function): + raise NotImplementedError + + +class L1Regularizer(_Regularizer): + """ + L1 regularized loss + """ + + def __init__(self, model, lambda_reg=0.01): + super(L1Regularizer, self).__init__(model=model) + self.lambda_reg = lambda_reg + + def regularized_param(self, param_weights, reg_loss_function): + reg_loss_function += self.lambda_reg * L1Regularizer.__add_l1(var=param_weights) + return reg_loss_function + + def regularized_all_param(self, reg_loss_function): + for model_param_name, model_param_value in self.model.named_parameters(): + if ( + model_param_name.endswith("weight") + and "1.weight" not in model_param_name + and "bn" not in model_param_name + ): + reg_loss_function += self.lambda_reg * L1Regularizer.__add_l1(var=model_param_value) + return reg_loss_function + + @staticmethod + def __add_l1(var): + return var.abs().sum() + + +class L2Regularizer(_Regularizer): + """ + L2 regularized loss + """ + + def __init__(self, model, lambda_reg=0.01): + super(L2Regularizer, self).__init__(model=model) + self.lambda_reg = lambda_reg + + def regularized_param(self, param_weights, reg_loss_function): + reg_loss_function += self.lambda_reg * L2Regularizer.__add_l2(var=param_weights) + return reg_loss_function + + def regularized_all_param(self, reg_loss_function): + for model_param_name, model_param_value in self.model.named_parameters(): + if ( + model_param_name.endswith("weight") + and "1.weight" not in model_param_name + and "bn" not in model_param_name + ): + reg_loss_function += self.lambda_reg * L2Regularizer.__add_l2(var=model_param_value) + return reg_loss_function + + @staticmethod + def __add_l2(var): + return var.pow(2).sum() + + +class ElasticNetRegularizer(_Regularizer): + """ + Elastic Net Regularizer + """ + + def __init__(self, model, lambda_reg=0.01, alpha_reg=0.01): + super(ElasticNetRegularizer, self).__init__(model=model) + self.lambda_reg = lambda_reg + self.alpha_reg = alpha_reg + + def regularized_param(self, param_weights, reg_loss_function): + reg_loss_function += self.lambda_reg * ( + ((1 - self.alpha_reg) * ElasticNetRegularizer.__add_l2(var=param_weights)) + + (self.alpha_reg * ElasticNetRegularizer.__add_l1(var=param_weights)) + ) + return reg_loss_function + + def regularized_all_param(self, reg_loss_function): + for model_param_name, model_param_value in self.model.named_parameters(): + if model_param_name.endswith("weight"): + reg_loss_function += self.lambda_reg * ( + ((1 - self.alpha_reg) * ElasticNetRegularizer.__add_l2(var=model_param_value)) + + (self.alpha_reg * ElasticNetRegularizer.__add_l1(var=model_param_value)) + ) + return reg_loss_function + + @staticmethod + def __add_l1(var): + return var.abs().sum() + + @staticmethod + def __add_l2(var): + return var.pow(2).sum() + + +class GroupSparseLassoRegularizer(_Regularizer): + """ + Group Sparse Lasso Regularizer + """ + + def __init__(self, model, lambda_reg=0.01): + super(GroupSparseLassoRegularizer, self).__init__(model=model) + self.lambda_reg = lambda_reg + self.reg_l2_l1 = GroupLassoRegularizer(model=self.model, lambda_reg=self.lambda_reg) + self.reg_l1 = L1Regularizer(model=self.model, lambda_reg=self.lambda_reg) + + def regularized_param(self, param_weights, reg_loss_function): + reg_loss_function = self.lambda_reg * ( + self.reg_l2_l1.regularized_param(param_weights=param_weights, reg_loss_function=reg_loss_function) + + self.reg_l1.regularized_param(param_weights=param_weights, reg_loss_function=reg_loss_function) + ) + + return reg_loss_function + + def regularized_all_param(self, reg_loss_function): + reg_loss_function = self.lambda_reg * ( + self.reg_l2_l1.regularized_all_param(reg_loss_function=reg_loss_function) + + self.reg_l1.regularized_all_param(reg_loss_function=reg_loss_function) + ) + + return reg_loss_function + + +class GroupLassoRegularizer(_Regularizer): + """ + GroupLasso Regularizer: + The first dimension represents the input layer and the second dimension represents the output layer. + The groups are defined by the column in the matrix W + """ + + def __init__(self, model, lambda_reg=0.01): + super(GroupLassoRegularizer, self).__init__(model=model) + self.lambda_reg = lambda_reg + + def regularized_param(self, param_weights, reg_loss_function, group_name="input_group"): + if group_name == "input_group": + reg_loss_function += self.lambda_reg * GroupLassoRegularizer.__inputs_groups_reg( + layer_weights=param_weights + ) # apply the group norm on the input value + elif group_name == "hidden_group": + reg_loss_function += self.lambda_reg * GroupLassoRegularizer.__inputs_groups_reg( + layer_weights=param_weights + ) # apply the group norm on every hidden layer + elif group_name == "bias_group": + reg_loss_function += self.lambda_reg * GroupLassoRegularizer.__bias_groups_reg( + bias_weights=param_weights + ) # apply the group norm on the bias + else: + print( + "The group {} is not supported yet. Please try one of this: [input_group, hidden_group, bias_group]".format( + group_name + ) + ) + return reg_loss_function + + def regularized_all_param(self, reg_loss_function): + for model_param_name, model_param_value in self.model.named_parameters(): + if model_param_name.endswith("weight"): + reg_loss_function += self.lambda_reg * GroupLassoRegularizer.__inputs_groups_reg( + layer_weights=model_param_value + ) + if model_param_name.endswith("bias"): + reg_loss_function += self.lambda_reg * GroupLassoRegularizer.__bias_groups_reg( + bias_weights=model_param_value + ) + return reg_loss_function + + @staticmethod + def __grouplasso_reg(groups, dim): + if dim == -1: + # We only have single group + return groups.norm(2) + return groups.norm(2, dim=dim).sum() + + @staticmethod + def __inputs_groups_reg(layer_weights): + return GroupLassoRegularizer.__grouplasso_reg(groups=layer_weights, dim=1) + + @staticmethod + def __bias_groups_reg(bias_weights): + return GroupLassoRegularizer.__grouplasso_reg(groups=bias_weights, dim=-1) # ou 0 i dont know yet + + +class RegularizerTypes(Enum): + L1 = L1Regularizer + L2 = L2Regularizer + ELASTIC = ElasticNetRegularizer diff --git a/torch-points3d/torch_points3d/core/schedulers/__init__.py b/torch-points3d/torch_points3d/core/schedulers/__init__.py new file mode 100644 index 0000000..9adc2d5 --- /dev/null +++ b/torch-points3d/torch_points3d/core/schedulers/__init__.py @@ -0,0 +1,2 @@ +from .lr_schedulers import * +from .bn_schedulers import * diff --git a/torch-points3d/torch_points3d/core/schedulers/bn_schedulers.py b/torch-points3d/torch_points3d/core/schedulers/bn_schedulers.py new file mode 100644 index 0000000..05faf85 --- /dev/null +++ b/torch-points3d/torch_points3d/core/schedulers/bn_schedulers.py @@ -0,0 +1,112 @@ +from typing import * +from omegaconf import OmegaConf +from torch import nn +import logging + +try: + import MinkowskiEngine as ME + + BATCH_NORM_MODULES: Any = ( + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ME.MinkowskiBatchNorm, + ME.MinkowskiInstanceNorm, + ) +except: + BATCH_NORM_MODULES = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d) + + +log = logging.getLogger(__name__) + + +def set_bn_momentum_default(bn_momentum): + """ + This function return a function which will assign `bn_momentum` to every module instance within `BATCH_NORM_MODULES`. + """ + + def fn(m): + if isinstance(m, BATCH_NORM_MODULES): + m.momentum = bn_momentum + + return fn + + +class BNMomentumScheduler(object): + def __init__(self, model, bn_lambda, update_scheduler_on, last_epoch=-1, setter=set_bn_momentum_default): + if not isinstance(model, nn.Module): + raise RuntimeError("Class '{}' is not a PyTorch nn Module".format(type(model).__name__)) + + self.model = model + self.setter = setter + self.bn_lambda = bn_lambda + self._current_momemtum = None + self.step(last_epoch + 1) + self.last_epoch = last_epoch + self._scheduler_opt = None + self._update_scheduler_on = update_scheduler_on + + @property + def update_scheduler_on(self): + return self._update_scheduler_on + + @property + def scheduler_opt(self): + return self._scheduler_opt + + @scheduler_opt.setter + def scheduler_opt(self, scheduler_opt): + self._scheduler_opt = scheduler_opt + + def step(self, epoch=None): + + if epoch is None: + epoch = self.last_epoch + 1 + + self.last_epoch = epoch + new_momemtum = self.bn_lambda(epoch) + if self._current_momemtum != new_momemtum: + self._current_momemtum = new_momemtum + log.info("Setting batchnorm momentum at {}".format(new_momemtum)) + self.model.apply(self.setter(new_momemtum)) + + def state_dict(self): + return { + "current_momemtum": self.bn_lambda(self.last_epoch), + "last_epoch": self.last_epoch, + } + + def load_state_dict(self, state_dict): + self.last_epoch = state_dict["last_epoch"] + self.current_momemtum = state_dict["current_momemtum"] + + def __repr__(self): + return "{}(base_momentum: {}, update_scheduler_on={})".format( + self.__class__.__name__, self._current_momemtum, self._update_scheduler_on + ) + + +def instantiate_bn_scheduler(model, bn_scheduler_opt): + """Return a batch normalization scheduler + Parameters: + model -- the nn network + bn_scheduler_opt (option class) -- dict containing all the params to build the scheduler  + opt.bn_policy is the name of learning rate policy: lambda_rule | step | plateau | cosine + opt.params contains the scheduler_params to construct the scheduler + See https://pytorch.org/docs/stable/optim.html for more details. + """ + update_scheduler_on = bn_scheduler_opt.get("update_scheduler_on") + bn_scheduler_params = bn_scheduler_opt.get("params") + if bn_scheduler_opt.get("bn_policy") == "step_decay": + bn_lambda = lambda e: max( + bn_scheduler_params.bn_momentum + * bn_scheduler_params.bn_decay ** (int(e // bn_scheduler_params.decay_step)), + bn_scheduler_params.bn_clip, + ) + + else: + return NotImplementedError("bn_policy [%s] is not implemented", bn_scheduler_opt.bn_policy) + + bn_scheduler = BNMomentumScheduler(model, bn_lambda, update_scheduler_on) + bn_scheduler.scheduler_opt = OmegaConf.to_container(bn_scheduler_opt) + return bn_scheduler diff --git a/torch-points3d/torch_points3d/core/schedulers/lr_schedulers.py b/torch-points3d/torch_points3d/core/schedulers/lr_schedulers.py new file mode 100644 index 0000000..d8f59d7 --- /dev/null +++ b/torch-points3d/torch_points3d/core/schedulers/lr_schedulers.py @@ -0,0 +1,270 @@ +import math +import sys +import warnings +from typing import List + +from torch.optim import lr_scheduler, Optimizer +from omegaconf import OmegaConf +import logging +from torch.optim.lr_scheduler import LambdaLR, _LRScheduler + +from torch_points3d.utils.enums import SchedulerUpdateOn + +log = logging.getLogger(__name__) + +_custom_lr_scheduler = sys.modules[__name__] + + +def collect_params(params, update_scheduler_on): + """ + This function enable to handle if params contains on_epoch and on_iter or not. + """ + on_epoch_params = params.get("on_epoch") + on_batch_params = params.get("on_num_batch") + on_sample_params = params.get("on_num_sample") + + def check_params(params): + if params is not None: + return params + else: + raise Exception( + "The lr_scheduler doesn't have policy {}. Options: {}".format(update_scheduler_on, SchedulerUpdateOn) + ) + + if on_epoch_params or on_batch_params or on_sample_params: + if update_scheduler_on == SchedulerUpdateOn.ON_EPOCH.value: + return check_params(on_epoch_params) + elif update_scheduler_on == SchedulerUpdateOn.ON_NUM_BATCH.value: + return check_params(on_batch_params) + elif update_scheduler_on == SchedulerUpdateOn.ON_NUM_SAMPLE.value: + return check_params(on_sample_params) + else: + raise Exception( + "The provided update_scheduler_on {} isn't within {}".format(update_scheduler_on, SchedulerUpdateOn) + ) + else: + return params + + +class LinearWarmupCosineAnnealingLR(_LRScheduler): + """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr + and base_lr followed by a cosine annealing schedule between base_lr and eta_min. + .. warning:: + It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` + after each iteration as calling it after each epoch will keep the starting lr at + warmup_start_lr for the first epoch which is 0 in most cases. + .. warning:: + passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. + It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of + :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing + epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling + train and validation methods. + Example: + >>> layer = nn.Linear(10, 1) + >>> optimizer = Adam(layer.parameters(), lr=0.02) + >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) + >>> # + >>> # the default case + >>> for epoch in range(40): + ... # train(...) + ... # validate(...) + ... scheduler.step() + >>> # + >>> # passing epoch param case + >>> for epoch in range(40): + ... scheduler.step(epoch) + ... # train(...) + ... # validate(...) + """ + + def __init__( + self, + optimizer: Optimizer, + warmup_epochs: int, + max_epochs: int, + warmup_start_lr: float = 0.00003, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: + """ + from https://github.com/Lightning-AI/lightning-bolts/blob/master/pl_bolts/optimizers/lr_scheduler.py + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_epochs + self.max_epochs = max_epochs + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + """Compute learning rate using chainable form of the scheduler.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", + UserWarning, + ) + + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + if self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + if self.last_epoch == self.warmup_epochs: + return self.base_lrs + if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + / ( + 1 + + math.cos( + math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> List[float]: + """Called when epoch is passed as a param to the `step` function of the scheduler.""" + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + + 0.5 + * (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] + + @property + def last_step(self): + """Use last_epoch for the step counter""" + return self.last_epoch + + @last_step.setter + def last_step(self, v): + self.last_epoch = v + + +class LambdaStepLR(LambdaLR): + def __init__(self, optimizer, lr_lambda, last_step=-1): + super(LambdaStepLR, self).__init__(optimizer, lr_lambda, last_step) + + @property + def last_step(self): + """Use last_epoch for the step counter""" + return self.last_epoch + + @last_step.setter + def last_step(self, v): + self.last_epoch = v + + +class PolyLR(LambdaStepLR): + """DeepLab learning rate policy""" + + def __init__(self, optimizer, max_iter, power=0.9, last_step=-1): + lambda_func = lambda s: (1 - s / (max_iter + 1)) ** power + composite_func = lambda s: lambda_func(max_iter) if s > max_iter else lambda_func(s) + super(PolyLR, self).__init__(optimizer, lambda s: composite_func(s), last_step) + + +class SquaredLR(LambdaStepLR): + """ Used for SGD Lars""" + + def __init__(self, optimizer, max_iter, last_step=-1): + super(SquaredLR, self).__init__(optimizer, lambda s: (1 - s / (max_iter + 1)) ** 2, last_step) + + +class ExpLR(LambdaStepLR): + def __init__(self, optimizer, step_size, gamma=0.9, last_step=-1): + # (0.9 ** 21.854) = 0.1, (0.95 ** 44.8906) = 0.1 + # To get 0.1 every N using gamma 0.9, N * log(0.9)/log(0.1) = 0.04575749 N + # To get 0.1 every N using gamma g, g ** N = 0.1 -> N * log(g) = log(0.1) -> g = np.exp(log(0.1) / N) + super(ExpLR, self).__init__(optimizer, lambda s: gamma ** (s / step_size), last_step) + + +def repr(self, scheduler_params={}): + return "{}({})".format(self.__class__.__name__, scheduler_params) + + +class LRScheduler: + def __init__(self, scheduler, scheduler_params, update_scheduler_on): + self._scheduler = scheduler + self._scheduler_params = scheduler_params + self._update_scheduler_on = update_scheduler_on + + @property + def scheduler(self): + return self._scheduler + + @property + def scheduler_opt(self): + return self._scheduler._scheduler_opt + + def __repr__(self): + return "{}({}, update_scheduler_on={})".format( + self._scheduler.__class__.__name__, self._scheduler_params, self._update_scheduler_on + ) + + def step(self, *args, **kwargs): + self._scheduler.step(*args, **kwargs) + + def state_dict(self): + return self._scheduler.state_dict() + + def load_state_dict(self, state_dict): + self._scheduler.load_state_dict(state_dict) + + +def instantiate_scheduler(optimizer, scheduler_opt): + """Return a learning rate scheduler + Parameters: + optimizer -- the optimizer of the network + scheduler_opt (option class) -- dict containing all the params to build the scheduler  + opt.lr_policy is the name of learning rate policy: lambda_rule | step | plateau | cosine + opt.params contains the scheduler_params to construct the scheduler + See https://pytorch.org/docs/stable/optim.html for more details. + """ + + update_scheduler_on = scheduler_opt.update_scheduler_on + scheduler_cls_name = getattr(scheduler_opt, "class") + scheduler_params = collect_params(scheduler_opt.params, update_scheduler_on) + + try: + scheduler_cls = getattr(lr_scheduler, scheduler_cls_name) + except: + scheduler_cls = getattr(_custom_lr_scheduler, scheduler_cls_name) + log.info("Created custom lr scheduler") + + if scheduler_cls_name.lower() == "ReduceLROnPlateau".lower(): + raise NotImplementedError("This scheduler is not fully supported yet") + + scheduler = scheduler_cls(optimizer, **scheduler_params) + # used to re_create the scheduler + # instantiate vars + for key in scheduler_params.keys(): + scheduler_params[key] = getattr(scheduler, key) + + setattr(scheduler, "_scheduler_opt", OmegaConf.to_container(scheduler_opt)) + return LRScheduler(scheduler, scheduler_params, update_scheduler_on) diff --git a/torch-points3d/torch_points3d/core/spatial_ops/__init__.py b/torch-points3d/torch_points3d/core/spatial_ops/__init__.py new file mode 100644 index 0000000..f91c893 --- /dev/null +++ b/torch-points3d/torch_points3d/core/spatial_ops/__init__.py @@ -0,0 +1,2 @@ +from .sampling import * +from .interpolate import * diff --git a/torch-points3d/torch_points3d/core/spatial_ops/interpolate.py b/torch-points3d/torch_points3d/core/spatial_ops/interpolate.py new file mode 100644 index 0000000..fbb29ca --- /dev/null +++ b/torch-points3d/torch_points3d/core/spatial_ops/interpolate.py @@ -0,0 +1,70 @@ +import torch +from torch_geometric.nn import knn_interpolate, knn +from torch_scatter import scatter_add +from torch_geometric.data import Data + + +class KNNInterpolate: + def __init__(self, k): + self.k = k + + def precompute(self, query, support): + """ Precomputes a data structure that can be used in the transform itself to speed things up + """ + pos_x, pos_y = query.pos, support.pos + if hasattr(support, "batch"): + batch_y = support.batch + else: + batch_y = torch.zeros((support.num_nodes,), dtype=torch.long) + if hasattr(query, "batch"): + batch_x = query.batch + else: + batch_x = torch.zeros((query.num_nodes,), dtype=torch.long) + + with torch.no_grad(): + assign_index = knn(pos_x, pos_y, self.k, batch_x=batch_x, batch_y=batch_y) + y_idx, x_idx = assign_index + diff = pos_x[x_idx] - pos_y[y_idx] + squared_distance = (diff * diff).sum(dim=-1, keepdim=True) + weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + normalisation = scatter_add(weights, y_idx, dim=0, dim_size=pos_y.size(0)) + + return Data(num_nodes=support.num_nodes, x_idx=x_idx, y_idx=y_idx, weights=weights, normalisation=normalisation) + + def __call__(self, query, support, precomputed: Data = None): + """ Computes a new set of features going from the query resolution position to the support + resolution position + Args: + - query: data structure that holds the low res data (position + features) + - support: data structure that holds the position to which we will interpolate + Returns: + - torch.tensor: interpolated features + """ + if precomputed: + num_points = support.pos.size(0) + if num_points != precomputed.num_nodes: + raise ValueError("Precomputed indices do not match with the data given to the transform") + + x = query.x + x_idx, y_idx, weights, normalisation = ( + precomputed.x_idx, + precomputed.y_idx, + precomputed.weights, + precomputed.normalisation, + ) + y = scatter_add(x[x_idx] * weights, y_idx, dim=0, dim_size=num_points) + y = y / normalisation + return y + + x, pos = query.x, query.pos + pos_support = support.pos + if hasattr(support, "batch"): + batch_support = support.batch + else: + batch_support = torch.zeros((support.num_nodes,), dtype=torch.long) + if hasattr(query, "batch"): + batch = query.batch + else: + batch = torch.zeros((query.num_nodes,), dtype=torch.long) + + return knn_interpolate(x, pos, pos_support, batch, batch_support, k=self.k) diff --git a/torch-points3d/torch_points3d/core/spatial_ops/sampling.py b/torch-points3d/torch_points3d/core/spatial_ops/sampling.py new file mode 100644 index 0000000..2a10b89 --- /dev/null +++ b/torch-points3d/torch_points3d/core/spatial_ops/sampling.py @@ -0,0 +1,126 @@ +from abc import ABC, abstractmethod +import math +import torch +from torch_geometric.nn import voxel_grid +from torch_geometric.nn.pool.consecutive import consecutive_cluster +from torch_geometric.nn.pool.pool import pool_pos, pool_batch +import torch_points_kernels as tp + +from torch_points3d.utils.config import is_list +from torch_points3d.utils.enums import ConvolutionFormat + + +class BaseSampler(ABC): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def __init__(self, ratio=None, num_to_sample=None, subsampling_param=None): + if num_to_sample is not None: + if (ratio is not None) or (subsampling_param is not None): + raise ValueError("Can only specify ratio or num_to_sample or subsampling_param, not several !") + self._num_to_sample = num_to_sample + + elif ratio is not None: + self._ratio = ratio + + elif subsampling_param is not None: + self._subsampling_param = subsampling_param + + else: + raise Exception('At least ["ratio, num_to_sample, subsampling_param"] should be defined') + + def __call__(self, pos, x=None, batch=None): + return self.sample(pos, batch=batch, x=x) + + def _get_num_to_sample(self, batch_size) -> int: + if hasattr(self, "_num_to_sample"): + return self._num_to_sample + else: + return math.floor(batch_size * self._ratio) + + def _get_ratio_to_sample(self, batch_size) -> float: + if hasattr(self, "_ratio"): + return self._ratio + else: + return self._num_to_sample / float(batch_size) + + @abstractmethod + def sample(self, pos, x=None, batch=None): + pass + + +class FPSSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos, batch, **kwargs): + from torch_geometric.nn import fps + + if len(pos.shape) != 2: + raise ValueError(" This class is for sparse data and expects the pos tensor to be of dimension 2") + return fps(pos, batch, ratio=self._get_ratio_to_sample(pos.shape[0])) + + +class GridSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos=None, x=None, batch=None): + if len(pos.shape) != 2: + raise ValueError("This class is for sparse data and expects the pos tensor to be of dimension 2") + + pool = voxel_grid(pos, batch, self._subsampling_param) + pool, perm = consecutive_cluster(pool) + batch = pool_batch(perm, batch) + if x is not None: + return pool_pos(pool, x), pool_pos(pool, pos), batch + else: + return None, pool_pos(pool, pos), batch + + +class DenseFPSSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos, **kwargs): + """ Sample pos + + Arguments: + pos -- [B, N, 3] + + Returns: + indexes -- [B, num_sample] + """ + if len(pos.shape) != 3: + raise ValueError(" This class is for dense data and expects the pos tensor to be of dimension 2") + return tp.furthest_point_sample(pos, self._get_num_to_sample(pos.shape[1])) + + +class RandomSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + """ + + def sample(self, pos, batch, **kwargs): + if len(pos.shape) != 2: + raise ValueError(" This class is for sparse data and expects the pos tensor to be of dimension 2") + idx = torch.randint(0, pos.shape[0], (self._get_num_to_sample(pos.shape[0]),)) + return idx + + +class DenseRandomSampler(BaseSampler): + """If num_to_sample is provided, sample exactly + num_to_sample points. Otherwise sample floor(pos[0] * ratio) points + Arguments: + pos -- [B, N, 3] + """ + + def sample(self, pos, **kwargs): + if len(pos.shape) != 3: + raise ValueError(" This class is for dense data and expects the pos tensor to be of dimension 2") + idx = torch.randint(0, pos.shape[1], (self._get_num_to_sample(pos.shape[1]),)) + return idx diff --git a/torch-points3d/torch_points3d/datasets/__init__.py b/torch-points3d/torch_points3d/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/datasets/base_dataset.py b/torch-points3d/torch_points3d/datasets/base_dataset.py new file mode 100644 index 0000000..ccad2f0 --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/base_dataset.py @@ -0,0 +1,592 @@ +import os +from abc import abstractmethod +import logging +import functools +from functools import partial +import numpy as np +import torch +import torch_geometric +from torch_geometric.transforms import Compose +import copy + +from torch_points3d.models import model_interface +from torch_points3d.core.data_transform import instantiate_transforms, MultiScaleTransform +from torch_points3d.core.data_transform import instantiate_filters +from torch_points3d.datasets.batch import SimpleBatch +from torch_points3d.datasets.multiscale_data import MultiScaleBatch +from torch_points3d.utils.enums import ConvolutionFormat +from torch_points3d.utils.config import ConvolutionFormatFactory +from torch_points3d.utils.colors import COLORS + +# A logger for this file +log = logging.getLogger(__name__) + + +def explode_transform(transforms): + """ Returns a flattened list of transform + Arguments: + transforms {[list | T.Compose]} -- Contains list of transform to be added + + Returns: + [list] -- [List of transforms] + """ + out = [] + if transforms is not None: + if isinstance(transforms, Compose): + out = copy.deepcopy(transforms.transforms) + elif isinstance(transforms, list): + out = copy.deepcopy(transforms) + else: + raise Exception("Transforms should be provided either within a list or a Compose") + return out + + +def save_used_properties(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + # Save used_properties for mocking dataset when calling pretrained registry + result = func(self, *args, **kwargs) + if isinstance(result, torch.Tensor): + self.used_properties[func.__name__] = result.numpy().tolist() + elif isinstance(result, np.ndarray): + self.used_properties[func.__name__] = result.tolist() + else: + self.used_properties[func.__name__] = result + return result + + return wrapper + + +class BaseDataset: + def __init__(self, dataset_opt): + self.dataset_opt = dataset_opt + + # Default dataset path + dataset_name = dataset_opt.get("dataset_name", None) + if dataset_name: + self._data_path = os.path.join(dataset_opt.dataroot, dataset_name) + else: + class_name = self.__class__.__name__.lower().replace("dataset", "") + self._data_path = os.path.join(dataset_opt.dataroot, class_name) + self._batch_size = None + self.strategies = {} + self._contains_dataset_name = False + + self.train_sampler = None + self.test_sampler = None + self.val_sampler = None + + self._train_dataset = None + self._test_dataset = None + self._val_dataset = None + + self.train_pre_batch_collate_transform = None + self.val_pre_batch_collate_transform = None + self.test_pre_batch_collate_transform = None + + transforms = dataset_opt.get(dataset_opt.transform_type) + if "pre_transform" in dataset_opt: + transforms["pre_transform"] = dataset_opt.get("pre_transform") + BaseDataset.set_transform(self, transforms) + self.set_filter(dataset_opt) + + self.used_properties = {} + + @staticmethod + def remove_transform(transform_in, list_transform_class): + """ Remove a transform if within list_transform_class + + Arguments: + transform_in {[type]} -- [Compose | List of transform] + list_transform_class {[type]} -- [List of transform class to be removed] + + Returns: + [type] -- [description] + """ + if isinstance(transform_in, Compose) or isinstance(transform_in, list): + if len(list_transform_class) > 0: + transform_out = [] + transforms = transform_in.transforms if isinstance(transform_in, Compose) else transform_in + for t in transforms: + if not isinstance(t, tuple(list_transform_class)): + transform_out.append(t) + transform_out = Compose(transform_out) + else: + transform_out = transform_in + return transform_out + + @staticmethod + def set_transform(obj, dataset_opt): + """This function create and set the transform to the obj as attributes + """ + obj.pre_transform = None + obj.test_transform = None + obj.train_transform = None + obj.val_transform = None + obj.inference_transform = None + + for key_name in dataset_opt.keys(): + if "transform" in key_name: + new_name = key_name.replace("transforms", "transform") + try: + transform = instantiate_transforms(getattr(dataset_opt, key_name)) + except Exception: + log.exception("Error trying to create {}, {}".format(new_name, getattr(dataset_opt, key_name))) + continue + setattr(obj, new_name, transform) + inference_transform = explode_transform(obj.pre_transform) + inference_transform += explode_transform(obj.test_transform) + obj.inference_transform = Compose(inference_transform) if len(inference_transform) > 0 else None + + def set_filter(self, dataset_opt): + """This function create and set the pre_filter to the obj as attributes + """ + self.pre_filter = None + for key_name in dataset_opt.keys(): + if "filter" in key_name: + new_name = key_name.replace("filters", "filter") + try: + filt = instantiate_filters(getattr(dataset_opt, key_name)) + except Exception: + log.exception("Error trying to create {}, {}".format(new_name, getattr(dataset_opt, key_name))) + continue + setattr(self, new_name, filt) + + @staticmethod + def _collate_fn(batch, collate_fn=None, pre_collate_transform=None): + if pre_collate_transform: + batch = pre_collate_transform(batch) + return collate_fn(batch) + + @staticmethod + def _get_collate_function(conv_type, is_multiscale, pre_collate_transform=None): + is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) + if is_multiscale: + if conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower(): + fn = MultiScaleBatch.from_data_list + else: + raise NotImplementedError( + "MultiscaleTransform is activated and supported only for partial_dense format" + ) + else: + if is_dense: + fn = SimpleBatch.from_data_list + else: + fn = torch_geometric.data.batch.Batch.from_data_list + return partial(BaseDataset._collate_fn, collate_fn=fn, pre_collate_transform=pre_collate_transform) + + @staticmethod + def get_num_samples(batch, conv_type): + is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) + if is_dense: + return batch.pos.shape[0] + else: + return batch.batch.max() + 1 + + @staticmethod + def get_sample(batch, key, index, conv_type): + assert hasattr(batch, key) + is_dense = ConvolutionFormatFactory.check_is_dense_format(conv_type) + if is_dense: + return batch[key][index] + else: + return batch[key][batch.batch == index] + + def create_dataloaders( + self, + model: model_interface.DatasetInterface, + batch_size: int, + shuffle: bool, + drop_last: bool, + num_workers: int, + precompute_multi_scale: bool, + ): + """ Creates the data loaders. Must be called in order to complete the setup of the Dataset + """ + conv_type = model.conv_type + self._batch_size = batch_size + + if self.train_sampler is not None: + log.info(self.train_sampler) + + if self.train_dataset: + self._train_loader = self._dataloader( + self.train_dataset, + self.train_pre_batch_collate_transform, + conv_type, + precompute_multi_scale, + batch_size=batch_size, + shuffle=shuffle and not self.train_sampler, + num_workers=num_workers, + sampler=self.train_sampler, + drop_last=drop_last and not self.train_sampler + ) + + if self.val_dataset: + self._val_loader = self._dataloader( + self.val_dataset, + self.val_pre_batch_collate_transform, + conv_type, + precompute_multi_scale, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=self.val_sampler, + ) + + if self.test_dataset: + self._test_loaders = [ + self._dataloader( + dataset, + self.test_pre_batch_collate_transform, + conv_type, + precompute_multi_scale, + batch_size=batch_size, + shuffle=False, + num_workers=num_workers, + sampler=self.test_sampler, + ) + for dataset in self.test_dataset + ] + + if precompute_multi_scale: + self.set_strategies(model) + + def _dataloader(self, dataset, pre_batch_collate_transform, conv_type, precompute_multi_scale, **kwargs): + batch_collate_function = self.__class__._get_collate_function( + conv_type, precompute_multi_scale, pre_batch_collate_transform + ) + num_workers = kwargs.get("num_workers", 0) + persistent_workers = num_workers > 0 + dataloader = partial( + torch.utils.data.DataLoader, + collate_fn=batch_collate_function, + worker_init_fn=np.random.seed, + persistent_workers=persistent_workers, + ) + return dataloader(dataset, **kwargs) + + @property + def has_train_loader(self): + return hasattr(self, "_train_loader") + + @property + def has_val_loader(self): + return hasattr(self, "_val_loader") + + @property + def has_test_loader(self): + return hasattr(self, "_test_loaders") + + @property + def train_dataset(self): + return self._train_dataset + + @train_dataset.setter + def train_dataset(self, value): + self._train_dataset = value + if not hasattr(self._train_dataset, "name"): + setattr(self._train_dataset, "name", "train") + + @property + def val_dataset(self): + return self._val_dataset + + @val_dataset.setter + def val_dataset(self, value): + self._val_dataset = value + if not hasattr(self._val_dataset, "name"): + setattr(self._val_dataset, "name", "val") + + @property + def test_dataset(self): + return self._test_dataset + + @test_dataset.setter + def test_dataset(self, value): + if isinstance(value, list): + self._test_dataset = value + else: + self._test_dataset = [value] + + for i, dataset in enumerate(self._test_dataset): + if not hasattr(dataset, "name"): + if self.num_test_datasets > 1: + setattr(dataset, "name", "test_%i" % i) + else: + setattr(dataset, "name", "test") + else: + self._contains_dataset_name = True + + # Check for uniqueness + all_names = [d.name for d in self.test_dataset] + if len(set(all_names)) != len(all_names): + raise ValueError("Datasets need to have unique names. Current names are {}".format(all_names)) + + @property + def train_dataloader(self): + return self._train_loader + + @property + def val_dataloader(self): + return self._val_loader + + @property + def test_dataloaders(self): + if self.has_test_loader: + return self._test_loaders + else: + return [] + + @property + def _loaders(self): + loaders = [] + if self.has_train_loader: + loaders += [self.train_dataloader] + if self.has_val_loader: + loaders += [self.val_dataloader] + if self.has_test_loader: + loaders += self.test_dataloaders + return loaders + + @property + def num_test_datasets(self): + return len(self._test_dataset) if self._test_dataset else 0 + + @property + def _test_datatset_names(self): + if self.test_dataset: + return [d.name for d in self.test_dataset] + else: + return [] + + @property + def available_stage_names(self): + out = self._test_datatset_names + if self.has_val_loader: + out += [self._val_dataset.name] + return out + + @property + def available_dataset_names(self): + return ["train"] + self.available_stage_names + + def get_raw_data(self, stage, idx, **kwargs): + assert stage in self.available_dataset_names + dataset = self.get_dataset(stage) + if hasattr(dataset, "get_raw_data"): + return dataset.get_raw_data(idx, **kwargs) + else: + raise Exception("Dataset {} doesn t have a get_raw_data function implemented".format(dataset)) + + def has_labels(self, stage: str) -> bool: + """ Tests if a given dataset has labels or not + + Parameters + ---------- + stage : str + name of the dataset to test + """ + assert stage in self.available_dataset_names + dataset = self.get_dataset(stage) + if hasattr(dataset, "has_labels"): + return dataset.has_labels + + sample = dataset[0] + if hasattr(sample, "y"): + return sample.y is not None + return False + + @property # type: ignore + @save_used_properties + def is_hierarchical(self): + """ Used by the metric trackers to log hierarchical metrics + """ + return False + + @property # type: ignore + @save_used_properties + def class_to_segments(self): + """ Use this property to return the hierarchical map between classes and segment ids, example: + { + 'Airplaine': [0,1,2], + 'Boat': [3,4,5] + } + """ + return None + + @property # type: ignore + @save_used_properties + def num_classes(self): + if self.train_dataset: + return self.train_dataset.num_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_classes + else: + return self.test_dataset.num_classes + elif self.val_dataset is not None: + return self.val_dataset.num_classes + else: + raise NotImplementedError() + + @property + def weight_classes(self): + return getattr(self.train_dataset, "weight_classes", None) + + @property # type: ignore + @save_used_properties + def feature_dimension(self): + if self.train_dataset: + return self.train_dataset.num_features + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_features + else: + return self.test_dataset.num_features + elif self.val_dataset is not None: + return self.val_dataset.num_features + else: + raise NotImplementedError() + + @property + def batch_size(self): + return self._batch_size + + @property + def num_batches(self): + out = { + "train": len(self._train_loader) if self.has_train_loader else 0, + "val": len(self._val_loader) if self.has_val_loader else 0, + } + if self.test_dataset: + for loader in self._test_loaders: + stage_name = loader.dataset.name + out[stage_name] = len(loader) + return out + + def get_dataset(self, name): + """ Get a dataset by name. Raises an exception if no dataset was found + + Parameters + ---------- + name : str + """ + all_datasets = [self.train_dataset, self.val_dataset] + if self.test_dataset: + all_datasets += self.test_dataset + for dataset in all_datasets: + if dataset is not None and dataset.name == name: + return dataset + raise ValueError("No dataset with name %s was found." % name) + + def _set_composed_multiscale_transform(self, attr, transform): + current_transform = getattr(attr.dataset, "transform", None) + if current_transform is None: + setattr(attr.dataset, "transform", transform) + else: + if ( + isinstance(current_transform, Compose) and transform not in current_transform.transforms + ): # The transform contains several transformations + current_transform.transforms += [transform] + elif current_transform != transform: + setattr( + attr.dataset, "transform", Compose([current_transform, transform]), + ) + + def _set_multiscale_transform(self, transform): + for _, attr in self.__dict__.items(): + if isinstance(attr, torch.utils.data.DataLoader): + self._set_composed_multiscale_transform(attr, transform) + for loader in self.test_dataloaders: + self._set_composed_multiscale_transform(loader, transform) + + def set_strategies(self, model): + strategies = model.get_spatial_ops() + transform = MultiScaleTransform(strategies) + self._set_multiscale_transform(transform) + + @abstractmethod + def get_tracker(self, wandb_log: bool, tensorboard_log: bool): + pass + + def resolve_saving_stage(self, selection_stage): + """This function is responsible to determine if the best model selection + is going to be on the validation or test datasets + """ + log.info( + "Available stage selection datasets: {} {} {}".format( + COLORS.IPurple, self.available_stage_names, COLORS.END_NO_TOKEN + ) + ) + + if self.num_test_datasets > 1 and not self._contains_dataset_name: + msg = "If you want to have better trackable names for your test datasets, add a " + msg += COLORS.IPurple + "name" + COLORS.END_NO_TOKEN + msg += " attribute to them" + log.info(msg) + + if selection_stage == "": + if self.has_val_loader: + selection_stage = self.val_dataset.name + elif self.has_test_loader: + selection_stage = self.test_dataset[0].name + else: + selection_stage = self.train_dataset.name + log.info( + "The models will be selected using the metrics on following dataset: {} {} {}".format( + COLORS.IPurple, selection_stage, COLORS.END_NO_TOKEN + ) + ) + return selection_stage + + def add_weights(self, dataset_name="train", class_weight_method="sqrt"): + """ Add class weights to a given dataset that are then accessible using the `class_weights` attribute + """ + L = self.num_classes + weights = torch.ones(L) + dataset = self.get_dataset(dataset_name) + idx_classes, counts = torch.unique(dataset.data.y, return_counts=True) + + dataset.idx_classes = torch.arange(L).long() + weights[idx_classes] = counts.float() + weights = weights.float() + weights = weights.mean() / weights + if class_weight_method == "sqrt": + weights = torch.sqrt(weights) + elif str(class_weight_method).startswith("log"): + weights = torch.log(1.1 + weights / weights.sum()) + else: + raise ValueError("Method %s not supported" % class_weight_method) + + weights /= torch.sum(weights) + log.info("CLASS WEIGHT : {}".format([np.round(weight.item(), 4) for weight in weights])) + setattr(dataset, "weight_classes", weights) + + return dataset + + def __repr__(self): + message = "Dataset: %s \n" % self.__class__.__name__ + for attr in self.__dict__: + if "transform" in attr: + message += "{}{} {}= {}\n".format(COLORS.IPurple, attr, COLORS.END_NO_TOKEN, getattr(self, attr)) + for attr in self.__dict__: + if attr.endswith("_dataset"): + dataset = getattr(self, attr) + if isinstance(dataset, list): + if len(dataset) > 1: + size = ", ".join([str(len(d)) for d in dataset]) + else: + size = len(dataset[0]) + elif dataset: + size = len(dataset) + else: + size = 0 + if attr.startswith("_"): + attr = attr[1:] + message += "Size of {}{} {}= {}\n".format(COLORS.IPurple, attr, COLORS.END_NO_TOKEN, size) + for key, attr in self.__dict__.items(): + if key.endswith("_sampler") and attr: + message += "{}{} {}= {}\n".format(COLORS.IPurple, key, COLORS.END_NO_TOKEN, attr) + message += "{}Batch size ={} {}".format(COLORS.IPurple, COLORS.END_NO_TOKEN, self.batch_size) + return message diff --git a/torch-points3d/torch_points3d/datasets/batch.py b/torch-points3d/torch_points3d/datasets/batch.py new file mode 100644 index 0000000..51a1536 --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/batch.py @@ -0,0 +1,58 @@ +import torch +from torch_geometric.data import Data + + +class SimpleBatch(Data): + r""" A classic batch object wrapper with :class:`torch_geometric.data.Data` being the + base class, all its methods can also be used here. + """ + + def __init__(self, batch=None, **kwargs): + super(SimpleBatch, self).__init__(**kwargs) + + self.batch = batch + self.__data_class__ = Data + + @staticmethod + def from_data_list(data_list): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + """ + keys = [set(data.keys) for data in data_list] + keys = list(set.union(*keys)) + + # Check if all dimensions matches and we can concatenate data + # if len(data_list) > 0: + # for data in data_list[1:]: + # for key in keys: + # assert data_list[0][key].shape == data[key].shape + + batch = SimpleBatch() + batch.__data_class__ = data_list[0].__class__ + + for key in keys: + batch[key] = [] + + for _, data in enumerate(data_list): + for key in data.keys: + item = data[key] + batch[key].append(item) + + for key in batch.keys: + item = batch[key][0] + if ( + torch.is_tensor(item) + or isinstance(item, int) + or isinstance(item, float) + ): + batch[key] = torch.stack(batch[key]) + else: + raise ValueError("Unsupported attribute type") + + return batch.contiguous() + # return [batch.x.transpose(1, 2).contiguous(), batch.pos, batch.y.view(-1)] + + @property + def num_graphs(self): + """Returns the number of graphs in the batch.""" + return self.batch[-1].item() + 1 diff --git a/torch-points3d/torch_points3d/datasets/dataset_factory.py b/torch-points3d/torch_points3d/datasets/dataset_factory.py new file mode 100644 index 0000000..c39149d --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/dataset_factory.py @@ -0,0 +1,48 @@ +import importlib +import copy +import hydra +import logging + +from torch_points3d.datasets.base_dataset import BaseDataset + +log = logging.getLogger(__name__) + + +def get_dataset_class(dataset_config): + task = dataset_config.task + # Find and create associated dataset + try: + dataset_config.dataroot = hydra.utils.to_absolute_path( + dataset_config.dataroot) + except Exception: + log.error("This should happen only during testing") + dataset_class = getattr(dataset_config, "class") + dataset_paths = dataset_class.split(".") + module = ".".join(dataset_paths[:-1]) + class_name = dataset_paths[-1] + dataset_module = ".".join(["torch_points3d.datasets", task, module]) + datasetlib = importlib.import_module(dataset_module) + + target_dataset_name = class_name + + for name, cls in datasetlib.__dict__.items(): + if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset): + dataset_cls = cls + + if dataset_cls is None: + raise NotImplementedError( + "In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." + % (module, class_name) + ) + return dataset_cls + + +def instantiate_dataset(dataset_config) -> BaseDataset: + """Import the module "data/[module].py". + In the file, the class called {class_name}() will + be instantiated. It has to be a subclass of BaseDataset, + and it is case-insensitive. + """ + dataset_cls = get_dataset_class(dataset_config) + dataset = dataset_cls(dataset_config) + return dataset diff --git a/torch-points3d/torch_points3d/datasets/instance/__init__.py b/torch-points3d/torch_points3d/datasets/instance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/datasets/instance/las_dataset.py b/torch-points3d/torch_points3d/datasets/instance/las_dataset.py new file mode 100644 index 0000000..5153048 --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/instance/las_dataset.py @@ -0,0 +1,1074 @@ +import logging +import os +from collections import OrderedDict +from functools import partial +from glob import glob +from itertools import chain, product +from pathlib import Path +from typing import Sized, Iterator + +import geopandas as gpd +import laspy +import numpy as np +import pandas as pd +import pyproj +import scipy.stats as scstats +import torch +from omegaconf import OmegaConf +from plyfile import PlyData +from shapely.geometry import Point +from sklearn.neighbors import KDTree +from torch.utils.data import Sampler +from torch_geometric.data import Dataset, Data +from tqdm.auto import tqdm + +from torch_points3d.datasets.base_dataset import BaseDataset, save_used_properties +from torch_points3d.metrics.instance_tracker import InstanceTracker +from torch_points3d.models import model_interface + +log = logging.getLogger(__name__) + + +def read_pt(pt_file, feature_cols, delimiter: str): + crs = None + has_features = len(feature_cols) > 0 + if Path(pt_file).suffix in [".las", ".laz"]: + backend = laspy.compression.LazBackend(0) + if not backend.is_available(): + backend = laspy.compression.LazBackend(1) + if not backend.is_available(): + backend = laspy.compression.LazBackend(2) + loaded_file = laspy.read(pt_file, laz_backend=backend) + pos = np.stack([loaded_file.x, loaded_file.y, loaded_file.z], 1) + if has_features: + features = np.stack([getattr(loaded_file, feature) for feature in feature_cols], 1) + else: + features = None + + # get crs + for vlr in loaded_file.header.vlrs: + if isinstance(vlr, laspy.vlrs.known.WktCoordinateSystemVlr): + # read general CRS (ignores specific parameters) + crs = pyproj.CRS(vlr.string) + elif Path(pt_file).suffix in [".ply"]: + loaded_file = PlyData.read(pt_file) + pos = np.stack([loaded_file.elements[0]["x"], loaded_file.elements[0]["y"], loaded_file.elements[0]["z"]], 1) + if has_features: + features = np.stack([loaded_file.elements[0][feat] for feat in feature_cols], 1) + else: + features = None + else: + # try to read as csv + loaded_file = pd.read_csv( + pt_file, header=None, engine="pyarrow", delimiter=delimiter, dtype=np.float32, skip_blank_lines=True + ) + pos = loaded_file.values[:, :3] # assumes first 3 values are positions + if has_features: + features = loaded_file[feature_cols] + else: + features = None + + return pos, features, crs + + +class Las(Dataset): + """loads all las files into memory and creates samples based on a label_df""" + + def __init__( + self, root, areas: dict, split: str, stats=None, + xy_radius=15., + transform=None, targets=None, feature_cols=None, feature_scaling_dict: dict = None, + pre_transform=None, pre_filter=None, save_local_stats: bool = True, + min_pts_outer: int = 500, min_pts_inner: int = 250, + save_processed: bool = True, processed_folder="processed", in_memory: bool = False, + pos_dict: dict = None, features_dict: dict = None, pos_tree_dict: dict = None, crs_dict: dict = None + ): + self.root = root + self.split = split + + self.min_pts_outer = min_pts_outer + self.min_pts_inner = min_pts_inner + + # useful for double batch detection + self.prev_idx = None + + assert save_processed or in_memory, "Samples are neither saved to processed folder or kept in memory! " \ + "(set either save_processed or in_memory to True)" + self.in_memory = in_memory + if in_memory: + self.memory = {} + + if not save_processed and in_memory: + log.info("Not saving any samples, storing areas in memory if not present on disk") + + self.save_processed = save_processed + self.processed_folder = processed_folder + + self.save_local_stats = save_local_stats + + if pos_dict is not None or pos_tree_dict is not None: + assert pos_dict is not None and pos_tree_dict is not None, \ + "if any of pos or pos_tree are given, both need to be there" + + assert (len(feature_cols) > 0 and (features_dict is not None)) or len(feature_cols) == 0, \ + "need to give features, if pos is given and there are features" + self.pos_dict = {} if pos_dict is None else pos_dict + self.features_dict = {} if features_dict is None else features_dict + self.pos_tree_dict = {} if pos_tree_dict is None else pos_tree_dict + self.crs_dict = {} if crs_dict is None else crs_dict + + self.areas = areas + + self.targets = targets + self.feature_cols = [] if feature_cols is None else feature_cols + self.stats = [] if stats is None else stats + # difference between measurement and pointclouds taken + self.radius = xy_radius + + # different types of targets + self.reg_targets = [target for target in self.targets if self.targets[target]["task"] == "regression"] + self.cls_targets = [target for target in self.targets if self.targets[target]["task"] == "classification"] + self.cls_targets_ = [f"{target}_" for target in self.targets if + self.targets[target]["task"] == "classification"] + self.mol_targets = [target for target in self.targets if self.targets[target]["task"] == "mol"] + + # if not give, calculate on given data + if feature_scaling_dict is None: + feature_scaling_dict = { + area_name: + { # feature: (center, scale) + "num_returns": (0., 5.), + "return_num": (0., 5.), } + for area_name in areas + } + self.feature_scaling_dict = feature_scaling_dict + + super().__init__( + root, transform, pre_transform, pre_filter + ) + # check if all areas are actually processed when using saves + if self.save_processed: + for area_name in areas: + area = areas[area_name] + labels = area["labels"].query(f"{area['split_col']} == '{self.split}'") + if len(labels) > 0 and not (Path(self.processed_dir) / self.split / area_name / "done.flag").exists(): + log.info(f'Resuming processing, since {area_name} is not complete!') + self.process() + else: + self.process() + + # pre-load into memory if not already done during processing + if self.in_memory: + log.info("Pre-loading into memory") + pbar = tqdm(range(len(self)), total=len(self)) + [self.get(idx) for idx in pbar] + + + @property + def processed_dir(self) -> str: + return os.path.join(self.root, self.processed_folder) + + @property + def raw_file_names(self): + files = list(chain(*[[area["pt_files"]] for area in self.areas.values()])) + return files + + @property + def has_labels(self) -> bool: + return self.split in ["val", "test"] + + @property + def processed_file_names(self): + path = Path(self.processed_dir) / self.split + files = glob(str(path / f"*/*.pt")) + return files + + @property + def num_samples(self): + n = 0 + if self.in_memory and not self.save_processed: + n = len(self.memory) + if n == 0: # memory not initialized yet + n = sum([len(area["labels"].query(f"{area['split_col']} == '{self.split}'")) for area in self.areas]) + return n + + for area_name in self.areas: + area = self.areas[area_name] + + if (Path(self.processed_dir) / self.split / area_name / "done.flag").exists(): + n += len(list((Path(self.processed_dir) / self.split / area_name).glob("*.pt"))) + else: + n += len(area["labels"].query(f"{area['split_col']} == '{self.split}'")) + + return n + + def process(self): + file_idx = 0 + + for area_name in self.areas: + flag = (Path(self.processed_dir) / self.split / area_name / "done.flag") + area = self.areas[area_name] + + log.info(f"### start processing area: '{area_name}'") + if not flag.exists(): + + labels = area["labels"].query(f"{area['split_col']} == '{self.split}'") + if len(labels) == 0: + continue + + if area["type"] == "scene": + # can prepare this beforehand + pos, features, inner_label_point_idx, label_point_idx, labels = \ + self.process_scene_area_(area_name, labels) + + ### TODO reintroduce feature scaling + # if feature in feature_scaling: + # center, scale = feature_scaling.get(feature, (0., 1.)) + # else: + # # fill with iqr scaling + # center = np.median(feat) + # scale = (np.quantile(feat, 0.75) - np.quantile(feat, 0.25)) * 1.349 + # feature_scaling[feature] = (center, scale) + # features_sample.append((feat - center) / scale) + + log.info("Saving samples and calculating stats") + if self.save_processed: + (Path(self.processed_dir) / self.split).mkdir(exist_ok=True) + (Path(self.processed_dir) / self.split / area_name).mkdir(exist_ok=True) + missing_idx = [] + for idx, index in tqdm(enumerate(labels.index.values)): + sample = labels.iloc[idx] + file = Path(self.processed_dir) / self.split / area_name / f"{file_idx}.pt" + if file.exists(): + file_idx += 1 + continue + + if area["type"] == "object": + # only load objects here instead of bulk loading before to avoid memory issues + pos, features, crs = read_pt(sample["pt_file"], self.feature_cols, area["delimiter"]) + + if area.get("check_pt_crs", True) and crs is not None and \ + not pyproj.CRS.is_exact_same(labels.crs, crs): + sample = labels.to_crs(crs).iloc[idx] + + # find points + label_centers = [[sample.geometry.x, sample.geometry.y]] + tree = KDTree(pos[:, :2]) + point_idxs = tree.query_radius(label_centers, self.radius)[0] + inner_point_idx = tree.query_radius(label_centers, self.radius / 2.)[0] + del tree + + elif area["type"] == "scene": + point_idxs = label_point_idx[idx] + inner_point_idx = inner_label_point_idx[idx] + else: + raise NotImplementedError("Only 'scence' and 'object' area types are implemented") + + data = self.save_data_( + area_name, index, sample, pos, features, + point_idxs, inner_point_idx + ) + if data is not None: + if self.in_memory: + self.memory[file_idx] = data + if self.save_processed: + torch.save(data, file) + file_idx += 1 + else: + missing_idx.append(index) + area["labels"].drop(index=missing_idx, inplace=True) + if self.save_processed: + flag.touch() + else: + file_idx += len(list((Path(self.processed_dir) / self.split / area_name).glob("*.pt"))) + + def process_scene_area_(self, area_name, labels): + area = self.areas[area_name] + pos_tree = self.pos_tree_dict.get(area_name, None) + + if not pos_tree: + log.info(f"Loading Las files") + pt = [read_pt(las_file, self.feature_cols, area["delimiter"]) for las_file in area["pt_files"]] + + pos = np.concatenate([p[0] for p in pt], 0) + if len(self.feature_cols) > 0: + features = np.concatenate([p[1] for p in pt], 0) + else: + features = None + + crs = np.stack([p[2] for p in pt], 0) + assert np.all(crs[0] == crs_ for crs_ in crs), "pt_files of an area need to be in same crs currently" + crs = crs[0] + + # fit this into a KDTree + log.info("Creating KDTree") + pos_tree = KDTree(pos[:, :2]) + + self.pos_dict[area_name] = pos + self.pos_tree_dict[area_name] = pos_tree + self.features_dict[area_name] = features + self.crs_dict[area_name] = crs + log.info("Querying KDTree") + # restrict to bounds + crs = self.crs_dict[area_name] + if area.get("check_pt_crs", True) and crs is not None and not pyproj.CRS.is_exact_same(labels.crs, crs): + labels = labels.to_crs(crs) + + label_centers = np.stack([labels.geometry.x, labels.geometry.y], 1) + radius = self.radius + label_point_idx = self.pos_tree_dict[area_name].query_radius(label_centers, radius) + inner_label_point_idx = self.pos_tree_dict[area_name].query_radius(label_centers, radius / 2.) + return self.pos_dict[area_name], self.features_dict[area_name], inner_label_point_idx, label_point_idx, labels + + @property + def num_classes(self) -> int: + if not hasattr(self, "num_classes_"): + num_reg_classes = 0 + num_mol_classes = 0 + num_cls_classes = [] + if self.targets: + for target in self.targets: + task = self.targets[target]["task"] + if task == "classification": + num_cls_classes.append(len(self.targets[target]["class_names"])) + elif task == "regression": + num_reg_classes += 1 + elif task.lower() == "mol": + num_mixtures = self.targets[target].get("num_mixtures", 1) + num_mol_classes += num_mixtures * 3 + + self.num_reg_classes_ = num_reg_classes + self.num_mol_classes_ = num_mol_classes + self.num_cls_classes_ = num_cls_classes + + self.num_classes_ = self.num_reg_classes + self.num_mol_classes + int(np.sum(self.num_cls_classes)) + + return self.num_classes_ + + @property + def num_reg_classes(self) -> int: + if not hasattr(self, "num_reg_classes_"): + # init by calling num_classes + _ = self.num_classes + + return self.num_reg_classes_ + + @property + def num_mol_classes(self) -> int: + if not hasattr(self, "num_mol_classes_"): + # init by calling num_classes + _ = self.num_classes + + return self.num_mol_classes_ + + @property + def num_cls_classes(self) -> []: + if not hasattr(self, "num_cls_classes_"): + # init by calling num_classes + _ = self.num_classes + + return self.num_cls_classes_ + + def len(self): + return self.num_samples + + @staticmethod + def get_local_stats(points, postfix=""): + stats = {} + z = points[:, 2] + + z_stats = { + "h_mean": np.mean, + "h_std": np.std, + "h_coov": scstats.variation, + "h_kur": scstats.kurtosis, + "h_skew": scstats.skew, + } + + quantiles = [5, 10, 25, 50, 75, 90, 95, 99] + z_stats.update({f"h_q{i}": partial(np.quantile, q=i / 100) for i in quantiles}) + + def density_q(z, q): + # the proportion of points above the height percentiles + quant = np.quantile(z, q=q) + return len(z[z > quant]) / len(z) + + z_stats.update({f"d_q{i}": partial(density_q, q=i / 100) for i in quantiles}) + + tree = KDTree(points) + # create 1m grid spanning extend + xx = np.arange(points[:, 0].min(), points[:, 0].max(), 1) + yy = np.arange(points[:, 1].min(), points[:, 1].max(), 1) + zz = np.arange(points[:, 2].min(), points[:, 2].max(), 1) + grid = [[x, y, z] for x, y, z in product(xx, yy, zz)] + # get highest density in grid + if len(grid) > len(points): # use points directly if only few points present + grid = points + density = tree.kernel_density(grid, 1, kernel="gaussian").max() + stats["kde_h1"] = density + + for key in z_stats.keys(): + try: + value = z_stats[key](z) + except IndexError: + # return -1 if not enough values in quantiles + value = -1 + + stats[key + postfix] = value + + return stats + + def get(self, idx): + if self.in_memory: + if idx in self.memory.keys(): + data = self.memory[idx].clone() + else: + data = torch.load(self.processed_file_names[idx]) + self.memory[idx] = data.clone() + else: + data = torch.load(self.processed_file_names[idx]) + + del data.local_stats_keys + data["is_double"] = self.prev_idx == idx + self.prev_idx = idx + + return data + + def save_data_(self, area_name: str, idx, sample, pos_: np.array, features_: np.array, + point_idxs: np.array, inner_point_idxs: np.array): + + if len(point_idxs) < self.min_pts_outer: + log.warning(f"only {len(point_idxs)} in total, skipping") + return None + elif len(inner_point_idxs) < self.min_pts_inner: + log.warning(f"only {len(inner_point_idxs)} in inner circle, skipping") + return None + + # only coordinates for now + x = pos_[point_idxs] + inner_x = pos_[inner_point_idxs] + + if features_ is not None: + features = features_[point_idxs] + else: + features = None + + # normalize + inner_x, x = self.center_pos(x, inner_x, sample) + + # get local and df stats + local_stats, local_stats_keys, stats = self.get_stats(x, inner_x, sample) + + # target + if self.targets: + y_reg = sample[self.reg_targets] + y_reg_mask = ~y_reg.isna() + y_mol = sample[self.mol_targets] + y_mol_mask = ~y_mol.isna() + y_cls = sample[self.cls_targets_] + y_cls_mask = ~y_cls.isna() + + else: + y_reg = y_reg_mask = y_mol = y_mol_mask = y_cls = y_cls_mask = [] + + data = self.covert_to_data_( + x, y_reg, y_reg_mask, y_mol, y_mol_mask, y_cls, y_cls_mask, + features, idx, area_name, local_stats, local_stats_keys, stats + ) + + return data + + def covert_to_data_( + self, x, y_reg, y_reg_mask, y_mol, y_mol_mask, y_cls, y_cls_mask, features, idx, area_name, local_stats, + local_stats_keys, stats + ): + x = torch.tensor(x, dtype=torch.float32) + y_reg = torch.tensor(y_reg, dtype=torch.float32) + y_reg_mask = torch.tensor(y_reg_mask, dtype=torch.bool) + y_mol = torch.tensor(y_mol, dtype=torch.float32) + y_mol_mask = torch.tensor(y_mol_mask, dtype=torch.bool) + y_cls[~y_cls_mask] = - 1 + y_cls = torch.tensor(y_cls, dtype=torch.long) + y_cls_mask = torch.tensor(y_cls_mask, dtype=torch.bool) + features = features if features is None else torch.tensor(features, dtype=torch.float32) + stats = torch.tensor(stats, dtype=torch.float32) + local_stats = torch.tensor(local_stats, dtype=torch.float32) + data = Data( + x=features, + y_reg=y_reg, y_reg_mask=y_reg_mask, + y_mol=y_mol, y_mol_mask=y_mol_mask, + y_cls=y_cls, y_cls_mask=y_cls_mask, + pos=x, stats=stats, label_idx=[idx], area_name=area_name, + local_stats=local_stats, local_stats_keys=local_stats_keys + ) + + # apply pre_transform + if self.pre_transform is not None: + data = self.pre_transform(data) + if data.pos.shape[0] == 0: + log.warning(f"Pre transform reduced sample to 0 points, skipping") + return None + + return data + + def get_stats(self, x, inner_x, df): + # local stats + if self.save_local_stats: + local_stats = self.get_local_stats(x) + local_stats.update(self.get_local_stats(inner_x, "_inner")) + local_stats_keys = list(local_stats.keys()) + local_stats = list(local_stats.values()) + else: + local_stats = local_stats_keys = [] + # global stats + stats = df[self.stats] + return local_stats, local_stats_keys, stats + + def center_pos(self, x, inner_x, df): + x_center = np.amin(x, axis=0, keepdims=True) + x_center[:, 0] = df.geometry.x + x_center[:, 1] = df.geometry.y + x -= x_center + inner_x -= x_center + return inner_x, x + + +class LasDataset(BaseDataset): + def __init__(self, dataset_opt): + super().__init__(dataset_opt) + self.dataset_opt = dataset_opt + self.targets = dataset_opt.get("targets", None) + self.target_keys = list(self.targets.keys()) if self.targets is not None else None + self.features = dataset_opt.features + self.stats = dataset_opt.stats + self.xy_radius = dataset_opt.xy_radius + self.x_scale = dataset_opt.x_scale + self.y_scale = dataset_opt.y_scale + self.z_scale = dataset_opt.z_scale + self.transform_type = dataset_opt.transform_type + self.double_batch = dataset_opt.get(self.transform_type).get("double_batch", False) + self.log_train_metrics = dataset_opt.get("log_train_metrics", True) + + self.reg_targets = [target for target in self.targets if self.targets[target]["task"] == "regression"] + self.reg_targets_idx = [self.targets[target]["task"] == "regression" for target in self.targets] + self.cls_targets = [target for target in self.targets if self.targets[target]["task"] == "classification"] + self.cls_targets_idx = [self.targets[target]["task"] == "classification" for target in self.targets] + self.cls_targets_ = [f"{target}_" for target in self.cls_targets] + self.mol_targets = [target for target in self.targets if self.targets[target]["task"] == "mol"] + self.mol_targets_idx = [self.targets[target]["task"] == "mol" for target in self.targets] + + self.min_pts_outer = dataset_opt.get("min_pts_outer", 500) + self.min_pts_inner = dataset_opt.get("min_pts_inner", 250) + + in_memory = dataset_opt.get("in_memory", False) + save_processed = dataset_opt.get("save_processed", True) + save_local_stats = dataset_opt.get("save_local_stats", True) + train_subset = dataset_opt.get("train_subset", False) + + processed_folder = dataset_opt.get("processed_folder", "processed") + + areas_file = (self._data_path / (Path(processed_folder)) / "areas.pt") + self.areas: dict = OrderedDict(OmegaConf.to_container(dataset_opt.areas)) + if areas_file.exists(): + self.areas.update(torch.load(areas_file)) + self.process_area_labels(dataset_opt) + train_set_avail = any( + [len(area["labels"].query(f"{area['split_col']} == 'train'")) > 0 for area in self.areas.values()]) + val_set_avail = any( + [len(area["labels"].query(f"{area['split_col']} == 'val'")) > 0 for area in self.areas.values()]) + test_set_avail = any( + [len(area["labels"].query(f"{area['split_col']} == 'test'")) > 0 for area in self.areas.values()]) + + if save_processed: + (self._data_path / (Path(processed_folder))).mkdir(exist_ok=True) + + feature_scaling_file = self._data_path / (Path(processed_folder) / "features_scaling.pt") + feature_scaling_dict = torch.load(feature_scaling_file) if feature_scaling_file.exists() else None + + assert train_set_avail or val_set_avail or test_set_avail, "Apparently no data available" + + pos_dict = {} + pos_tree_dict = {} + features_dict = {} + crs_dict = {} + if train_set_avail: + if train_subset: + train_subset_remove = 1 - train_subset + for area in self.areas.values(): + idx = area["labels"].query(f"{area['split_col']} == 'train'").index + idx = np.random.choice(idx, int(len(idx) * train_subset_remove), replace=False) + area["labels"].drop(index=idx, inplace=True) + + log.info("Init train dataset") + self.train_dataset = Las( + self._data_path, areas=self.areas, split="train", + targets=self.targets, feature_cols=self.features, feature_scaling_dict=feature_scaling_dict, + stats=dataset_opt.stats, transform=self.train_transform, pre_transform=self.pre_transform, + save_processed=save_processed, processed_folder=processed_folder, in_memory=in_memory, + xy_radius=self.xy_radius, save_local_stats=save_local_stats, + min_pts_outer=self.min_pts_outer, min_pts_inner=self.min_pts_inner + ) + if not feature_scaling_file.exists(): + feature_scaling_dict = self.train_dataset.feature_scaling_dict + torch.save(feature_scaling_dict, feature_scaling_file) + + pos_dict.update(self.train_dataset.pos_dict) + pos_tree_dict.update(self.train_dataset.pos_tree_dict) + features_dict.update(self.train_dataset.features_dict) + crs_dict.update(self.train_dataset.crs_dict) + + if val_set_avail: + log.info("Init val dataset") + self.val_dataset = Las( + self._data_path, areas=self.areas, split="val", + targets=self.targets, feature_cols=self.features, feature_scaling_dict=feature_scaling_dict, + stats=dataset_opt.stats, transform=self.val_transform, pre_transform=self.pre_transform, + save_processed=save_processed, processed_folder=processed_folder, in_memory=in_memory, + xy_radius=self.xy_radius, save_local_stats=save_local_stats, + min_pts_outer=self.min_pts_outer, min_pts_inner=self.min_pts_inner, + pos_dict=pos_dict, features_dict=features_dict, + pos_tree_dict=pos_tree_dict, crs_dict=crs_dict + ) + + pos_dict.update(self.val_dataset.pos_dict) + pos_tree_dict.update(self.val_dataset.pos_tree_dict) + features_dict.update(self.val_dataset.features_dict) + crs_dict.update(self.val_dataset.crs_dict) + + if test_set_avail: + log.info("Init test dataset") + self.test_dataset = Las( + self._data_path, areas=self.areas, split="test", + targets=self.targets, feature_cols=self.features, feature_scaling_dict=feature_scaling_dict, + stats=dataset_opt.stats, transform=self.test_transform, pre_transform=self.pre_transform, + save_processed=save_processed, processed_folder=processed_folder, in_memory=in_memory, + xy_radius=self.xy_radius, save_local_stats=save_local_stats, + min_pts_outer=self.min_pts_outer, min_pts_inner=self.min_pts_inner, + pos_dict=pos_dict, features_dict=features_dict, + pos_tree_dict=pos_tree_dict, crs_dict=crs_dict + ) + + del pos_dict, pos_tree_dict, features_dict, crs_dict + + # save areas in preprocessed file + if save_processed: + torch.save(self.areas, areas_file) + + self.set_label_stats_(save_processed) + + self.has_reg_targets = len(self.reg_targets) > 0 + self.has_mol_targets = len(self.mol_targets) > 0 + self.has_cls_targets = len(self.cls_targets) > 0 + + def process_area_labels(self, dataset_opt): + for area_name in self.areas: + area = self.areas[area_name] + + # assume that if the labels are set, the area was already processed + if area.get("labels", None) is not None: + continue + + # set some standard params + area["delimiter"] = area.get("delimiter", dataset_opt.get("delimiter", ",")) + + # processing file lists + pt_files = area["pt_files"] + if isinstance(pt_files, (str, Path)): + pt_files = glob(str(Path(self._data_path) / "raw" / pt_files)) + elif isinstance(pt_files, list): + # iterating to list of files + unpacked_files = [] + for f in pt_files: + unpacked_files.extend(glob(str(Path(self._data_path) / "raw" / f))) + pt_files = unpacked_files + else: + raise Exception("pt_files need to be a str or a list of str (can use * expression)") + + labels = self.process_label_files_(area, area_name) + + labels.geometry = labels.centroid + + if area["type"] == "object": + # check if each label has a pt_file + def find_pt_file(id): + for ptf in pt_files: + # return first occurrence + if id in ptf: + return ptf + return "None" + + labels["pt_file"] = labels[area["pt_identifier"]].apply(find_pt_file) + + # removing sample without pt_file + n_samples = len(labels) + labels.query("pt_file != 'None'", inplace=True) + if len(labels) != n_samples: + log.warning(f"{n_samples - len(labels)} removed due to missing pt_file") + + pt_files = labels["pt_file"].values.tolist() + + area["pt_files"] = pt_files + + split_col = area.get("split_col", dataset_opt.get("split_col", "split")) + area["split_col"] = split_col + # create split if fully labeled data available + if split_col not in labels.columns: + targets_must_be_present = np.array(area.get("targets_must_be_present", [True] * len(self.target_keys))) + lb = labels[np.array(self.target_keys)[targets_must_be_present]] + + val_ratio = area.get("val_ratio", .1) + test_ratio = area.get("test_ratio", .1) + + # if no targets are fully available, only use this area for training + if (lb.shape[1] > 0 and lb.isna().all().all()) or val_ratio == test_ratio == 0.0: + labels.loc[:, split_col] = "train" + else: + # no split available, create own + # only select those that have labels others are for training + if any(targets_must_be_present): + partly_missing = lb.isna().all(axis=1) + lables_partly_missing = labels[partly_missing] + lables_partly_missing[split_col] = "train" + + lables_full = labels[~partly_missing] + else: + lables_partly_missing = pd.DataFrame() + lables_full = labels + index = lables_full.index.values + + rs = np.random.RandomState(42) + + rs.shuffle(index) + + train_end = int(len(index) * (1 - (val_ratio + test_ratio))) + val_end = int(len(index) * (1 - test_ratio)) + train_idx = index[:train_end] + val_idx = index[train_end:val_end] + test_idx = index[val_end:] + + lables_full.loc[train_idx, split_col] = "train" + if val_ratio != 0 and len(val_idx) > 0: + lables_full.loc[val_idx, split_col] = "val" + if test_ratio != 0 and len(test_idx) > 0: + lables_full.loc[test_idx, split_col] = "test" + + labels = pd.concat([lables_partly_missing, lables_full]) + + if len(labels.query(f"['val', 'test'] in {split_col}")) == 0: + log.warning(f"neither val nor test set present for {area_name}") + + area["labels"] = labels + + def process_label_files_(self, area: dict, area_name: str): + label_files = area["label_files"] + # ensure labels file follows schemata: + # [file_1, ..., file_n] + if isinstance(label_files, (str, Path)): + label_files = [label_files] + + assert len(label_files) > 0, f"no labels given, check area {area_name}" + + labels = None + for lf in label_files: + lb = gpd.read_file(Path(self._data_path) / "raw" / lf) + + # put dummy point if no position exists (usually true for csv data) + lb.geometry = lb.geometry.apply(lambda g: Point(0, 0) if g is None else g) + + alias_targets = area.get("alias_targets", self.targets) + assert len(alias_targets) == len(self.targets), f"given target aliases for '{area_name}' have " \ + f"different lengths: {alias_targets} vs {self.targets}" + + target_metric_factor = area.get("target_metric_factor", None) + + # add targets if present else set to nan + for ori_target, alias_target in zip(self.targets, alias_targets): + task = self.targets[ori_target]["task"] + if alias_target in lb: + lb[ori_target] = lb[alias_target] + # assumes that classification targets will be not necessarily be numbers, but everything else is + if task in ["regression", "mol"]: + lb[ori_target] = pd.to_numeric(lb[ori_target], errors="coerce") + if target_metric_factor is not None: + lb[ori_target] *= target_metric_factor.get(ori_target, 1.0) + else: + lb[ori_target] = np.nan + + if task == "classification": + # also save numerical values according to given classes + lb[f"{ori_target}_"] = lb[ori_target].map( + self.targets[ori_target]["class_mapping"] + ).astype(float) + + # crs comparison + if labels is None: + labels = lb + crs = lb.crs + else: + if crs != lb.crs: + Warning("CRS of label files do not match, have to convert") + lb = lb.to_crs(crs) + labels = pd.concat([labels, lb]) + + # indicate fully/partly missing targets in label sample + n_labels = len(labels) + nans_allowed = area.get("nans_allowed", True) + fully_missing = labels[self.targets].isna().all(axis=1).sum() + partly_missing = labels[self.targets].isna().any(axis=1).sum() + partly_missing = abs(partly_missing - fully_missing) + if fully_missing > 0: + log.info(f"{fully_missing} of {n_labels} labels fully missing in {area_name}") + if fully_missing == n_labels: + area["has_labels"] = False + if partly_missing > 0: + log.info(f"{partly_missing} of {n_labels} labels partly missing in {area_name}") + if fully_missing + partly_missing == n_labels and not nans_allowed: + area["has_labels"] = False + + if not nans_allowed: + labels.dropna(axis=0, how="any", subset=self.targets, inplace=True) + log.info( + f"Removing all missing or partly missing samples as indicated by 'nans_allowed' in {area_name}" + ) + + # apply filter query + query = area.get("label_query", None) + if query is not None: + labels.query(query, inplace=True) + if n_labels > len(labels): + log.warning(f"({n_labels - len(labels)} sample were " + f"filtered out according to: {query})") + + labels.set_index(np.arange(len(labels)), inplace=True) + return labels + + def set_label_stats_(self, save_processed: bool): + processed_dir = Path(os.path.join(self._data_path, self.dataset_opt.processed_folder)) + if save_processed: + processed_dir.mkdir(exist_ok=True) + means_file = processed_dir / "mean_targets.pt" + std_file = processed_dir / "std_targets.pt" + min_file = processed_dir / "min_targets.pt" + max_file = processed_dir / "max_targets.pt" + corr_file = processed_dir / "corr_targets.pt" + + self.mean_targets_ = torch.load(means_file) if means_file.exists() else \ + self.get_stat_targets_(np.nanmean, means_file if save_processed else None) + self.std_targets_ = torch.load(std_file) if std_file.exists() else \ + self.get_stat_targets_(np.nanstd, std_file if save_processed else None) + self.min_targets_ = torch.load(min_file) if min_file.exists() else \ + self.get_stat_targets_(np.nanmin, min_file if save_processed else None) + self.max_targets_ = torch.load(max_file) if max_file.exists() else \ + self.get_stat_targets_(np.nanmax, max_file if save_processed else None) + + self.corr_targets_ = torch.load(corr_file) if corr_file.exists() else \ + self.get_corr_targets_(corr_file if save_processed else None) + + def create_dataloaders( + self, + model: model_interface.DatasetInterface, + batch_size: int, + shuffle: bool, + drop_last: bool, + num_workers: int, + precompute_multi_scale: bool, + ): + if self.train_dataset and shuffle: + self.train_sampler = RandomSampler(self.train_dataset, batch_size, self.double_batch) + if drop_last is False: + log.warning("Cannot disable 'drop_last' with RandomSampler.") + super().create_dataloaders(model, batch_size, shuffle, drop_last, num_workers, precompute_multi_scale) + + def get_std_targets(self): + return self.std_targets_ + + def get_mean_targets(self): + return self.mean_targets_ + + def get_min_targets(self): + return self.min_targets_ + + def get_max_targets(self): + return self.max_targets_ + + def get_stat_targets_(self, stat_fn, file_name: (str, Path) = None): + dict = OrderedDict() + targets = [f"{target}_" if self.targets[target]["task"] == "classification" else target for target in + self.targets] + + dict["total"] = {} + if self.train_dataset is not None: + dict["total"].update({"train": [], }) + if self.val_dataset is not None: + dict["total"].update({"val": [], }) + if self.test_dataset is not None: + dict["total"].update({"test": [], }) + + for area_name in self.areas: + # TODO also uses labels that were not used due to too few points + sc = self.areas[area_name]["split_col"] + labels = self.areas[area_name]["labels"] + area_dict = {} + if self.train_dataset is not None and labels.query(f"{sc} == 'val'").shape[0] > 1: + values = labels.query(f"{sc} == 'train'")[targets].values + area_dict.update({"train": stat_fn(values, 0), }) + dict["total"]["train"].append(values) + if self.val_dataset is not None and labels.query(f"{sc} == 'val'").shape[0] > 1: + values = labels.query(f"{sc} == 'val'")[targets].values + area_dict.update({"val": stat_fn(values, 0), }) + dict["total"]["val"].append(values) + if self.test_dataset is not None and labels.query(f"{sc} == 'test'").shape[0] > 1: + values = labels.query(f"{sc} == 'test'")[targets].values + area_dict.update({"test": stat_fn(values, 0), }) + dict["total"]["test"].append(values) + + if len(area_dict) > 0: + dict[area_name] = area_dict + + if self.train_dataset is not None: + dict["total"]["train"] = stat_fn(np.concatenate(dict["total"]["train"], 0), 0) + if self.val_dataset is not None: + dict["total"]["val"] = stat_fn(np.concatenate(dict["total"]["val"], 0), 0) + if self.test_dataset is not None: + dict["total"]["test"] = stat_fn(np.concatenate(dict["total"]["test"], 0), 0) + + if file_name is not None: + torch.save(dict, file_name) + + return dict + + def get_corr_targets(self): + return self.corr_targets_ + + def get_corr_targets_(self, file_name: (str, Path) = None): + dict = OrderedDict() + targets = [f"{target}_" if self.targets[target]["task"] == "classification" else target for target in + self.targets] + + for area_name in self.areas: + sc = self.areas[area_name]["split_col"] + labels = self.areas[area_name]["labels"] + area_dict = {} + if self.train_dataset is not None and labels.query(f"{sc} == 'train'").shape[0] > 1: + area_dict.update({"train": labels.query(f"{sc} == 'train'")[targets].corr().values, }) + if self.val_dataset is not None and labels.query(f"{sc} == 'val'").shape[0] > 1: + area_dict.update( + {"val": labels.query(f"{sc} == 'val'")[targets].corr().values, } + ) + if self.test_dataset is not None and labels.query(f"{sc} == 'test'").shape[0] > 1: + area_dict.update( + {"test": labels.query(f"{sc} == 'test'")[targets].corr().values, } + ) + + if len(area_dict) > 0: + dict[area_name] = area_dict + + if file_name is not None: + torch.save(dict, file_name) + return dict + + def get_tracker(self, wandb_log: bool, tensorboard_log: bool): + """Factory method for the tracker + Arguments: + wandb_log - Log using weight and biases + tensorboard_log - Log using tensorboard + Returns: + [BaseTracker] -- tracker + """ + return InstanceTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log, + log_train_metrics=self.log_train_metrics) + + @property # type: ignore + @save_used_properties + def num_reg_classes(self) -> int: + if self.train_dataset: + return self.train_dataset.num_reg_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_reg_classes + else: + return self.test_dataset.num_reg_classes + elif self.val_dataset is not None: + return self.val_dataset.num_reg_classes + else: + raise NotImplementedError() + + @property # type: ignore + @save_used_properties + def num_mol_classes(self) -> int: + if self.train_dataset: + return self.train_dataset.num_mol_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_mol_classes + else: + return self.test_dataset.num_mol_classes + elif self.val_dataset is not None: + return self.val_dataset.num_mol_classes + else: + raise NotImplementedError() + + @property # type: ignore + @save_used_properties + def num_cls_classes(self) -> int: + if self.train_dataset: + return self.train_dataset.num_cls_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_cls_classes + else: + return self.test_dataset.num_cls_classes + elif self.val_dataset is not None: + return self.val_dataset.num_cls_classes + else: + raise NotImplementedError() + + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. + + Args: + data_source (Dataset): dataset to sample from + batch_size (int): number of samples in a mini-batch + double_batch (bool): if each sample should in a batch should be returned twice (e.g., for self-supervision) + generator (Generator): Generator used in sampling. + """ + data_source: Sized + + def __init__(self, data_source: Sized, batch_size: int, double_batch: bool, generator=None) -> None: + super().__init__(data_source) + self.data_source = data_source + self.generator = generator + self._num_samples = None + self.batch_size = batch_size + self.double_batch = double_batch + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + iterator = torch.randperm(self.num_samples, generator=generator).tolist() + if self.double_batch: + iterator = np.array([[k, k] for k in iterator]).flatten().tolist() + iterator = iterator[:(self.num_samples // self.batch_size) * self.batch_size] + + yield from iterator + + def __len__(self) -> int: + return self.num_samples + + def __repr__(self): + return "{}(batch_size={},double_batch={},generator={})".format( + self.__class__.__name__, self.batch_size, self.double_batch, self.generator, + ) diff --git a/torch-points3d/torch_points3d/datasets/instance/las_dataset_.py b/torch-points3d/torch_points3d/datasets/instance/las_dataset_.py new file mode 100644 index 0000000..4728bfd --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/instance/las_dataset_.py @@ -0,0 +1,894 @@ +import os +import sys +from collections import OrderedDict +from functools import partial +from glob import glob +from itertools import chain, product +from pathlib import Path +from typing import Sized, Iterator + +import geopandas as gpd +import laspy +import numpy as np +import pandas as pd +import pyproj +import scipy.stats as scstats +import torch +from omegaconf import OmegaConf +from plyfile import PlyData +from shapely.geometry import Point +from sklearn.neighbors import KDTree +from torch.utils.data import Sampler +from torch_geometric.data import Dataset, Data +from tqdm.auto import tqdm + +from torch_points3d.datasets.base_dataset import BaseDataset, save_used_properties +from torch_points3d.metrics.instance_tracker import InstanceTracker +from torch_points3d.models import model_interface + + +def read_pt(pt_file, feature_cols, delimiter: str): + crs = None + has_features = len(feature_cols) > 0 + if Path(pt_file).suffix in [".las", ".laz"]: + backend = laspy.compression.LazBackend(0) + if not backend.is_available(): + backend = laspy.compression.LazBackend(1) + if not backend.is_available(): + backend = laspy.compression.LazBackend(2) + loaded_file = laspy.read(pt_file, laz_backend=backend) + pos = np.stack([loaded_file.x, loaded_file.y, loaded_file.z], 1) + if has_features: + features = np.stack([getattr(loaded_file, feature) for feature in feature_cols], 1) + else: + features = None + + # get crs + for vlr in loaded_file.header.vlrs: + if isinstance(vlr, laspy.vlrs.known.WktCoordinateSystemVlr): + crs = vlr.string + elif Path(pt_file).suffix in [".ply"]: + loaded_file = PlyData.read(pt_file) + pos = np.stack([loaded_file.elements[0]["x"], loaded_file.elements[0]["y"], loaded_file.elements[0]["z"]], 1) + if has_features: + features = np.stack([loaded_file.elements[0][feat] for feat in feature_cols], 1) + else: + features = None + else: + # try to read as csv + loaded_file = pd.read_csv( + pt_file, header=None, engine="pyarrow", delimiter=delimiter, dtype=np.float32, skip_blank_lines=True + ) + pos = loaded_file.values[:, :3] # assumes first 3 values are positions + if has_features: + features = loaded_file[feature_cols] + else: + features = None + + return pos, features, crs + + +class Las(Dataset): + """loads all las files into memory and creates samples based on a label_df""" + + def __init__( + self, root, areas: dict, split: str, stats=None, + xy_radius=15., + transform=None, targets=None, feature_cols=None, feature_scaling_dict: dict = None, + pre_transform=None, pre_filter=None, processed_folder="processed", + in_memory: bool = False, pos_dict: dict = None, features_dict: dict = None, pos_tree_dict: dict = None, + ): + self.root = root + self.split = split + + self.processed_folder = processed_folder + + if pos_dict is not None or pos_tree_dict is not None: + assert pos_dict is not None and pos_tree_dict is not None, \ + "if any of pos or pos_tree are given, both need to be there" + + assert (len(feature_cols) > 0 and (features_dict is not None)) or len(feature_cols) == 0, \ + "need to give features, if pos is given and there are features" + self.pos_dict = {} if pos_dict is None else pos_dict + self.features_dict = {} if features_dict is None else features_dict + self.pos_tree_dict = {} if pos_tree_dict is None else pos_tree_dict + + self.areas = areas + + self.targets = targets + self.feature_cols = [] if feature_cols is None else feature_cols + self.stats = [] if stats is None else stats + # difference between measurement and pointclouds taken + self.radius = xy_radius + + # different types of targets + self.reg_targets = [target for target in self.targets if self.targets[target]["task"] == "regression"] + self.cls_targets = [target for target in self.targets if self.targets[target]["task"] == "classification"] + self.cls_targets_ = [f"{target}_" for target in self.targets if + self.targets[target]["task"] == "classification"] + self.mol_targets = [target for target in self.targets if self.targets[target]["task"] == "mol"] + + # if not give, calculate on given data + if feature_scaling_dict is None: + feature_scaling_dict = { + area_name: + { # feature: (center, scale) + "num_returns": (0., 5.), + "return_num": (0., 5.), } + for area_name in areas + } + self.feature_scaling_dict = feature_scaling_dict + + self.in_memory = in_memory + if in_memory: + self.memory = {} + + super().__init__( + root, transform, pre_transform, pre_filter + ) + # check if all areas are actually processed + for area_name in areas: + area = areas[area_name] + labels = area["labels"].query(f"{area['split_col']} == '{self.split}'") + if len(labels) > 0 and not (Path(self.processed_dir) / self.split / area_name / "done.flag").exists(): + print(f'Resuming processing, since {area_name} is not complete!', file=sys.stderr) + self.process() + + @property + def processed_dir(self) -> str: + return os.path.join(self.root, self.processed_folder) + + @property + def raw_file_names(self): + files = list(chain(*[[area["pt_files"]] for area in self.areas.values()])) + return files + + @property + def has_labels(self) -> bool: + return self.split in ["val", "test"] + + @property + def processed_file_names(self): + path = Path(self.processed_dir) / self.split + files = glob(str(path / f"*/*.pt")) + return files + + @property + def num_samples(self): + n = 0 + for area_name in self.areas: + area = self.areas[area_name] + + if (Path(self.processed_dir) / self.split / area_name / "done.flag").exists(): + n += len(list((Path(self.processed_dir) / self.split / area_name).glob("*.pt"))) + else: + n += len(area["labels"].query(f"{area['split_col']} == '{self.split}'")) + + return n + + def process(self): + for area_name in self.areas: + flag = (Path(self.processed_dir) / self.split / area_name / "done.flag") + area = self.areas[area_name] + + print(f"### start processing area: '{area_name}'") + if not flag.exists(): + + labels = area["labels"].query(f"{area['split_col']} == '{self.split}'") + if len(labels) == 0: + continue + + if area["type"] == "scene": + # can prepare this beforehand + pos, features, inner_label_point_idx, label_point_idx, labels = \ + self.process_scene_area_(area_name, labels) + + ### TODO reintroduce feature scaling + # if feature in feature_scaling: + # center, scale = feature_scaling.get(feature, (0., 1.)) + # else: + # # fill with iqr scaling + # center = np.median(feat) + # scale = (np.quantile(feat, 0.75) - np.quantile(feat, 0.25)) * 1.349 + # feature_scaling[feature] = (center, scale) + # features_sample.append((feat - center) / scale) + + print("Save samples and calculate stats") + (Path(self.processed_dir) / self.split).mkdir(exist_ok=True) + (Path(self.processed_dir) / self.split / area_name).mkdir(exist_ok=True) + file_idx = 0 + for idx in tqdm(range(len(labels))): + sample = labels.iloc[idx] + if area["type"] == "object": + # only load objects here instead of bulk loading before to avoid memory issues + pos, features, crs = read_pt(sample["pt_file"], self.feature_cols, area["delimiter"]) + + if area.get("check_pt_crs", True) and crs is not None and \ + not pyproj.CRS.is_exact_same(labels.crs, crs): + sample = labels.to_crs(crs).iloc[idx] + + # find points + label_centers = [[sample.geometry.x, sample.geometry.y]] + tree = KDTree(pos[:, :2]) + point_idxs = tree.query_radius(label_centers, self.radius)[0] + inner_point_idx = tree.query_radius(label_centers, self.radius / 2.)[0] + + elif area["type"] == "scene": + point_idxs = label_point_idx[idx] + inner_point_idx = inner_label_point_idx[idx] + + file = Path(self.processed_dir) / self.split / area_name / f"{file_idx}.pt" + if file.exists(): + continue + data = self.save_data_( + area_name, idx, sample, pos, features, + point_idxs, inner_point_idx + ) + if data is not None: + torch.save(data, file) + file_idx += 1 + flag.touch() + + def process_scene_area_(self, area_name, labels): + area = self.areas[area_name] + pos_tree = self.pos_tree_dict.get(area_name, None) + + if not pos_tree: + print(f"Load Las files") + pt = [read_pt(las_file, self.feature_cols, area["delimiter"]) for las_file in area["pt_files"]] + + pos = np.concatenate([p[0] for p in pt], 0) + if len(self.feature_cols) > 0: + features = np.concatenate([p[1] for p in pt], 0) + else: + features = None + + crs = np.stack([p[2] for p in pt], 0) + assert np.all(crs[0] == crs_ for crs_ in crs), "pt_files of an area need to be in same crs currently" + crs = crs[0] + + # fit this into a KDTree + print("Creating KDTree") + pos_tree = KDTree(pos[:, :2]) + + self.pos_dict[area_name] = pos + self.pos_tree_dict[area_name] = pos_tree + self.features_dict[area_name] = features + print("Query KDTree") + # restrict to bounds + if area.get("check_pt_crs", True) and crs is not None and not pyproj.CRS.is_exact_same(labels.crs, crs): + labels = labels.to_crs(crs) + + label_centers = np.stack([labels.geometry.x, labels.geometry.y], 1) + radius = self.radius + label_point_idx = self.pos_tree_dict[area_name].query_radius(label_centers, radius) + inner_label_point_idx = self.pos_tree_dict[area_name].query_radius(label_centers, radius / 2.) + return self.pos_dict[area_name], self.features_dict[area_name], inner_label_point_idx, label_point_idx, labels + + @property + def num_classes(self) -> int: + if not hasattr(self, "num_classes_"): + num_reg_classes = 0 + num_mol_classes = 0 + num_cls_classes = [] + if self.targets: + for target in self.targets: + task = self.targets[target]["task"] + if task == "classification": + num_cls_classes.append(len(self.targets[target]["class_names"])) + elif task == "regression": + num_reg_classes += 1 + elif task.lower() == "mol": + num_mixtures = self.targets[target].get("num_mixtures", 1) + num_mol_classes += num_mixtures * 3 + + self.num_reg_classes_ = num_reg_classes + self.num_mol_classes_ = num_mol_classes + self.num_cls_classes_ = num_cls_classes + + self.num_classes_ = self.num_reg_classes + self.num_mol_classes + int(np.sum(self.num_cls_classes)) + + return self.num_classes_ + + @property + def num_reg_classes(self) -> int: + if not hasattr(self, "num_reg_classes_"): + # init by calling num_classes + _ = self.num_classes + + return self.num_reg_classes_ + + @property + def num_mol_classes(self) -> int: + if not hasattr(self, "num_mol_classes_"): + # init by calling num_classes + _ = self.num_classes + + return self.num_mol_classes_ + + @property + def num_cls_classes(self) -> []: + if not hasattr(self, "num_cls_classes_"): + # init by calling num_classes + _ = self.num_classes + + return self.num_cls_classes_ + + def len(self): + return self.num_samples + + @staticmethod + def get_local_stats(points, postfix=""): + stats = {} + z = points[:, 2] + + z_stats = { + "h_mean": np.mean, + "h_std": np.std, + "h_coov": scstats.variation, + "h_kur": scstats.kurtosis, + "h_skew": scstats.skew, + } + + quantiles = [5, 10, 25, 50, 75, 90, 95, 99] + z_stats.update({f"h_q{i}": partial(np.quantile, q=i / 100) for i in quantiles}) + + def density_q(z, q): + # the proportion of points above the height percentiles + quant = np.quantile(z, q=q) + return len(z[z > quant]) / len(z) + + z_stats.update({f"d_q{i}": partial(density_q, q=i / 100) for i in quantiles}) + + tree = KDTree(points) + # create 1m grid spanning extend + xx = np.arange(points[:, 0].min(), points[:, 0].max(), 1) + yy = np.arange(points[:, 1].min(), points[:, 1].max(), 1) + zz = np.arange(points[:, 2].min(), points[:, 2].max(), 1) + grid = [[x, y, z] for x, y, z in product(xx, yy, zz)] + # get highest density in grid + if len(grid) > len(points): # use points directly if only few points present + grid = points + density = tree.kernel_density(grid, 1, kernel="gaussian").max() + stats["kde_h1"] = density + + for key in z_stats.keys(): + try: + value = z_stats[key](z) + except IndexError: + # return -1 if not enough values in quantiles + value = -1 + + stats[key + postfix] = value + + return stats + + def get(self, idx): + if self.in_memory: + if idx in self.memory.keys(): + data = self.memory[idx].clone() + else: + data = torch.load(self.processed_file_names[idx]) + self.memory[idx] = data.clone() + else: + data = torch.load(self.processed_file_names[idx]) + + del data.local_stats_keys + return data + + def save_data_(self, area_name: str, idx: int, sample, pos_: np.array, features_: np.array, + point_idxs: np.array, inner_point_idxs: np.array): + + x = self.center_pos(pos_[point_idxs], sample) + + data = { + "pos": x, + "height_m": sample["height_m"], + "mean_crown_diameter_m": sample["mean_crown_diameter_m"], + "DBH_cm": sample["DBH_cm"], + "species": sample["species"], + "source": sample["source"], + "date": sample["date"], + "quality": sample["quality"], + + } + + return data + + def covert_to_data_( + self, x, y_reg, y_reg_mask, y_mol, y_mol_mask, y_cls, y_cls_mask, features, area_name, local_stats, + local_stats_keys, stats + ): + x = torch.tensor(x, dtype=torch.float32) + y_reg = torch.tensor(y_reg, dtype=torch.float32) + y_reg_mask = torch.tensor(y_reg_mask, dtype=torch.bool) + y_mol = torch.tensor(y_mol, dtype=torch.float32) + y_mol_mask = torch.tensor(y_mol_mask, dtype=torch.bool) + y_cls[~y_cls_mask] = - 1 + y_cls = torch.tensor(y_cls, dtype=torch.long) + y_cls_mask = torch.tensor(y_cls_mask, dtype=torch.bool) + features = features if features is None else torch.tensor(features, dtype=torch.float32) + stats = torch.tensor(stats, dtype=torch.float32) + local_stats = torch.tensor(local_stats, dtype=torch.float32) + data = Data( + x=features, + y_reg=y_reg, y_reg_mask=y_reg_mask, + y_mol=y_mol, y_mol_mask=y_mol_mask, + y_cls=y_cls, y_cls_mask=y_cls_mask, + pos=x, stats=stats, area_name=area_name, + local_stats=local_stats, local_stats_keys=local_stats_keys + ) + return data + + def get_stats(self, x, inner_x, df): + # local stats + local_stats = self.get_local_stats(x) + local_stats.update(self.get_local_stats(inner_x, "_inner")) + local_stats_keys = list(local_stats.keys()) + local_stats = list(local_stats.values()) + # global stats + stats = df[self.stats] + return local_stats, local_stats_keys, stats + + def center_pos(self, x, df): + x_center = np.amin(x, axis=0, keepdims=True) + x_center[:, 0] = df.geometry.x + x_center[:, 1] = df.geometry.y + x -= x_center + return x + + +class LasDataset(BaseDataset): + def __init__(self, dataset_opt): + super().__init__(dataset_opt) + self.dataset_opt = dataset_opt + self.targets = dataset_opt.get("targets", None) + self.target_keys = list(self.targets.keys()) if self.targets is not None else None + self.features = dataset_opt.features + self.stats = dataset_opt.stats + self.xy_radius = dataset_opt.xy_radius + self.x_scale = dataset_opt.x_scale + self.y_scale = dataset_opt.y_scale + self.z_scale = dataset_opt.z_scale + self.double_batch = dataset_opt.get("double_batch", False) + self.log_train_metrics = dataset_opt.get("log_train_metrics", True) + + self.areas: dict = OrderedDict(OmegaConf.to_container(dataset_opt.areas)) + + self.reg_targets = [target for target in self.targets if self.targets[target]["task"] == "regression"] + self.reg_targets_idx = [self.targets[target]["task"] == "regression" for target in self.targets] + self.cls_targets = [target for target in self.targets if self.targets[target]["task"] == "classification"] + self.cls_targets_idx = [self.targets[target]["task"] == "classification" for target in self.targets] + self.cls_targets_ = [f"{target}_" for target in self.cls_targets] + self.mol_targets = [target for target in self.targets if self.targets[target]["task"] == "mol"] + self.mol_targets_idx = [self.targets[target]["task"] == "mol" for target in self.targets] + + processed_folder = dataset_opt.get("processed_folder", "processed") + + for area_name in self.areas: + area = self.areas[area_name] + + # set some standard params + area["delimiter"] = area.get("delimiter", dataset_opt.get("delimiter", ",")) + + # processing file lists + pt_files = area["pt_files"] + if isinstance(pt_files, (str, Path)): + pt_files = glob(str(Path(self._data_path) / "raw" / pt_files)) + elif isinstance(pt_files, list): + # iterating to list of files + unpacked_files = [] + for f in pt_files: + unpacked_files.extend(glob(str(Path(self._data_path) / "raw" / f))) + pt_files = unpacked_files + else: + raise Exception("pt_files need to be a str or a list of str (can use * expression)") + + labels = self.process_label_files_(area, area_name) + + labels.geometry = labels.centroid + + if area["type"] == "object": + # check if each label has a pt_file + def find_pt_file(id): + for ptf in pt_files: + # return first occurrence + if id in ptf: + return ptf + return "None" + + labels["pt_file"] = labels[area["pt_identifier"]].apply(find_pt_file) + + # removing sample without pt_file + n_samples = len(labels) + labels.query("pt_file != 'None'", inplace=True) + if len(labels) != n_samples: + print(f"Warning: {n_samples - len(labels)} removed due to missing pt_file") + + pt_files = labels["pt_file"].values.tolist() + + area["pt_files"] = pt_files + + split_col = area.get("split_col", dataset_opt.get("split_col", "split")) + area["split_col"] = split_col + # create split if fully labeled data available + if split_col not in labels.columns: + targets_must_be_present = np.array(area.get("targets_must_be_present", [True] * len(self.target_keys))) + lb = labels[np.array(self.target_keys)[targets_must_be_present]] + + # if no targets are fully available, only use this area for training + if (~targets_must_be_present).all() or lb.isna().all().all(): + labels[split_col] = "train" + else: + # no split available, create own + # only select those that have labels others are for training + partly_missing = lb.isna().all(1) + lables_partly_missing = labels[partly_missing] + lables_partly_missing[split_col] = "train" + + lables_full = labels[~partly_missing] + index = lables_full.index.values + + rs = np.random.RandomState(42) + val_ratio = area.get("val_ratio", .1) + test_ratio = area.get("test_ratio", .1) + + rs.shuffle(index) + + train_end = int(len(index) * (1 - (val_ratio + test_ratio))) + val_end = int(len(index) * (1 - test_ratio)) + train_idx = index[:train_end] + val_idx = index[train_end:val_end] + test_idx = index[val_end:] + + lables_full.loc[train_idx, split_col] = "train" + if val_ratio != 0 and len(val_idx) > 0: + lables_full.loc[val_idx, split_col] = "val" + if test_ratio != 0 and len(test_idx) > 0: + lables_full.loc[test_idx, split_col] = "test" + + labels = pd.concat([lables_partly_missing, lables_full]) + + if len(labels.query(f"['val', 'test'] in {split_col}")) == 0: + print(f"Warning: neither val nor test set present for {area_name}") + + area["labels"] = labels + val_set_avail = any( + [len(area["labels"].query(f"{area['split_col']} == 'val'")) > 0 for area in self.areas.values()]) + test_set_avail = any( + [len(area["labels"].query(f"{area['split_col']} == 'test'")) > 0 for area in self.areas.values()]) + + (self._data_path / (Path(processed_folder))).mkdir(exist_ok=True) + + feature_scaling_file = self._data_path / (Path(processed_folder) / "features_scaling.pt") + feature_scaling_dict = torch.load(feature_scaling_file) if feature_scaling_file.exists() else None + + in_memory = dataset_opt.get("in_memory", False) + + print("init train dataset") + self.train_dataset = Las( + self._data_path, areas=self.areas, split="train", + targets=self.targets, feature_cols=self.features, feature_scaling_dict=feature_scaling_dict, + stats=dataset_opt.stats, transform=self.train_transform, pre_transform=self.pre_transform, + processed_folder=processed_folder, + xy_radius=self.xy_radius, + in_memory=in_memory + ) + if not feature_scaling_file.exists(): + feature_scaling_dict = self.train_dataset.feature_scaling_dict + torch.save(feature_scaling_dict, feature_scaling_file) + + if val_set_avail: + print("init val dataset") + self.val_dataset = Las( + self._data_path, areas=self.areas, split="val", + targets=self.targets, feature_cols=self.features, feature_scaling_dict=feature_scaling_dict, + stats=dataset_opt.stats, transform=self.val_transform, pre_transform=self.pre_transform, + processed_folder=processed_folder, + xy_radius=self.xy_radius, + in_memory=in_memory, + pos_dict=self.train_dataset.pos_dict, features_dict=self.train_dataset.features_dict, + pos_tree_dict=self.train_dataset.pos_tree_dict + ) + + if test_set_avail: + print("init test dataset") + self.test_dataset = Las( + self._data_path, areas=self.areas, split="test", + targets=self.targets, feature_cols=self.features, feature_scaling_dict=feature_scaling_dict, + stats=dataset_opt.stats, transform=self.test_transform, pre_transform=self.pre_transform, + processed_folder=processed_folder, + xy_radius=self.xy_radius, + in_memory=in_memory, + pos_dict=self.train_dataset.pos_dict, features_dict=self.train_dataset.features_dict, + pos_tree_dict=self.train_dataset.pos_tree_dict + ) + + del self.train_dataset.pos_dict, self.train_dataset.pos_tree_dict, self.train_dataset.features_dict + + self.set_label_stats_() + + self.has_reg_targets = not np.isnan( + [area["train"][self.reg_targets_idx] for area in self.get_std_targets().values()]).all() + self.has_mol_targets = not np.isnan( + [area["train"][self.mol_targets_idx] for area in self.get_std_targets().values()]).all() + self.has_cls_targets = not np.isnan( + [area["train"][self.cls_targets_idx] for area in self.get_std_targets().values()]).all() + + def process_label_files_(self, area: dict, area_name: str): + label_files = area["label_files"] + # ensure labels file follows schemata: + # [file_1, ..., file_n] + if isinstance(label_files, (str, Path)): + label_files = [label_files] + + assert len(label_files) > 0, f"no labels given, check area {area_name}" + + labels = None + for lf in label_files: + lb = gpd.read_file(Path(self._data_path) / "raw" / lf) + + # put dummy point if no position exists (usually true for csv data) + lb.geometry = lb.geometry.apply(lambda g: Point(0, 0) if g is None else g) + + alias_targets = area.get("alias_targets", self.targets) + assert len(alias_targets) == len(self.targets), f"given target aliases for '{area_name}' have " \ + f"different lengths: {alias_targets} vs {self.targets}" + + target_metric_factor = area.get("target_metric_factor", None) + + # add targets if present else set to nan + for ori_target, alias_target in zip(self.targets, alias_targets): + task = self.targets[ori_target]["task"] + if alias_target in lb: + lb[ori_target] = lb[alias_target] + # assumes that classification targets will be not necessarily be numbers, but everything else is + if task in ["regression", "mol"]: + lb[ori_target] = pd.to_numeric(lb[ori_target], errors="coerce") + if target_metric_factor is not None: + lb[ori_target] *= target_metric_factor.get(ori_target, 1.0) + else: + lb[ori_target] = np.nan + + if task == "classification": + # also save numerical values according to given classes + lb[f"{ori_target}_"] = lb[ori_target].map( + self.targets[ori_target]["class_mapping"] + ).astype(float) + + # crs comparison + if labels is None: + labels = lb + crs = lb.crs + else: + if crs != lb.crs: + Warning("CRS of label files do not match, have to convert") + lb = lb.to_crs(crs) + labels = pd.concat([labels, lb]) + + # indicate fully/partly missing targets in label sample + n_labels = len(labels) + nans_allowed = area.get("nans_allowed", True) + fully_missing = labels[self.targets].isna().all(1).sum() + partly_missing = labels[self.targets].isna().any(1).sum() + partly_missing = abs(partly_missing - fully_missing) + if fully_missing > 0: + print(f"Info: {fully_missing} of {n_labels} labels fully missing in {area_name}") + if fully_missing == n_labels: + area["has_labels"] = False + if partly_missing > 0: + print(f"Info: {partly_missing} of {n_labels} labels partly missing in {area_name}") + if fully_missing + partly_missing == n_labels and not nans_allowed: + area["has_labels"] = False + + if not nans_allowed: + labels.dropna(axis=0, how="any", subset=self.targets, inplace=True) + print(f"Info: Removing all missing or partly missing samples as indicated by 'nans_allowed' in {area_name}") + + # apply filter query + query = area.get("label_query", None) + if query is not None: + labels.query(query, inplace=True) + if n_labels > len(labels): + Warning(f"Warning: ({n_labels - len(labels)} sample were " + f"filtered out according to: {query})") + + labels.set_index(np.arange(len(labels)), inplace=True) + return labels + + def set_label_stats_(self): + processed_dir = Path(os.path.join(self._data_path, self.dataset_opt.processed_folder)) + processed_dir.mkdir(exist_ok=True) + means_file = processed_dir / "mean_targets.pt" + std_file = processed_dir / "std_targets.pt" + min_file = processed_dir / "min_targets.pt" + max_file = processed_dir / "max_targets.pt" + corr_file = processed_dir / "corr_targets.pt" + + self.mean_targets_ = torch.load(means_file) if means_file.exists() else \ + self.get_stat_targets_(np.nanmean, means_file) + self.std_targets_ = torch.load(std_file) if std_file.exists() else \ + self.get_stat_targets_(np.nanstd, std_file) + self.min_targets_ = torch.load(min_file) if min_file.exists() else \ + self.get_stat_targets_(np.nanmin, min_file) + self.max_targets_ = torch.load(max_file) if max_file.exists() else \ + self.get_stat_targets_(np.nanmax, max_file) + + self.corr_targets_ = torch.load(corr_file) if corr_file.exists() else self.get_corr_targets_(corr_file) + + def create_dataloaders( + self, + model: model_interface.DatasetInterface, + batch_size: int, + shuffle: bool, + num_workers: int, + precompute_multi_scale: bool, + ): + self.train_sampler = RandomSampler(self.train_dataset, True, batch_size, self.double_batch) + super().create_dataloaders(model, batch_size, shuffle, num_workers, precompute_multi_scale) + + def get_std_targets(self): + return self.std_targets_ + + def get_mean_targets(self): + return self.mean_targets_ + + def get_min_targets(self): + return self.min_targets_ + + def get_max_targets(self): + return self.max_targets_ + + def get_stat_targets_(self, stat_fn, file_name: (str, Path) = None): + dict = OrderedDict() + targets = [f"{target}_" if self.targets[target]["task"] == "classification" else target for target in + self.targets] + + for area_name in self.areas: + # TODO also uses labels that were not used due to too few points + sc = self.areas[area_name]["split_col"] + labels = self.areas[area_name]["labels"] + dict[area_name] = {} + + dict[area_name] = {"train": stat_fn(labels.query(f"{sc} == 'train'")[targets].values, 0), } + if self.val_dataset is not None and labels.query(f"{sc} == 'val'").shape[0] > 1: + dict[area_name].update( + {"val": stat_fn(labels.query(f"{sc} == 'val'")[targets].values, 0), } + ) + if self.test_dataset is not None and labels.query(f"{sc} == 'test'").shape[0] > 1: + dict[area_name].update( + {"test": stat_fn(labels.query(f"{sc} == 'test'")[targets].values, 0), } + ) + + if file_name is not None: + torch.save(dict, file_name) + + return dict + + def get_corr_targets(self): + return self.corr_targets_ + + def get_corr_targets_(self, file_name: (str, Path) = None): + dict = OrderedDict() + targets = [f"{target}_" if self.targets[target]["task"] == "classification" else target for target in + self.targets] + + for area_name in self.areas: + sc = self.areas[area_name]["split_col"] + labels = self.areas[area_name]["labels"] + + dict[area_name] = {"train": labels.query(f"{sc} == 'train'")[targets].corr().values, } + if self.val_dataset is not None and labels.query(f"{sc} == 'val'").shape[0] > 1: + dict[area_name].update( + {"val": labels.query(f"{sc} == 'val'")[targets].corr().values, } + ) + if self.test_dataset is not None and labels.query(f"{sc} == 'test'").shape[0] > 1: + dict[area_name].update( + {"test": labels.query(f"{sc} == 'test'")[targets].corr().values, } + ) + + if file_name is not None: + torch.save(dict, file_name) + return dict + + def get_tracker(self, wandb_log: bool, tensorboard_log: bool): + """Factory method for the tracker + Arguments: + wandb_log - Log using weight and biases + tensorboard_log - Log using tensorboard + Returns: + [BaseTracker] -- tracker + """ + return InstanceTracker(self, wandb_log=wandb_log, use_tensorboard=tensorboard_log, + log_train_metrics=self.log_train_metrics) + + @property # type: ignore + @save_used_properties + def num_reg_classes(self) -> int: + if self.train_dataset: + return self.train_dataset.num_reg_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_reg_classes + else: + return self.test_dataset.num_reg_classes + elif self.val_dataset is not None: + return self.val_dataset.num_reg_classes + else: + raise NotImplementedError() + + @property # type: ignore + @save_used_properties + def num_mol_classes(self) -> int: + if self.train_dataset: + return self.train_dataset.num_mol_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_mol_classes + else: + return self.test_dataset.num_mol_classes + elif self.val_dataset is not None: + return self.val_dataset.num_mol_classes + else: + raise NotImplementedError() + + @property # type: ignore + @save_used_properties + def num_cls_classes(self) -> int: + if self.train_dataset: + return self.train_dataset.num_cls_classes + elif self.test_dataset is not None: + if isinstance(self.test_dataset, list): + return self.test_dataset[0].num_cls_classes + else: + return self.test_dataset.num_cls_classes + elif self.val_dataset is not None: + return self.val_dataset.num_cls_classes + else: + raise NotImplementedError() + + +class RandomSampler(Sampler[int]): + r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. + If with replacement, then user can specify :attr:`num_samples` to draw. + + Args: + data_source (Dataset): dataset to sample from + num_samples (int): number of samples to draw, default=`len(dataset)`. + generator (Generator): Generator used in sampling. + """ + data_source: Sized + + def __init__(self, data_source: Sized, drop_last: bool, batch_size: int, double_batch: bool) -> None: + super().__init__(data_source) + self.data_source = data_source + self.generator = None + self._num_samples = None + self.batch_size = batch_size + self.drop_last = drop_last + self.double_batch = double_batch + + @property + def num_samples(self) -> int: + # dataset size might change at runtime + if self._num_samples is None: + return len(self.data_source) + return self._num_samples + + def __iter__(self) -> Iterator[int]: + if self.generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = self.generator + + iterator = torch.randperm(self.num_samples, generator=generator).tolist() + if self.double_batch: + iterator = np.array([[k, k] for k in iterator]).flatten().tolist() + iterator = iterator[:(self.num_samples // self.batch_size) * self.batch_size] + + yield from iterator + + def __len__(self) -> int: + return self.num_samples diff --git a/torch-points3d/torch_points3d/datasets/multiscale_data.py b/torch-points3d/torch_points3d/datasets/multiscale_data.py new file mode 100644 index 0000000..c66a746 --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/multiscale_data.py @@ -0,0 +1,165 @@ +from typing import List, Optional +import torch +import copy +import torch_geometric +from torch_geometric.data import Data +from torch_geometric.data import Batch + + +class MultiScaleData(Data): + def __init__( + self, + x=None, + y=None, + pos=None, + multiscale: Optional[List[Data]] = None, + upsample: Optional[List[Data]] = None, + **kwargs, + ): + super().__init__(x=x, y=y, pos=pos, multiscale=multiscale, upsample=upsample, **kwargs) + + def apply(self, func, *keys): + r"""Applies the function :obj:`func` to all tensor and Data attributes + :obj:`*keys`. If :obj:`*keys` is not given, :obj:`func` is applied to + all present attributes. + """ + for key, item in self(*keys): + if torch.is_tensor(item): + self[key] = func(item) + for scale in range(self.num_scales): + self.multiscale[scale] = self.multiscale[scale].apply(func) + + for up in range(self.num_upsample): + self.upsample[up] = self.upsample[up].apply(func) + return self + + @property + def num_scales(self): + """ Number of scales in the multiscale array + """ + return len(self.multiscale) if hasattr(self, "multiscale") and self.multiscale else 0 + + @property + def num_upsample(self): + """ Number of upsample operations + """ + return len(self.upsample) if hasattr(self, "upsample") and self.upsample else 0 + + @classmethod + def from_data(cls, data): + ms_data = cls() + for k, item in data: + ms_data[k] = item + return ms_data + + +class MultiScaleBatch(MultiScaleData): + @staticmethod + def from_data_list(data_list, follow_batch=[]): + r"""Constructs a batch object from a python list holding + :class:`torch_geometric.data.Data` objects. + The assignment vector :obj:`batch` is created on the fly. + Additionally, creates assignment batch vectors for each key in + :obj:`follow_batch`.""" + for data in data_list: + assert isinstance(data, MultiScaleData) + num_scales = data_list[0].num_scales + for data_entry in data_list: + assert data_entry.num_scales == num_scales, "All data objects should contain the same number of scales" + num_upsample = data_list[0].num_upsample + for data_entry in data_list: + assert data_entry.num_upsample == num_upsample, "All data objects should contain the same number of scales" + + # Build multiscale batches + multiscale = [] + for scale in range(num_scales): + ms_scale = [] + for data_entry in data_list: + ms_scale.append(data_entry.multiscale[scale]) + multiscale.append(from_data_list_token(ms_scale)) + + # Build upsample batches + upsample = [] + for scale in range(num_upsample): + upsample_scale = [] + for data_entry in data_list: + upsample_scale.append(data_entry.upsample[scale]) + upsample.append(from_data_list_token(upsample_scale)) + + # Create batch from non multiscale data + for data_entry in data_list: + del data_entry.multiscale + del data_entry.upsample + batch = Batch.from_data_list(data_list) + batch = MultiScaleBatch.from_data(batch) + batch.multiscale = multiscale + batch.upsample = upsample + + if torch_geometric.is_debug_enabled(): + batch.debug() + + return batch + + +def from_data_list_token(data_list, follow_batch=[]): + """ This is pretty a copy paste of the from data list of pytorch geometric + batch object with the difference that indexes that are negative are not incremented + """ + + keys = [set(data.keys) for data in data_list] + keys = list(set.union(*keys)) + assert "batch" not in keys + + batch = Batch() + batch.__data_class__ = data_list[0].__class__ + batch.__slices__ = {key: [0] for key in keys} + + for key in keys: + batch[key] = [] + + for key in follow_batch: + batch["{}_batch".format(key)] = [] + + cumsum = {key: 0 for key in keys} + batch.batch = [] + for i, data in enumerate(data_list): + for key in data.keys: + item = data[key] + if torch.is_tensor(item) and item.dtype != torch.bool and cumsum[key] > 0: + mask = item >= 0 + item[mask] = item[mask] + cumsum[key] + if torch.is_tensor(item): + size = item.size(data.__cat_dim__(key, data[key])) + else: + size = 1 + batch.__slices__[key].append(size + batch.__slices__[key][-1]) + cumsum[key] += data.__inc__(key, item) + batch[key].append(item) + + if key in follow_batch: + item = torch.full((size,), i, dtype=torch.long) + batch["{}_batch".format(key)].append(item) + + num_nodes = data.num_nodes + if num_nodes is not None: + item = torch.full((num_nodes,), i, dtype=torch.long) + batch.batch.append(item) + + if num_nodes is None: + batch.batch = None + + for key in batch.keys: + item = batch[key][0] + if torch.is_tensor(item): + batch[key] = torch.cat( + batch[key], dim=data_list[0].__cat_dim__(key, item)) + elif isinstance(item, int) or isinstance(item, float): + batch[key] = torch.tensor(batch[key]) + else: + raise ValueError( + "Unsupported attribute type {} : {}".format(type(item), item)) + + if torch_geometric.is_debug_enabled(): + batch.debug() + + return batch.contiguous() diff --git a/torch-points3d/torch_points3d/datasets/samplers.py b/torch-points3d/torch_points3d/datasets/samplers.py new file mode 100644 index 0000000..e7c6877 --- /dev/null +++ b/torch-points3d/torch_points3d/datasets/samplers.py @@ -0,0 +1,31 @@ +import torch +import numpy as np +from torch.utils.data import Sampler + +class BalancedRandomSampler(Sampler): + r"""This sampler is responsible for creating balanced batch based on the class distribution. + It is implementing a replacement=True strategy for indices selection + """ + def __init__(self, labels, replacement=True): + + self.num_samples = len(labels) + + self.idx_classes, self.counts = np.unique(labels, return_counts=True) + self.indices = { + idx: np.argwhere(labels == idx).flatten() for idx in self.idx_classes + } + + def __iter__(self): + indices = [] + for _ in range(self.num_samples): + idx = np.random.choice(self.idx_classes) + indice = int(np.random.choice(self.indices[idx])) + indices.append(indice) + return iter(indices) + + def __len__(self): + return self.num_samples + + def __repr__(self): + return "{}(num_samples={})".format(self.__class__.__name__, self.num_samples) + diff --git a/torch-points3d/torch_points3d/metrics/__init__.py b/torch-points3d/torch_points3d/metrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/metrics/base_tracker.py b/torch-points3d/torch_points3d/metrics/base_tracker.py new file mode 100644 index 0000000..9da85e3 --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/base_tracker.py @@ -0,0 +1,138 @@ +import logging +import os +from typing import Dict, Any + +import torch +import torchnet as tnt +import wandb +from torch.utils.tensorboard import SummaryWriter + +from torch_points3d.models import model_interface + +log = logging.getLogger(__name__) + + +def meter_value(meter, dim=0): + return float(meter.value()[dim]) if meter.n > 0 else 0.0 + + +class BaseTracker: + def __init__(self, stage: str, wandb_log: bool, use_tensorboard: bool): + self._wandb = wandb_log + self._use_tensorboard = use_tensorboard + self._tensorboard_dir = os.path.join(os.getcwd(), "tensorboard") + self._n_iter = 0 + self._finalised = False + self._conv_type = None + + if self._use_tensorboard: + log.info( + "Access tensorboard with the following command ".format(self._tensorboard_dir) + ) + self._writer = SummaryWriter(log_dir=self._tensorboard_dir) + + def reset(self, stage="train"): + self._stage = stage + self._loss_meters = {} + self._finalised = False + + def get_loss(self) -> Dict[str, Any]: + metrics = {} + for key, loss_meter in self._loss_meters.items(): + value = meter_value(loss_meter, dim=0) + if value: + metrics[key] = meter_value(loss_meter, dim=0) + return metrics + + def get_metrics(self, verbose=False) -> Dict[str, Any]: + return self.get_loss() + + @property + def metric_func(self): + self._metric_func = {"loss": min} + return self._metric_func + + def track(self, model: model_interface.TrackerInterface, **kwargs): + if self._finalised: + raise RuntimeError("Cannot track new values with a finalised tracker, you need to reset it first") + losses = self._convert(model.get_current_losses()) + self._append_losses(losses) + + def finalise(self, *args, **kwargs): + """Lifecycle method that is called at the end of an epoch. Use this to compute + end of epoch metrics. + """ + self._finalised = True + + def _append_losses(self, losses): + for key, loss in losses.items(): + if loss is None: + continue + loss_key = "%s_%s" % (self._stage, key) + if loss_key not in self._loss_meters: + self._loss_meters[loss_key] = tnt.meter.AverageValueMeter() + self._loss_meters[loss_key].add(loss) + + @staticmethod + def _convert(x): + if torch.is_tensor(x): + return x.detach().cpu().numpy() + else: + return x + + def publish_to_tensorboard(self, metrics, step): + for metric_name, metric_value in metrics.items(): + if isinstance(metric_value, (wandb.Table, wandb.viz.CustomChart)): # don't add table to postfix + continue + metric_name = "{}/{}".format(metric_name.replace(self._stage + "_", ""), self._stage) + self._writer.add_scalar(metric_name, metric_value, step) + + @staticmethod + def _remove_stage_from_metric_keys(stage, metrics): + new_metrics = {} + for metric_name, metric_value in metrics.items(): + if isinstance(metric_value, (wandb.Table, wandb.viz.CustomChart)): # don't add table to postfix + continue + new_metrics[metric_name.replace(stage + "_", "")] = metric_value + return new_metrics + + def publish_to_wandb(self, metrics, epoch): + wandb_metrics = metrics.copy() + wandb_metrics["epoch"] = epoch + wandb.log(wandb_metrics) + + def publish_metrics(self, metrics, epoch): + if self._wandb: + self.publish_to_wandb(metrics, epoch) + + if self._use_tensorboard: + self.publish_to_tensorboard(metrics, epoch) + + def get_publish_metrics(self, epoch): + """Publishes the current metrics to wandb and tensorboard + Arguments: + step: current epoch + """ + metrics = self.get_metrics() + + return { + "stage": self._stage, + "epoch": epoch, + "current_metrics": self._remove_stage_from_metric_keys(self._stage, metrics), + "all_metrics": metrics + } + + def print_summary(self): + metrics = self.get_loss() + log.info("".join(["=" for i in range(50)])) + for key, value in metrics.items(): + log.info(" {} = {}".format(key, value)) + log.info("".join(["=" for i in range(50)])) + + @staticmethod + def _dict_to_str(dictionnary): + string = "{" + for key, value in dictionnary.items(): + string += "%s: %.2f," % (str(key), value) + string += "}" + return string diff --git a/torch-points3d/torch_points3d/metrics/colored_tqdm.py b/torch-points3d/torch_points3d/metrics/colored_tqdm.py new file mode 100644 index 0000000..d91676d --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/colored_tqdm.py @@ -0,0 +1,40 @@ +from tqdm.auto import tqdm +from collections import OrderedDict +from numbers import Number +import numpy as np + +from torch_points3d.utils.colors import COLORS + + +class Coloredtqdm(tqdm): + def set_postfix(self, ordered_dict=None, refresh=True, color=None, round=4, **kwargs): + postfix = OrderedDict([] if ordered_dict is None else ordered_dict) + + for key in sorted(kwargs.keys()): + postfix[key] = kwargs[key] + + for key in postfix.keys(): + if isinstance(postfix[key], Number): + postfix[key] = self.format_num_to_k(np.round(postfix[key], round), k=round + 1) + if isinstance(postfix[key], str): + postfix[key] = str(postfix[key]) + if len(postfix[key]) != round: + postfix[key] += (round - len(postfix[key])) * " " + + if color is not None: + self.postfix = color + else: + self.postfix = "" + + self.postfix += ", ".join(key + "=" + postfix[key] for key in postfix.keys()) + if color is not None: + self.postfix += COLORS.END_TOKEN + + if refresh: + self.refresh() + + def format_num_to_k(self, seq, k=4): + seq = str(seq) + length = len(seq) + out = seq + " " * (k - length) if length < k else seq + return out if length < k else seq[:k] diff --git a/torch-points3d/torch_points3d/metrics/confusion_matrix.py b/torch-points3d/torch_points3d/metrics/confusion_matrix.py new file mode 100644 index 0000000..139415a --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/confusion_matrix.py @@ -0,0 +1,118 @@ +import os + +import numpy as np +import torch + + +class ConfusionMatrix: + """Streaming interface to allow for any source of predictions. + Initialize it, count predictions one by one, then print confusion matrix and intersection-union score""" + + def __init__(self, cls_names): + self.cls_names = np.array(cls_names) + self.n_cls = len(cls_names) + self.confusion_matrix = None + + @staticmethod + def create_from_matrix(confusion_matrix): + assert confusion_matrix.shape[0] == confusion_matrix.shape[1] + matrix = ConfusionMatrix(confusion_matrix.shape[0]) + matrix.confusion_matrix = confusion_matrix + return matrix + + def count_predicted_batch(self, ground_truth_vec, predicted): + assert predicted.max() < self.n_cls + batch_confusion = torch.bincount( + self.n_cls * ground_truth_vec.int() + predicted, minlength=self.n_cls ** 2 + ).reshape(self.n_cls, self.n_cls) + if self.confusion_matrix is None: + self.confusion_matrix = batch_confusion + else: + self.confusion_matrix += batch_confusion + + def get_count(self, ground_truth, predicted): + """labels are integers from 0 to number_of_labels-1""" + return self.confusion_matrix[ground_truth][predicted] + + def get_confusion_matrix(self): + """returns list of lists of integers; use it as result[ground_truth][predicted] + to know how many samples of class ground_truth were reported as class predicted""" + return self.confusion_matrix + + def get_stats(self): + cmat = self.confusion_matrix + stats = {} + class_stats = {} + numel = cmat.sum(1) + mask = numel > 0 + if mask.sum() == 0: # nothing to log + return stats + tp = torch.diag(cmat)[mask] + stats["tp"] = tp.sum().item() + fp = (cmat.sum(0)[mask] - tp) + stats["fp"] = fp.sum().item() + fn = (cmat.sum(1)[mask] - tp) + stats["acc"] = (tp.sum() / numel.sum()).item() + + # macro statistics + acc = (tp / numel[mask]) + stats["macc"] = acc.mean().item() + + precision = tp / (tp + fp + torch.finfo(torch.float32).eps) + stats["precision"] = precision.mean().item() + + recall = tp / (tp + fn + torch.finfo(torch.float32).eps) + stats["recall"] = recall.mean().item() + + f1 = 2 * ((precision * recall) / (precision + recall + torch.finfo(torch.float32).eps)) + stats["f1"] = f1.mean().item() + + # class stats + for i, cls_name in enumerate(self.cls_names[mask.cpu()]): + class_stats["acc", cls_name] = acc[i].item() + class_stats["tp", cls_name] = tp[i].item() + class_stats["recall", cls_name] = recall[i].item() + class_stats["precision", cls_name] = precision[i].item() + class_stats["f1", cls_name] = f1[i].item() + + """ + # normalize conf matrix + cmat_sum = cmat.sum(axis=1, keepdim=True) + cmat_sum += cmat_sum == 0 # avoid nans by displaying 0 + cmatn = cmat / cmat_sum + """ + return stats, class_stats, cmat + + +def save_confusion_matrix(cm, path2save, ordered_names): + import seaborn as sns + import matplotlib.pyplot as plt + + sns.set(font_scale=5) + + template_path = os.path.join(path2save, "{}.svg") + # PRECISION + cmn = cm.astype("float") / cm.sum(axis=-1)[:, np.newaxis] + cmn[np.isnan(cmn) | np.isinf(cmn)] = 0 + fig, ax = plt.subplots(figsize=(31, 31)) + sns.heatmap( + cmn, annot=True, fmt=".2f", xticklabels=ordered_names, yticklabels=ordered_names, annot_kws={"size": 20} + ) + # g.set_xticklabels(g.get_xticklabels(), rotation = 35, fontsize = 20) + plt.ylabel("Actual") + plt.xlabel("Predicted") + path_precision = template_path.format("precision") + plt.savefig(path_precision, format="svg") + + # RECALL + cmn = cm.astype("float") / cm.sum(axis=0)[np.newaxis, :] + cmn[np.isnan(cmn) | np.isinf(cmn)] = 0 + fig, ax = plt.subplots(figsize=(31, 31)) + sns.heatmap( + cmn, annot=True, fmt=".2f", xticklabels=ordered_names, yticklabels=ordered_names, annot_kws={"size": 20} + ) + # g.set_xticklabels(g.get_xticklabels(), rotation = 35, fontsize = 20) + plt.ylabel("Actual") + plt.xlabel("Predicted") + path_recall = template_path.format("recall") + plt.savefig(path_recall, format="svg") diff --git a/torch-points3d/torch_points3d/metrics/instance_tracker.py b/torch-points3d/torch_points3d/metrics/instance_tracker.py new file mode 100644 index 0000000..e476d6f --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/instance_tracker.py @@ -0,0 +1,264 @@ +import logging +from collections import OrderedDict +from typing import Dict, Any + +import numpy as np +import torch +import wandb +from torchnet.meter import MSEMeter + +from torch_points3d.metrics.base_tracker import BaseTracker +from torch_points3d.metrics.confusion_matrix import ConfusionMatrix +from torch_points3d.metrics.meters.maemeter import MAEMeter +from torch_points3d.metrics.meters.r2meter import R2Meter +from torch_points3d.models import model_interface + + +class InstanceTracker(BaseTracker): + def __init__(self, dataset, stage="train", wandb_log=False, use_tensorboard: bool = False, + log_train_metrics: bool = True): + """ This is a generic tracker for instance prediction tasks. + It uses a confusion matrix in the back-end to track results. + Use the tracker to track an epoch. + You can use the reset function before you start a new epoch + Arguments: + dataset -- dataset to track (used for the number of classes) + Keyword Arguments: + stage {str} -- current stage. (train, validation, test, etc...) (default: {"train"}) + wandb_log {str} -- Log using weight and biases + """ + super(InstanceTracker, self).__init__(stage, wandb_log, use_tensorboard) + self.has_reg_targets = dataset.has_reg_targets + self.reg_targets_idx = dataset.reg_targets_idx + self.reg_targets = dataset.reg_targets + + self.has_mol_targets = dataset.has_mol_targets + self.mol_targets_idx = dataset.mol_targets_idx + self.mol_targets = dataset.mol_targets + + self.has_cls_targets = dataset.has_cls_targets + self.cls_targets = dataset.cls_targets + self.cls_targets_idx = dataset.cls_targets_idx + self.cls_names = OrderedDict({ + target_name: dataset.targets[target_name]["class_names"] for target_name in dataset.targets + if target_name in self.cls_targets + }) + + self.area_names = dataset.areas.keys() + self.area_name_map = OrderedDict({area_name: i for i, area_name in enumerate(self.area_names)}) + + self.n_targets = dataset.num_classes + + # for r2 score + self.target_means = dataset.get_mean_targets() + self.log_train_metrics = log_train_metrics + + self.reset(stage) + # Those map subsentences to their optimization functions + self._metric_goals = { + "loss": "minimize", + } + self._metric_func = { + "loss": min, + } + if self.has_reg_targets or self.has_mol_targets: + self._metric_goals.update({ + "_rmse": "minimize", + "_mae": "minimize", + "_r2": "maximize", + }) + self._metric_func.update({ + "_rmse": min, + # "mae": min, + # "r2": max, + }) + if self.has_reg_targets: + self._metric_func.update({"loss_reg": min}) + if self.has_mol_targets: + self._metric_func.update({"loss_mol": min}) + if self.has_cls_targets: + self._metric_goals.update({ + "acc": "maximize", + "macc": "maximize", + "_f1": "maximize", + }) + self._metric_func.update({ + # "acc": max, + # "macc": max, + "_f1": max, + "loss_cls": min, + }) + + if wandb_log: + self.wandb_metrics = [] + + def reset(self, stage="train"): + super().reset(stage=stage) + if (stage == "train" and self.log_train_metrics) or stage != "train": + area_names = [area_name for area_name in self.area_names + if self.target_means[area_name].get(stage, None) is not None] + area_names.append("total") + if self.has_reg_targets or self.has_mol_targets: + targets = self.reg_targets + self.mol_targets + targets_idx = np.logical_or(self.reg_targets_idx, self.mol_targets_idx) + self._rmse = {area_name: {} for area_name in area_names} + self._mae = {area_name: {} for area_name in area_names} + self._r2 = {area_name: {} for area_name in area_names} + for i, target_name in enumerate(targets): + for area_name in area_names: + if np.isnan(self.target_means[area_name][stage][targets_idx][i]).all(): + continue + self._rmse[area_name][target_name] = MSEMeter(root=True) + self._mae[area_name][target_name] = MAEMeter() + self._r2[area_name][target_name] = R2Meter(self.target_means[area_name][stage][targets_idx][i]) + + if self.has_cls_targets: + self._confusion_matrix = {area_name: {} for area_name in area_names} + + for i, target_name in enumerate(self.cls_targets): + for area_name in area_names: + if np.isnan(self.target_means[area_name][stage][self.cls_targets_idx][i]).all(): + continue + self._confusion_matrix[area_name][target_name] = ConfusionMatrix(self.cls_names[target_name]) + + @staticmethod + def detach_tensor(tensor): + if torch.torch.is_tensor(tensor): + tensor = tensor.detach() + return tensor + + def track(self, model: model_interface.InstanceTrackerInterface, **kwargs): + """ Add current model predictions (usually the result of a batch) to the tracking + """ + super().track(model) + + if (self._stage == "train" and self.log_train_metrics) or self._stage != "train": + areas = model.data_visual["area_name"] + areas = torch.tensor([self.area_name_map[an] for an in areas]) + + # regression + if self.has_reg_targets: + outputs = model.get_reg_output() + targets = model.get_reg_input() + + track_stats = self.track_numerical_stats + target_names = self.reg_targets + + self.track_iterate_areas_targets(areas, outputs, target_names, targets, track_stats) + + if self.has_mol_targets: + outputs = model.get_mol_output() + targets = model.get_mol_input() + + track_stats = self.track_numerical_stats + target_names = self.mol_targets + + self.track_iterate_areas_targets(areas, outputs, target_names, targets, track_stats) + + if self.has_cls_targets: + targets = model.get_cls_input() + outputs = torch.stack([cls_out.argmax(1) for cls_out in model.get_cls_output()], 1) + + track_stats = self.track_classification_stats + target_names = self.cls_names + + self.track_iterate_areas_targets(areas, outputs, target_names, targets, track_stats) + + def track_iterate_areas_targets(self, areas, outputs, target_names, targets, track_stats): + # ignore nan values + targets_nan = torch.isnan(targets) if targets.dtype == torch.float else targets == -1 + no_nans = ~targets_nan # ~(outputs_nan | targets_nan) + if no_nans.any(): + for i, target_name in enumerate(target_names): + no_nan = no_nans[:, i] + # skip if no real values are present + if not no_nan.any(): + continue + out = outputs[:, i][no_nan] + target = targets[:, i][no_nan] + area = areas[no_nan.cpu()] + + for area_name in self.area_names: + area_idx = area == self.area_name_map[area_name] + if area_idx.any(): + track_stats(area_idx, area_name, out, target, target_name) + track_stats(torch.ones_like(area_idx), "total", out, target, target_name) + + def track_classification_stats(self, area_idx, area_name, out, target, target_name): + self._confusion_matrix[area_name][target_name].count_predicted_batch(target[area_idx], out[area_idx]) + + def track_numerical_stats(self, area_idx, area_name, out, target, target_name): + self._rmse[area_name][target_name].add(out[area_idx], target[area_idx]) + self._mae[area_name][target_name].add(out[area_idx], target[area_idx]) + self._r2[area_name][target_name].add(out[area_idx], target[area_idx]) + + def get_metrics(self, verbose=False) -> Dict[str, Any]: + """ Returns a dictionary of all metrics and losses being tracked + """ + metrics = super().get_loss() + if (self._stage == "train" and self.log_train_metrics) or self._stage != "train": + area_names = list(self.area_names) + area_names.append("total") + for area_name in area_names: + if self.has_reg_targets or self.has_mol_targets: + if self._r2.get(area_name, None) is not None: + for target_name in self.reg_targets + self.mol_targets: + if self._r2[area_name].get(target_name, None) is None: + continue + metrics[f"{self._stage}_{area_name}_{target_name}_rmse"] = \ + self._rmse[area_name][target_name].value() + metrics[f"{self._stage}_{area_name}_{target_name}_mae"] = \ + self._mae[area_name][target_name].value() + metrics[f"{self._stage}_{area_name}_{target_name}_r2"] = \ + self._r2[area_name][target_name].value() + if self.has_cls_targets: + if self._confusion_matrix.get(area_name, None) is not None: + for target_name in self.cls_targets: + cmat_obj = self._confusion_matrix[area_name].get(target_name, None) + if cmat_obj is None or cmat_obj.confusion_matrix is None: + continue + + stats, class_stats, cmat = cmat_obj.get_stats() + for metric in stats: + metrics[f"{self._stage}_{area_name}_{target_name}_{metric}"] = stats[metric] + + for metric, cls_name in class_stats: + metrics[f"{self._stage}_{area_name}_{target_name}_{cls_name}:{metric}"] = \ + class_stats[metric, cls_name] + + if self._wandb: + data = [] + for i in range(cmat_obj.n_cls): + for j in range(cmat_obj.n_cls): + data.append([cmat_obj.cls_names[i], cmat_obj.cls_names[j], cmat[i, j]]) + + cmat_table = wandb.Table(columns=["Actual", "Predicted", "nPredictions"], data=data) + fields = {"Actual": "Actual", "Predicted": "Predicted", "nPredictions": "nPredictions"} + cmat_plot = wandb.plot_table( + "wandb/confusion_matrix/v1", + cmat_table, + fields, + {"title": f"{self._stage}; {area_name}; {target_name}"}, + ) + metrics[f"{self._stage}_{area_name}_{target_name}_cmat"] = cmat_plot + + if self._wandb: + # add metric to wandb if not there already + new_metrics = [metric for metric in metrics if metric not in self.wandb_metrics] + for metric in new_metrics: + m_func = [m for m in self._metric_goals if m in metric] + if len(m_func) == 0: + m_func = goal = None + else: + try: + m_func, goal = self._metric_goals[m_func[0]][:3], self._metric_goals[m_func[0]] + except Exception as e: + logging.warning(f"{str(e)}\n Something went wrong during wandb metric collection") + wandb.define_metric(metric, step_metric="epoch", summary=m_func, goal=goal) + self.wandb_metrics.append(metric) + + return metrics + + @property + def metric_func(self): + return self._metric_func diff --git a/torch-points3d/torch_points3d/metrics/meters.py b/torch-points3d/torch_points3d/metrics/meters.py new file mode 100644 index 0000000..e3aa634 --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/meters.py @@ -0,0 +1,155 @@ +import math +import torch + + +class Meter(object): + """Meters provide a way to keep track of important statistics in an online manner. + This class is abstract, but provides a standard interface for all meters to follow. + """ + + def reset(self): + """Resets the meter to default settings.""" + + def add(self, value): + """Log a new value to the meter + Args: + value: Next restult to include. + """ + + def value(self): + """Get the value of the meter in the current state.""" + + +# This code has been taken from Pytorch Torchnet has it contains a bug with an assert +# https://github.com/pytorch/tnt/issues/131 +class APMeter(Meter): + """ + The APMeter measures the average precision per class. + The APMeter is designed to operate on `NxK` Tensors `output` and + `target`, and optionally a `Nx1` Tensor weight where (1) the `output` + contains model output scores for `N` examples and `K` classes that ought to + be higher when the model is more convinced that the example should be + positively labeled, and smaller when the model believes the example should + be negatively labeled (for instance, the output of a sigmoid function); (2) + the `target` contains only values 0 (for negative examples) and 1 + (for positive examples); and (3) the `weight` ( > 0) represents weight for + each sample. + """ + + def __init__(self): + super(APMeter, self).__init__() + self.reset() + + def reset(self): + """Resets the meter with empty member variables""" + self.scores = torch.FloatTensor(torch.FloatStorage()) + self.targets = torch.LongTensor(torch.LongStorage()) + self.weights = torch.FloatTensor(torch.FloatStorage()) + + def add(self, output, target, weight=None): + """Add a new observation + Args: + output (Tensor): NxK tensor that for each of the N examples + indicates the probability of the example belonging to each of + the K classes, according to the model. The probabilities should + sum to one over all classes + target (Tensor): binary NxK tensort that encodes which of the K + classes are associated with the N-th input + (eg: a row [0, 1, 0, 1] indicates that the example is + associated with classes 2 and 4) + weight (optional, Tensor): Nx1 tensor representing the weight for + each example (each weight > 0) + """ + if not torch.is_tensor(output): + output = torch.from_numpy(output) + if not torch.is_tensor(target): + target = torch.from_numpy(target) + + if weight is not None: + if not torch.is_tensor(weight): + weight = torch.from_numpy(weight) + weight = weight.squeeze() + if output.dim() == 1: + output = output.view(-1, 1) + else: + assert ( + output.dim() == 2 + ), "wrong output size (should be 1D or 2D with one column \ + per class)" + if target.dim() == 1: + target = target.view(-1, 1) + else: + assert ( + target.dim() == 2 + ), "wrong target size (should be 1D or 2D with one column \ + per class)" + if weight is not None: + assert weight.dim() == 1, "Weight dimension should be 1" + assert weight.numel() == target.size(0), "Weight dimension 1 should be the same as that of target" + assert torch.min(weight) >= 0, "Weight should be non-negative only" + if self.scores.numel() > 0: + assert target.size(1) == self.targets.size( + 1 + ), "dimensions for output should match previously added examples." + + # make sure storage is of sufficient size + if self.scores.storage().size() < self.scores.numel() + output.numel(): + new_size = math.ceil(self.scores.storage().size() * 1.5) + new_weight_size = math.ceil(self.weights.storage().size() * 1.5) + self.scores.storage().resize_(int(new_size + output.numel())) + self.targets.storage().resize_(int(new_size + output.numel())) + if weight is not None: + self.weights.storage().resize_(int(new_weight_size + output.size(0))) + + # store scores and targets + offset = self.scores.size(0) if self.scores.dim() > 0 else 0 + self.scores.resize_(offset + output.size(0), output.size(1)) + self.targets.resize_(offset + target.size(0), target.size(1)) + self.scores.narrow(0, offset, output.size(0)).copy_(output) + self.targets.narrow(0, offset, target.size(0)).copy_(target) + + if weight is not None: + self.weights.resize_(offset + weight.size(0)) + self.weights.narrow(0, offset, weight.size(0)).copy_(weight) + + def value(self): + """Returns the model's average precision for each class + Return: + ap (FloatTensor): 1xK tensor, with avg precision for each class k + """ + + if self.scores.numel() == 0: + return 0 + ap = torch.zeros(self.scores.size(1)) + if hasattr(torch, "arange"): + rg = torch.arange(1, self.scores.size(0) + 1).float() + else: + rg = torch.range(1, self.scores.size(0)).float() + if self.weights.numel() > 0: + weight = self.weights.new(self.weights.size()) + weighted_truth = self.weights.new(self.weights.size()) + + # compute average precision for each class + for k in range(self.scores.size(1)): + # sort scores + scores = self.scores[:, k] + targets = self.targets[:, k] + _, sortind = torch.sort(scores, 0, True) + truth = targets[sortind] + if self.weights.numel() > 0: + weight = self.weights[sortind] + weighted_truth = truth.float() * weight + rg = weight.cumsum(0) + + # compute true positive sums + if self.weights.numel() > 0: + tp = weighted_truth.cumsum(0) + else: + tp = truth.float().cumsum(0) + + # compute precision curve + precision = tp.div(rg) + + # compute average precision + ap[k] = precision[truth.bool()].sum() / max(float(truth.sum()), 1) + return ap diff --git a/torch-points3d/torch_points3d/metrics/meters/__init__.py b/torch-points3d/torch_points3d/metrics/meters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/metrics/meters/apprxmeter.py b/torch-points3d/torch_points3d/metrics/meters/apprxmeter.py new file mode 100644 index 0000000..2c64211 --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/meters/apprxmeter.py @@ -0,0 +1,25 @@ +import torch + + +class APPRXMeter: + def __init__(self): + super(APPRXMeter, self).__init__() + self.reset() + + def reset(self): + self.n = 0 + self.target_sum = 0.0 + self.output_sum = 0.0 + + def add(self, output, target): + if not torch.is_tensor(output) and not torch.is_tensor(target): + output = torch.from_numpy(output) + target = torch.from_numpy(target) + self.n += output.numel() + self.target_sum += torch.sum(target).item() + self.output_sum += torch.sum(output).item() + + def value(self): + apprx = abs(1 - self.output_sum/self.target_sum) if self.n > 0 else 0.0 + + return apprx diff --git a/torch-points3d/torch_points3d/metrics/meters/maemeter.py b/torch-points3d/torch_points3d/metrics/meters/maemeter.py new file mode 100644 index 0000000..15dabe1 --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/meters/maemeter.py @@ -0,0 +1,22 @@ +import torch + + +class MAEMeter: + def __init__(self): + super(MAEMeter, self).__init__() + self.reset() + + def reset(self): + self.n = 0 + self.abssum = 0.0 + + def add(self, output, target): + if not torch.is_tensor(output) and not torch.is_tensor(target): + output = torch.from_numpy(output) + target = torch.from_numpy(target) + self.n += output.numel() + self.abssum += torch.sum(abs(output - target)).item() + + def value(self): + mae = self.abssum / max(1, self.n) + return mae diff --git a/torch-points3d/torch_points3d/metrics/meters/r2meter.py b/torch-points3d/torch_points3d/metrics/meters/r2meter.py new file mode 100644 index 0000000..766e27c --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/meters/r2meter.py @@ -0,0 +1,26 @@ +import torch + + +class R2Meter: + def __init__(self, target_mean): + super(R2Meter, self).__init__() + self.target_mean = target_mean + self.reset() + + def reset(self): + self.n = 0 + self.ressum = 0.0 + self.totsum = 0.0 + + def add(self, output, target): + if not torch.is_tensor(output) and not torch.is_tensor(target): + output = torch.from_numpy(output) + target = torch.from_numpy(target) + self.n += output.numel() + self.ressum += torch.sum((output - target) ** 2).item() + self.totsum += torch.sum((target - self.target_mean) ** 2).item() + + def value(self): + r2 = (1 - (self.ressum / self.totsum)) if self.n > 0 and self.totsum > 0 else 0.0 + + return r2 diff --git a/torch-points3d/torch_points3d/metrics/model_checkpoint.py b/torch-points3d/torch_points3d/metrics/model_checkpoint.py new file mode 100644 index 0000000..958b1f3 --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/model_checkpoint.py @@ -0,0 +1,375 @@ +import copy +import glob +import logging +import os +import shutil +from collections import OrderedDict +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import wandb +from omegaconf import DictConfig +from omegaconf import OmegaConf +from torch.nn import DataParallel + +from torch_points3d.core.schedulers.bn_schedulers import instantiate_bn_scheduler +from torch_points3d.core.schedulers.lr_schedulers import instantiate_scheduler +from torch_points3d.models.base_model import BaseModel +from torch_points3d.models.model_factory import instantiate_model +from torch_points3d.utils.colors import COLORS, colored_print + +log = logging.getLogger(__name__) + + +class Checkpoint: + _LATEST = "latest" + + def __init__(self, checkpoint_file: str, save_every_iter: bool = True): + """ Checkpoint manager. Saves to working directory with check_name + Arguments + checkpoint_file {str} -- Path to the checkpoint + save_every_iter {bool} -- [description] (default: {True}) + """ + self._check_path = checkpoint_file + self._filled = False + self.run_config: Optional[Dict] = None + self.models: Dict[str, Any] = {} + self.stats: Dict[str, List[Any]] = {"train": [], "test": [], "val": []} + self.optimizer: Optional[Tuple[str, Any]] = None + self.grad_scale: Optional[Tuple[str, Any]] = None + self.schedulers: Dict[str, Any] = {} + self.dataset_properties: Dict = {} + + def save_objects(self, models_to_save: Dict[str, Any], stage, current_stat, optimizer, schedulers, grad_scale, + **kwargs): + """ Saves checkpoint with updated models for the given stage + """ + self.models = models_to_save + self.optimizer = (optimizer.__class__.__name__, optimizer.state_dict()) + self.schedulers = { + scheduler_name: [scheduler.scheduler_opt, scheduler.state_dict()] + for scheduler_name, scheduler in schedulers.items() + } + self.grad_scale = grad_scale.state_dict() + to_save = kwargs + for key, value in self.__dict__.items(): + if not key.startswith("_"): + to_save[key] = value + torch.save(to_save, self.path) + + @property + def path(self): + return self._check_path + + @staticmethod + def load(checkpoint_dir: str, checkpoint_name: str, run_config: Any, strict=False, resume=True): + """ Creates a new checkpoint object in the current working directory by loading the + checkpoint located at [checkpointdir]/[checkpoint_name].pt + """ + checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) + ".pt" + if not os.path.exists(checkpoint_file): + ckp = Checkpoint(checkpoint_file) + if strict or resume: + available_checkpoints = glob.glob(os.path.join(checkpoint_dir, "*.pt")) + message = "The provided path {} didn't contain the checkpoint_file {}".format( + checkpoint_dir, checkpoint_name + ".pt" + ) + if available_checkpoints: + message += "\nDid you mean {}?".format(os.path.basename(available_checkpoints[0])) + raise ValueError(message) + ckp.run_config = run_config + return ckp + else: + chkp_name = os.path.basename(checkpoint_file) + if resume: + shutil.copyfile( + checkpoint_file, chkp_name + ) # Copy checkpoint to new run directory to make sure we don't override + ckp = Checkpoint(chkp_name) + log.info("Loading checkpoint from {}".format(checkpoint_file)) + objects = torch.load(checkpoint_file, map_location="cpu") + for key, value in objects.items(): + setattr(ckp, key, value) + ckp._filled = True + return ckp + + @property + def is_empty(self): + return not self._filled + + def load_optim_sched(self, model: BaseModel, load_state=True): + if not self.is_empty: + optimizer_config = self.optimizer + + # initialize & load schedulers + schedulers_out = {} + schedulers_config = self.schedulers + for scheduler_type, (scheduler_opt, scheduler_state) in schedulers_config.items(): + if scheduler_type == "lr_scheduler": + optimizer = model.optimizer + scheduler = instantiate_scheduler(optimizer, OmegaConf.create(scheduler_opt)) + if load_state: + scheduler.load_state_dict(scheduler_state) + schedulers_out["lr_scheduler"] = scheduler + + elif scheduler_type == "bn_scheduler": + scheduler = instantiate_bn_scheduler(model, OmegaConf.create(scheduler_opt)) + if load_state: + scheduler.load_state_dict(scheduler_state) + schedulers_out["bn_scheduler"] = scheduler + + # load optimizer + model.schedulers = schedulers_out + if load_state: + model.optimizer.load_state_dict(optimizer_config[1]) + + def load_grad_scale(self, model: BaseModel, load_state=True): + # load grad scaler settings + if not self.is_empty and load_state and self.grad_scale is not None: + model.grad_scale.load_state_dict(self.grad_scale) + + def get_state_dict(self, weight_name): + if not self.is_empty: + try: + models = self.models + keys = [key.replace("best_", "") for key in models.keys()] + log.info("Available weights : {}".format(keys)) + try: + key_name = "best_{}".format(weight_name) + model = models[key_name] + log.info("Model loaded from {}:{}.".format(self._check_path, key_name)) + return model + except: + key_name = Checkpoint._LATEST + model = models[Checkpoint._LATEST] + log.info("Model loaded from {}:{}".format(self._check_path, key_name)) + return model + except: + raise Exception("This weight name isn't within the checkpoint ") + + +class ModelCheckpoint(object): + """ Create a checkpoint for a given model + + Arguments: + - load_dir: directory where to load the checkpoint from (if exists) + - check_name: Name of the checkpoint (without the .pt extension) + - selection_stage: Stage that is used for selecting the best model + - run_config: Config of the run. In resume mode, this gets discarded + - resume: Resume a previous training - this creates optimizers + - strict: If strict and checkpoint is empty then it raises a ValueError. Being in resume mode forces strict + """ + + def __init__( + self, + load_dir: str, + check_name: str, + selection_stage: str, + run_config: DictConfig = DictConfig({}), + resume=False, + strict=False, + resume_opt=None, + ): + # Conversion of run_config to save a dictionary and not a pickle of omegaconf + rc = OmegaConf.to_container(copy.deepcopy(run_config)) + self._checkpoint = Checkpoint.load(load_dir, check_name, run_config=rc, strict=strict, resume=resume) + self._resume = resume + if resume_opt is None: + resume_opt = resume + self._resume_opt = resume_opt + self._selection_stage = selection_stage + + def create_model(self, dataset, weight_name=Checkpoint._LATEST): + if not self.is_empty: + run_config = OmegaConf.create(copy.deepcopy(self._checkpoint.run_config)) + model = instantiate_model(run_config, dataset) + if hasattr(self._checkpoint, "model_props"): + for k, v in self._checkpoint.model_props.items(): + setattr(model, k, v) + delattr(self._checkpoint, "model_props") + self._initialize_model(run_config, model, weight_name) + return model + else: + raise ValueError("Checkpoint is empty") + + @property + def start_epoch(self): + if self._resume: + return self.get_starting_epoch() + else: + return 1 + + @property + def run_config(self): + return OmegaConf.create(self._checkpoint.run_config) + + @property + def data_config(self): + return OmegaConf.create(self._checkpoint.run_config).data + + @property + def selection_stage(self): + return self._selection_stage + + @selection_stage.setter + def selection_stage(self, value): + self._selection_stage = value + + @property + def is_empty(self): + return self._checkpoint.is_empty + + @property + def checkpoint_path(self): + return self._checkpoint.path + + @property + def dataset_properties(self) -> Dict: + return self._checkpoint.dataset_properties + + @dataset_properties.setter + def dataset_properties(self, dataset_properties: Union[Dict[str, Any], Dict]): + self._checkpoint.dataset_properties = dataset_properties + + def get_starting_epoch(self): + return len(self._checkpoint.stats["train"]) + 1 + + def _initialize_model(self, run_config: OmegaConf, model: BaseModel, weight_name): + if not self._checkpoint.is_empty: + state_dict = self._checkpoint.get_state_dict(weight_name) + model.load_state_dict(state_dict, strict=False) + model.init_optim(run_config) + model.init_schedulers(run_config) + model.init_grad_scaler(run_config) + + def find_func_from_metric_name(self, metric_name, default_metrics_func): + for token_name, func in default_metrics_func.items(): + if token_name in metric_name: + return func + raise Exception( + 'The metric name {} doesn t have a func to measure which one is best in {}. Example: For best_train_iou, {{"iou":max}}'.format( + metric_name, default_metrics_func + ) + ) + + def save_best_models_under_current_metrics( + self, model: BaseModel, metrics_holder: dict, metric_func_dict: dict, wandb_log: bool, **kwargs + ): + """[This function is responsible to save checkpoint under the current metrics and their associated DEFAULT_METRICS_FUNC] + Arguments: + model {[CheckpointInterface]} -- [Model] + metrics_holder {[Dict]} -- [Need to contain stage, epoch, current_metrics] + """ + metrics = metrics_holder["current_metrics"] + stage = metrics_holder["stage"] + epoch = metrics_holder["epoch"] + p_metrics = metrics_holder["all_metrics"] + + stats = self._checkpoint.stats + state_dict = copy.deepcopy(model.state_dict()) + + # if multi GPU, remove DataParallel part + if isinstance(model.model, DataParallel): + new_state_dict = OrderedDict() + for key in state_dict: + name = copy.copy(key) + if key[6:12] == "module": + name = name[:5] + name[12:] + new_state_dict[name] = state_dict[key] + state_dict = new_state_dict + + current_stat = {"epoch": epoch} + + log_metrics = {} + + models_to_save = self._checkpoint.models + if stage not in stats: + stats[stage] = [] + + if stage == "train": + models_to_save[Checkpoint._LATEST] = state_dict + else: + latest_stats = None if len(stats[stage]) == 0 else stats[stage][-1] + + msg = "" + improved_metric = 0 + if wandb_log: + log_metrics = {"epoch": epoch, } + + for metric_name, current_metric_value in metrics.items(): + if all(key not in metric_name for key in ["total_", "loss_"]): + continue + + current_stat[metric_name] = current_metric_value + + try: + metric_func = self.find_func_from_metric_name(metric_name, metric_func_dict) + except Exception: + continue # no metric function was defined, so it is only used for logging + + if latest_stats is None: + current_stat["best_{}".format(metric_name)] = current_metric_value + models_to_save["best_{}".format(metric_name)] = state_dict + else: + best_metric_from_stats = latest_stats.get("best_{}".format(metric_name), current_metric_value) + best_value = metric_func(best_metric_from_stats, current_metric_value) + current_stat["best_{}".format(metric_name)] = best_value + # This new value seems to be better under metric_func + if (self._selection_stage == stage) and (current_metric_value == best_value) \ + and (current_metric_value != best_metric_from_stats): # Update the model weights + + models_to_save["best_{}".format(metric_name)] = state_dict + + if wandb_log: + log_metrics[f"{stage}_best_{metric_name}"] = wandb.Table( + columns=["epoch", "metric", "value"] + ) + + [log_metrics[f"{stage}_best_{metric_name}"].add_data( + epoch, f"{stage}_{metric}", metrics[metric] + ) for metric in metrics] + for metric in p_metrics: + if "cmat" in metric: + cmat_plot = copy.deepcopy(p_metrics[metric]) + title = cmat_plot.string_fields + title["title"] = title["title"] + f"; best {metric_name}" + cmat_plot = wandb.plot_table( + "wandb/confusion_matrix/v1", + cmat_plot.table, + cmat_plot.fields, + title, + ) + log_metrics[f"{metric}_best_{metric_name}"] = cmat_plot + + msg += "{}: {} -> {}, ".format(metric_name, best_metric_from_stats, best_value) + improved_metric += 1 + + if improved_metric > 0: + colored_print(COLORS.VAL_COLOR, msg[:-2]) + + kwargs["model_props"] = { + "num_epochs": model.num_epochs, # type: ignore + "num_batches": model.num_batches, # type: ignore + "num_samples": model.num_samples, # type: ignore + } + + self._checkpoint.stats[stage].append(current_stat) + self._checkpoint.save_objects(models_to_save, stage, current_stat, model.optimizer, model.schedulers, + model.grad_scale, **kwargs) + + p_metrics.update(log_metrics) + + return p_metrics + + def validate(self, dataset_config): + """ A checkpoint is considered as valid if it can recreate the model from + a dataset config only """ + if dataset_config is not None: + for k, v in dataset_config.items(): + self.data_config[k] = v + try: + instantiate_model(OmegaConf.create(self.run_config), self.data_config) + except Exception as e: + return False + + return True diff --git a/torch-points3d/torch_points3d/metrics/object_detection_tracker.py b/torch-points3d/torch_points3d/metrics/object_detection_tracker.py new file mode 100644 index 0000000..44fffb5 --- /dev/null +++ b/torch-points3d/torch_points3d/metrics/object_detection_tracker.py @@ -0,0 +1,131 @@ +from typing import Dict, List, Any +import torchnet as tnt +import torch +from collections import OrderedDict + +from torch_points3d.models.model_interface import TrackerInterface +from torch_points3d.metrics.base_tracker import BaseTracker, meter_value +from torch_points3d.datasets.segmentation import IGNORE_LABEL + +from torch_points3d.datasets.object_detection.box_data import BoxData +from .box_detection.ap import eval_detection + + +class ObjectDetectionTracker(BaseTracker): + def __init__(self, dataset, stage="train", wandb_log=False, use_tensorboard: bool = False): + super(ObjectDetectionTracker, self).__init__(stage, wandb_log, use_tensorboard) + self._num_classes = dataset.num_classes + self._dataset = dataset + self.reset(stage) + self._metric_func = {"loss": min, "acc": max, "pos": max, "neg": min, "map": max} + + def reset(self, stage="train"): + super().reset(stage=stage) + self._pred_boxes: Dict[str, List[BoxData]] = {} + self._gt_boxes: Dict[str, List[BoxData]] = {} + self._rec: Dict[str, Dict[str, float]] = {} + self._ap: Dict[str, Dict[str, float]] = {} + self._neg_ratio = tnt.meter.AverageValueMeter() + self._obj_acc = tnt.meter.AverageValueMeter() + self._pos_ratio = tnt.meter.AverageValueMeter() + + @staticmethod + def detach_tensor(tensor): + if torch.torch.is_tensor(tensor): + tensor = tensor.detach() + return tensor + + def track(self, model: TrackerInterface, data=None, track_boxes=False, **kwargs): + """ Add current model predictions (usually the result of a batch) to the tracking + if tracking boxes, you must provide a labeled "data" object with the following attributes: + - id_scan: id of the scan to which the boxes belong to + - instance_box_cornerimport torchnet as tnts - gt box corners + - box_label_mask - mask for boxes (0 = no box) + - sem_cls_label - semantic label for each box + """ + super().track(model) + + outputs = model.get_output() + + total_num_proposal = outputs.objectness_label.shape[0] * outputs.objectness_label.shape[1] + pos_ratio = torch.sum(outputs.objectness_label.float()).item() / float(total_num_proposal) + self._pos_ratio.add(pos_ratio) + self._neg_ratio.add(torch.sum(outputs.objectness_mask.float()).item() / float(total_num_proposal) - pos_ratio) + + obj_pred_val = torch.argmax(outputs.objectness_scores, 2) # B,K + self._obj_acc.add( + torch.sum((obj_pred_val == outputs.objectness_label.long()).float() * outputs.objectness_mask).item() + / (torch.sum(outputs.objectness_mask) + 1e-6).item() + ) + + if data is None or self._stage == "train" or not track_boxes: + return + + self._add_box_pred(outputs, data, model.conv_type) + + def _add_box_pred(self, outputs, input_data, conv_type): + # Track box predictions + pred_boxes = outputs.get_boxes(self._dataset, apply_nms=True, duplicate_boxes=False) + if input_data.id_scan is None: + raise ValueError("Cannot track boxes without knowing in which scan they are") + + scan_ids = input_data.id_scan + assert len(scan_ids) == len(pred_boxes) + for idx, scan_id in enumerate(scan_ids): + # Predictions + self._pred_boxes[scan_id.item()] = pred_boxes[idx] + + # Ground truth + sample_mask = idx + gt_boxes = input_data.instance_box_corners[sample_mask] + gt_boxes = gt_boxes[input_data.box_label_mask[sample_mask]] + sample_labels = input_data.sem_cls_label[sample_mask] + gt_box_data = [BoxData(sample_labels[i].item(), gt_boxes[i]) for i in range(len(gt_boxes))] + self._gt_boxes[scan_id.item()] = gt_box_data + + def get_metrics(self, verbose=False) -> Dict[str, Any]: + """ Returns a dictionnary of all metrics and losses being tracked + """ + metrics = super().get_metrics(verbose) + + metrics["{}_acc".format(self._stage)] = meter_value(self._obj_acc) + metrics["{}_pos".format(self._stage)] = meter_value(self._pos_ratio) + metrics["{}_neg".format(self._stage)] = meter_value(self._neg_ratio) + + if self._has_box_data: + for thresh, ap in self._ap.items(): + mAP = sum(ap.values()) / len(ap) + metrics["{}_map{}".format(self._stage, thresh)] = mAP + + if verbose and self._has_box_data: + for thresh in self._ap: + metrics["{}_class_rec{}".format(self._stage, thresh)] = self._dict_to_str(self._rec[thresh]) + metrics["{}_class_ap{}".format(self._stage, thresh)] = self._dict_to_str(self._ap[thresh]) + + return metrics + + def finalise(self, track_boxes=False, overlap_thresholds=[0.25, 0.5], **kwargs): + if not track_boxes or len(self._gt_boxes) == 0: + return + + # Compute box detection metrics + self._ap = {} + self._rec = {} + for thresh in overlap_thresholds: + rec, _, ap = eval_detection(self._pred_boxes, self._gt_boxes, ovthresh=thresh) + self._ap[str(thresh)] = OrderedDict(sorted(ap.items())) + self._rec[str(thresh)] = OrderedDict({}) + for key, val in sorted(rec.items()): + try: + value = val[-1] + except TypeError: + value = val + self._rec[str(thresh)][key] = value + + @property + def _has_box_data(self): + return len(self._rec) + + @property + def metric_func(self): + return self._metric_func diff --git a/torch-points3d/torch_points3d/models/__init__.py b/torch-points3d/torch_points3d/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/models/base_architectures/__init__.py b/torch-points3d/torch_points3d/models/base_architectures/__init__.py new file mode 100644 index 0000000..8b4e6df --- /dev/null +++ b/torch-points3d/torch_points3d/models/base_architectures/__init__.py @@ -0,0 +1,2 @@ +from .unet import * +from .backbone import * diff --git a/torch-points3d/torch_points3d/models/base_architectures/backbone.py b/torch-points3d/torch_points3d/models/base_architectures/backbone.py new file mode 100644 index 0000000..e184182 --- /dev/null +++ b/torch-points3d/torch_points3d/models/base_architectures/backbone.py @@ -0,0 +1,138 @@ +import logging + +from omegaconf.dictconfig import DictConfig +from torch import nn + +from torch_points3d.datasets.base_dataset import BaseDataset +from torch_points3d.models.base_architectures import BaseFactory +from torch_points3d.models.base_model import BaseModel +from torch_points3d.utils.config import is_list + +log = logging.getLogger(__name__) + +SPECIAL_NAMES = ["radius", "max_num_neighbors", "block_names"] + + +############################# Backbone Base ################################### + + +class BackboneBasedModel(BaseModel): + """ + create a backbone-based generator: + This is simply an encoder + (can be used in classification, regression, metric learning and so one) + """ + + def _save_sampling_and_search(self, down_conv): + sampler = getattr(down_conv, "sampler", None) + if is_list(sampler): + self._spatial_ops_dict["sampler"] = sampler + self._spatial_ops_dict["sampler"] + else: + self._spatial_ops_dict["sampler"] = [sampler] + self._spatial_ops_dict["sampler"] + + neighbour_finder = getattr(down_conv, "neighbour_finder", None) + if is_list(neighbour_finder): + self._spatial_ops_dict["neighbour_finder"] = neighbour_finder + self._spatial_ops_dict["neighbour_finder"] + else: + self._spatial_ops_dict["neighbour_finder"] = [neighbour_finder] + self._spatial_ops_dict["neighbour_finder"] + + def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib): + + """Construct a backbone generator (It is a simple down module) + Parameters: + opt - options for the network generation + model_type - type of the model to be generated + modules_lib - all modules that can be used in the backbone + + + opt is expected to contains the following keys: + * down_conv + """ + + super(BackboneBasedModel, self).__init__(opt) + self._spatial_ops_dict = {"neighbour_finder": [], "sampler": []} + + # detect which options format has been used to define the model + if is_list(opt.down_conv) or "down_conv_nn" not in opt.down_conv: + raise NotImplementedError + else: + self._init_from_compact_format(opt, model_type, dataset, modules_lib) + + def _get_from_kwargs(self, kwargs, name): + module = kwargs[name] + kwargs.pop(name) + return module + + def _init_from_compact_format(self, opt, model_type, dataset, modules_lib): + """Create a backbonebasedmodel from the compact options format - where the + same convolution is given for each layer, and arguments are given + in lists + """ + num_convs = len(opt.down_conv.down_conv_nn) + self.down_modules = nn.ModuleList() + factory_module_cls = self._get_factory(model_type, modules_lib) + down_conv_cls_name = opt.down_conv.module_name + self._factory_module = factory_module_cls(down_conv_cls_name, None, modules_lib) + # Down modules + for i in range(num_convs): + args = self._fetch_arguments(opt.down_conv, i, "DOWN") + conv_cls = self._get_from_kwargs(args, "conv_cls") + down_module = conv_cls(**args) + self._save_sampling_and_search(down_module) + self.down_modules.append(down_module) + + self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( + getattr(opt, "metric_loss", None), getattr(opt, "miner", None) + ) + + def _get_factory(self, model_name, modules_lib) -> BaseFactory: + factory_module_cls = getattr(modules_lib, "{}Factory".format(model_name), None) + if factory_module_cls is None: + factory_module_cls = BaseFactory + return factory_module_cls + + def _fetch_arguments_from_list(self, opt, index): + """Fetch the arguments for a single convolution from multiple lists + of arguments - for models specified in the compact format. + """ + args = {} + for o, v in opt.items(): + name = str(o) + if is_list(v) and len(getattr(opt, o)) > 0: + if name[-1] == "s" and name not in SPECIAL_NAMES: + name = name[:-1] + v_index = v[index] + if is_list(v_index): + v_index = list(v_index) + args[name] = v_index + else: + if is_list(v): + v = list(v) + args[name] = v + return args + + def _fetch_arguments(self, conv_opt, index, flow="DOWN"): + """ Fetches arguments for building a convolution down + + Arguments: + conv_opt + index in sequential order (as they come in the config) + flow "DOWN" + """ + args = self._fetch_arguments_from_list(conv_opt, index) + args["conv_cls"] = self._factory_module.get_module(flow) + args["index"] = index + return args + + def _flatten_compact_options(self, opt): + """Converts from a dict of lists, to a list of dicts + """ + flattenedOpts = [] + + for index in range(int(1e6)): + try: + flattenedOpts.append(DictConfig(self._fetch_arguments_from_list(opt, index))) + except IndexError: + break + + return flattenedOpts diff --git a/torch-points3d/torch_points3d/models/base_architectures/unet.py b/torch-points3d/torch_points3d/models/base_architectures/unet.py new file mode 100644 index 0000000..27dd8b3 --- /dev/null +++ b/torch-points3d/torch_points3d/models/base_architectures/unet.py @@ -0,0 +1,516 @@ +import copy +import logging + +from omegaconf.dictconfig import DictConfig +from omegaconf.listconfig import ListConfig +from torch import nn + +from torch_points3d.core.common_modules.base_modules import Identity +from torch_points3d.datasets.base_dataset import BaseDataset +from torch_points3d.models.base_model import BaseModel +from torch_points3d.utils.config import is_list + +log = logging.getLogger(__name__) + + +SPECIAL_NAMES = ["radius", "max_num_neighbors", "block_names"] + + +class BaseFactory: + def __init__(self, module_name_down, module_name_up, modules_lib): + self.module_name_down = module_name_down + self.module_name_up = module_name_up + self.modules_lib = modules_lib + + def get_module(self, flow): + if flow.upper() == "UP": + return getattr(self.modules_lib, self.module_name_up, None) + else: + return getattr(self.modules_lib, self.module_name_down, None) + + +############################# UNET BASE ################################### + + +class UnetBasedModel(BaseModel): + """Create a Unet-based generator""" + + def _save_sampling_and_search(self, submodule): + sampler = getattr(submodule.down, "sampler", None) + if is_list(sampler): + self._spatial_ops_dict["sampler"] = sampler + self._spatial_ops_dict["sampler"] + else: + self._spatial_ops_dict["sampler"] = [sampler] + self._spatial_ops_dict["sampler"] + + neighbour_finder = getattr(submodule.down, "neighbour_finder", None) + if is_list(neighbour_finder): + self._spatial_ops_dict["neighbour_finder"] = neighbour_finder + self._spatial_ops_dict["neighbour_finder"] + else: + self._spatial_ops_dict["neighbour_finder"] = [neighbour_finder] + self._spatial_ops_dict["neighbour_finder"] + + upsample_op = getattr(submodule.up, "upsample_op", None) + if upsample_op: + self._spatial_ops_dict["upsample_op"].append(upsample_op) + + def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib): + """Construct a Unet generator + Parameters: + opt - options for the network generation + model_type - type of the model to be generated + num_class - output of the network + modules_lib - all modules that can be used in the UNet + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + + opt is expected to contains the following keys: + * down_conv + * up_conv + * OPTIONAL: innermost + """ + opt = copy.deepcopy(opt) + super(UnetBasedModel, self).__init__(opt) + self._spatial_ops_dict = {"neighbour_finder": [], "sampler": [], "upsample_op": []} + # detect which options format has been used to define the model + if type(opt.down_conv) is ListConfig or "down_conv_nn" not in opt.down_conv: + self._init_from_layer_list_format(opt, model_type, dataset, modules_lib) + else: + self._init_from_compact_format(opt, model_type, dataset, modules_lib) + + def _init_from_compact_format(self, opt, model_type, dataset, modules_lib): + """Create a unetbasedmodel from the compact options format - where the + same convolution is given for each layer, and arguments are given + in lists + """ + num_convs = len(opt.down_conv.down_conv_nn) + + # Factory for creating up and down modules + factory_module_cls = self._get_factory(model_type, modules_lib) + down_conv_cls_name = opt.down_conv.module_name + up_conv_cls_name = opt.up_conv.module_name + self._factory_module = factory_module_cls( + down_conv_cls_name, up_conv_cls_name, modules_lib + ) # Create the factory object + # construct unet structure + contains_global = hasattr(opt, "innermost") and opt.innermost is not None + if contains_global: + assert len(opt.down_conv.down_conv_nn) + 1 == len(opt.up_conv.up_conv_nn) + + args_up = self._fetch_arguments_from_list(opt.up_conv, 0) + args_up["up_conv_cls"] = self._factory_module.get_module("UP") + + unet_block = UnetSkipConnectionBlock( + args_up=args_up, + args_innermost=opt.innermost, + modules_lib=modules_lib, + submodule=None, + innermost=True, + ) # add the innermost layer + else: + unet_block = Identity() + + if num_convs > 1: + for index in range(num_convs - 1, 0, -1): + args_up, args_down = self._fetch_arguments_up_and_down(opt, index) + unet_block = UnetSkipConnectionBlock(args_up=args_up, args_down=args_down, submodule=unet_block) + self._save_sampling_and_search(unet_block) + else: + index = num_convs + + index -= 1 + args_up, args_down = self._fetch_arguments_up_and_down(opt, index) + self.model = UnetSkipConnectionBlock( + args_up=args_up, args_down=args_down, submodule=unet_block, outermost=True + ) # add the outermost layer + self._save_sampling_and_search(self.model) + + def _init_from_layer_list_format(self, opt, model_type, dataset, modules_lib): + """Create a unetbasedmodel from the layer list options format - where + each layer of the unet is specified separately + """ + + self._get_factory(model_type, modules_lib) + + down_conv_layers = ( + opt.down_conv if type(opt.down_conv) is ListConfig else self._flatten_compact_options(opt.down_conv) + ) + up_conv_layers = opt.up_conv if type(opt.up_conv) is ListConfig else self._flatten_compact_options(opt.up_conv) + num_convs = len(down_conv_layers) + + unet_block = [] + contains_global = hasattr(opt, "innermost") and opt.innermost is not None + if contains_global: + assert len(down_conv_layers) + 1 == len(up_conv_layers) + + up_layer = dict(up_conv_layers[0]) + up_layer["up_conv_cls"] = getattr(modules_lib, up_layer["module_name"]) + + unet_block = UnetSkipConnectionBlock( + args_up=up_layer, + args_innermost=opt.innermost, + modules_lib=modules_lib, + innermost=True, + ) + + for index in range(num_convs - 1, 0, -1): + down_layer = dict(down_conv_layers[index]) + up_layer = dict(up_conv_layers[num_convs - index]) + + down_layer["down_conv_cls"] = getattr(modules_lib, down_layer["module_name"]) + up_layer["up_conv_cls"] = getattr(modules_lib, up_layer["module_name"]) + + unet_block = UnetSkipConnectionBlock( + args_up=up_layer, + args_down=down_layer, + modules_lib=modules_lib, + submodule=unet_block, + ) + + up_layer = dict(up_conv_layers[-1]) + down_layer = dict(down_conv_layers[0]) + down_layer["down_conv_cls"] = getattr(modules_lib, down_layer["module_name"]) + up_layer["up_conv_cls"] = getattr(modules_lib, up_layer["module_name"]) + self.model = UnetSkipConnectionBlock( + args_up=up_layer, args_down=down_layer, submodule=unet_block, outermost=True + ) + + self._save_sampling_and_search(self.model) + + def _get_factory(self, model_name, modules_lib) -> BaseFactory: + factory_module_cls = getattr(modules_lib, "{}Factory".format(model_name), None) + if factory_module_cls is None: + factory_module_cls = BaseFactory + return factory_module_cls + + def _fetch_arguments_from_list(self, opt, index): + """Fetch the arguments for a single convolution from multiple lists + of arguments - for models specified in the compact format. + """ + args = {} + for o, v in opt.items(): + name = str(o) + if is_list(v) and len(getattr(opt, o)) > 0: + if name[-1] == "s" and name not in SPECIAL_NAMES: + name = name[:-1] + v_index = v[index] + if is_list(v_index): + v_index = list(v_index) + args[name] = v_index + else: + if is_list(v): + v = list(v) + args[name] = v + return args + + def _fetch_arguments_up_and_down(self, opt, index): + # Defines down arguments + args_down = self._fetch_arguments_from_list(opt.down_conv, index) + args_down["index"] = index + args_down["down_conv_cls"] = self._factory_module.get_module("DOWN") + + # Defines up arguments + idx = len(getattr(opt.up_conv, "up_conv_nn")) - index - 1 + args_up = self._fetch_arguments_from_list(opt.up_conv, idx) + args_up["index"] = index + args_up["up_conv_cls"] = self._factory_module.get_module("UP") + return args_up, args_down + + def _flatten_compact_options(self, opt): + """Converts from a dict of lists, to a list of dicts""" + flattenedOpts = [] + + for index in range(int(1e6)): + try: + flattenedOpts.append(DictConfig(self._fetch_arguments_from_list(opt, index))) + except IndexError: + break + + return flattenedOpts + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + + """ + + def get_from_kwargs(self, kwargs, name): + module = kwargs[name] + kwargs.pop(name) + return module + + def __init__( + self, + args_up=None, + args_down=None, + args_innermost=None, + modules_lib=None, + submodule=None, + outermost=False, + innermost=False, + ): + """Construct a Unet submodule with skip connections. + Parameters: + args_up -- arguments for up convs + args_down -- arguments for down convs + args_innermost -- arguments for innermost + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + """ + super(UnetSkipConnectionBlock, self).__init__() + + self.outermost = outermost + self.innermost = innermost + + if innermost: + assert outermost == False + module_name = self.get_from_kwargs(args_innermost, "module_name") + inner_module_cls = getattr(modules_lib, module_name) + self.inner = inner_module_cls(**args_innermost) + upconv_cls = self.get_from_kwargs(args_up, "up_conv_cls") + self.up = upconv_cls(**args_up) + else: + downconv_cls = self.get_from_kwargs(args_down, "down_conv_cls") + upconv_cls = self.get_from_kwargs(args_up, "up_conv_cls") + downconv = downconv_cls(**args_down) + upconv = upconv_cls(**args_up) + + self.down = downconv + self.submodule = submodule + self.up = upconv + + def forward(self, data, *args, **kwargs): + if self.innermost: + data_out = self.inner(data, **kwargs) + data = (data_out, data) + return self.up(data, **kwargs) + else: + data_out = self.down(data, **kwargs) + data_out2 = self.submodule(data_out, **kwargs) + data = (data_out2, data) + return self.up(data, **kwargs) + + +############################# UNWRAPPED UNET BASE ################################### + + +class UnwrappedUnetBasedModel(BaseModel): + """Create a Unet unwrapped generator""" + + def _save_sampling_and_search(self, down_conv): + sampler = getattr(down_conv, "sampler", None) + if is_list(sampler): + self._spatial_ops_dict["sampler"] += sampler + else: + self._spatial_ops_dict["sampler"].append(sampler) + + neighbour_finder = getattr(down_conv, "neighbour_finder", None) + if is_list(neighbour_finder): + self._spatial_ops_dict["neighbour_finder"] += neighbour_finder + else: + self._spatial_ops_dict["neighbour_finder"].append(neighbour_finder) + + def _save_upsample(self, up_conv): + upsample_op = getattr(up_conv, "upsample_op", None) + if upsample_op: + self._spatial_ops_dict["upsample_op"].append(upsample_op) + + def __init__(self, opt, model_type, dataset: BaseDataset, modules_lib): + """Construct a Unet unwrapped generator + + The layers will be appended within lists with the following names + * down_modules : Contains all the down module + * inner_modules : Contain one or more inner modules + * up_modules: Contains all the up module + + Parameters: + opt - options for the network generation + model_type - type of the model to be generated + num_class - output of the network + modules_lib - all modules that can be used in the UNet + + For a recursive implementation. See UnetBaseModel. + + opt is expected to contains the following keys: + * down_conv + * up_conv + * OPTIONAL: innermost + + """ + opt = copy.deepcopy(opt) + super(UnwrappedUnetBasedModel, self).__init__(opt) + # detect which options format has been used to define the model + self._spatial_ops_dict = {"neighbour_finder": [], "sampler": [], "upsample_op": []} + + if is_list(opt.down_conv) or "down_conv_nn" not in opt.down_conv: + raise NotImplementedError + else: + self._init_from_compact_format(opt, model_type, dataset, modules_lib) + + def _collect_sampling_ids(self, list_data): + def extract_matching_key(keys, start_token): + for key in keys: + if key.startswith(start_token): + return key + return None + + d = {} + if self.save_sampling_id: + for idx, data in enumerate(list_data): + key = extract_matching_key(data.keys, "sampling_id") + if key: + d[key] = getattr(data, key) + return d + + def _get_from_kwargs(self, kwargs, name): + module = kwargs[name] + kwargs.pop(name) + return module + + def _create_inner_modules(self, args_innermost, modules_lib): + inners = [] + if is_list(args_innermost): + for inner_opt in args_innermost: + module_name = self._get_from_kwargs(inner_opt, "module_name") + inner_module_cls = getattr(modules_lib, module_name) + inners.append(inner_module_cls(**inner_opt)) + + else: + module_name = self._get_from_kwargs(args_innermost, "module_name") + inner_module_cls = getattr(modules_lib, module_name) + inners.append(inner_module_cls(**args_innermost)) + + return inners + + def _init_from_compact_format(self, opt, model_type, dataset, modules_lib): + """Create a unetbasedmodel from the compact options format - where the + same convolution is given for each layer, and arguments are given + in lists + """ + + self.down_modules = nn.ModuleList() + self.inner_modules = nn.ModuleList() + self.up_modules = nn.ModuleList() + + self.save_sampling_id = opt.down_conv.get('save_sampling_id') + + # Factory for creating up and down modules + factory_module_cls = self._get_factory(model_type, modules_lib) + down_conv_cls_name = opt.down_conv.module_name + up_conv_cls_name = opt.up_conv.module_name if opt.get('up_conv') is not None else None + self._factory_module = factory_module_cls( + down_conv_cls_name, up_conv_cls_name, modules_lib + ) # Create the factory object + + # Loal module + contains_global = hasattr(opt, "innermost") and opt.innermost is not None + if contains_global: + inners = self._create_inner_modules(opt.innermost, modules_lib) + for inner in inners: + self.inner_modules.append(inner) + else: + self.inner_modules.append(Identity()) + + # Down modules + for i in range(len(opt.down_conv.down_conv_nn)): + args = self._fetch_arguments(opt.down_conv, i, "DOWN") + conv_cls = self._get_from_kwargs(args, "conv_cls") + down_module = conv_cls(**args) + self._save_sampling_and_search(down_module) + self.down_modules.append(down_module) + + # Up modules + if up_conv_cls_name: + for i in range(len(opt.up_conv.up_conv_nn)): + args = self._fetch_arguments(opt.up_conv, i, "UP") + conv_cls = self._get_from_kwargs(args, "conv_cls") + up_module = conv_cls(**args) + self._save_upsample(up_module) + self.up_modules.append(up_module) + + self.metric_loss_module, self.miner_module = BaseModel.get_metric_loss_and_miner( + getattr(opt, "metric_loss", None), getattr(opt, "miner", None) + ) + + def _get_factory(self, model_name, modules_lib) -> BaseFactory: + factory_module_cls = getattr(modules_lib, "{}Factory".format(model_name), None) + if factory_module_cls is None: + factory_module_cls = BaseFactory + return factory_module_cls + + def _fetch_arguments_from_list(self, opt, index): + """Fetch the arguments for a single convolution from multiple lists + of arguments - for models specified in the compact format. + """ + args = {} + for o, v in opt.items(): + name = str(o) + if is_list(v) and len(getattr(opt, o)) > 0: + if name[-1] == "s" and name not in SPECIAL_NAMES: + name = name[:-1] + v_index = v[index] + if is_list(v_index): + v_index = list(v_index) + args[name] = v_index + else: + if is_list(v): + v = list(v) + args[name] = v + return args + + def _fetch_arguments(self, conv_opt, index, flow): + """Fetches arguments for building a convolution (up or down) + + Arguments: + conv_opt + index in sequential order (as they come in the config) + flow "UP" or "DOWN" + """ + args = self._fetch_arguments_from_list(conv_opt, index) + args["conv_cls"] = self._factory_module.get_module(flow) + args["index"] = index + return args + + def _flatten_compact_options(self, opt): + """Converts from a dict of lists, to a list of dicts""" + flattenedOpts = [] + + for index in range(int(1e6)): + try: + flattenedOpts.append(DictConfig(self._fetch_arguments_from_list(opt, index))) + except IndexError: + break + + return flattenedOpts + + def forward(self, data, precomputed_down=None, precomputed_up=None, **kwargs): + """This method does a forward on the Unet assuming symmetrical skip connections + + Parameters + ---------- + data: torch.geometric.Data + Data object that contains all info required by the modules + precomputed_down: torch.geometric.Data + Precomputed data that will be passed to the down convs + precomputed_up: torch.geometric.Data + Precomputed data that will be passed to the up convs + """ + stack_down = [] + for i in range(len(self.down_modules) - 1): + data = self.down_modules[i](data, precomputed=precomputed_down) + stack_down.append(data) + data = self.down_modules[-1](data, precomputed=precomputed_down) + + if not isinstance(self.inner_modules[0], Identity): + stack_down.append(data) + data = self.inner_modules[0](data) + + sampling_ids = self._collect_sampling_ids(stack_down) + + for i in range(len(self.up_modules)): + data = self.up_modules[i]((data, stack_down.pop()), precomputed=precomputed_up) + + for key, value in sampling_ids.items(): + setattr(data, key, value) + return data diff --git a/torch-points3d/torch_points3d/models/base_model.py b/torch-points3d/torch_points3d/models/base_model.py new file mode 100644 index 0000000..94e653d --- /dev/null +++ b/torch-points3d/torch_points3d/models/base_model.py @@ -0,0 +1,438 @@ +import logging +import os +from collections import OrderedDict +from typing import Optional, Dict, Any, List + +import torch +from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.optimizer import Optimizer + + +from torch_points3d.core.optimizer.adabelief import AdaBelief +from torch_points3d.core.regularizer import * +from torch_points3d.core.schedulers.bn_schedulers import instantiate_bn_scheduler +from torch_points3d.core.schedulers.lr_schedulers import instantiate_scheduler +from torch_points3d.utils.colors import colored_print, COLORS +from torch_points3d.utils.enums import SchedulerUpdateOn +from .model_interface import TrackerInterface, DatasetInterface, CheckpointInterface + +log = logging.getLogger(__name__) + + +class BaseModel(torch.nn.Module, TrackerInterface, DatasetInterface, CheckpointInterface): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + """ + + __REQUIRED_DATA__: List[str] = [] + __REQUIRED_LABELS__: List[str] = [] + + def __init__(self, opt): + """Initialize the BaseModel class. + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): specify the images that you want to display and save. + -- self.visual_names (str list): define networks used in our training. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + super(BaseModel, self).__init__() + self.opt = opt + self.loss_names = [] + self.visual_names = [] + self.output = None + self.model = None + self._conv_type = opt.conv_type if hasattr(opt, 'conv_type') else None # Update to OmegaConv 2.0 + self._optimizer: Optional[Optimizer] = None + self._lr_scheduler: Optimizer[_LRScheduler] = None + self._bn_scheduler = None + self._spatial_ops_dict: Dict = {} + self._num_epochs = 0 + self._num_batches = 0 + self._num_samples = -1 + self._schedulers = {} + self._accumulated_gradient_step = 1 + self._grad_clip = -1 + self._grad_scale = None + self._supports_mixed = False + self._enable_mixed = False + self._update_lr_scheduler_on = "on_epoch" + self._update_bn_scheduler_on = "on_epoch" + + @property + def schedulers(self): + return self._schedulers + + @schedulers.setter + def schedulers(self, schedulers): + if schedulers: + self._schedulers = schedulers + for scheduler_name, scheduler in schedulers.items(): + setattr(self, "_{}".format(scheduler_name), scheduler) + + def _add_scheduler(self, scheduler_name, scheduler): + setattr(self, "_{}".format(scheduler_name), scheduler) + self._schedulers[scheduler_name] = scheduler + + @property + def optimizer(self): + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self._optimizer = optimizer + + @property + def grad_scale(self): + return self._grad_scale + + @grad_scale.setter + def grad_scale(self, grad_scale): + self._grad_scale = grad_scale + + @property + def num_epochs(self): + return self._num_epochs + + @num_epochs.setter + def num_epochs(self, num_epochs): + self._num_epochs = num_epochs + + @property + def num_batches(self): + return self._num_batches + + @num_batches.setter + def num_batches(self, num_batches): + self._num_batches = num_batches + + @property + def num_samples(self): + return self._num_samples + + @num_samples.setter + def num_samples(self, num_samples): + self._num_samples = num_samples + + @property + def learning_rate(self): + for param_group in self.optimizer.param_groups: + return param_group["lr"] + + @property + def device(self): + return next(self.parameters()).device + + @property + def conv_type(self): + return self._conv_type + + @conv_type.setter + def conv_type(self, conv_type): + self._conv_type = conv_type + + def is_mixed_precision(self): + return self._supports_mixed and self._enable_mixed + + def set_input(self, input, device): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + Parameters: + input (dict): includes the data itself and its metadata information. + """ + raise NotImplementedError + + def load_state_dict_with_same_shape(self, weights, strict=False): + model_state = self.state_dict() + filtered_weights = {k: v for k, v in weights.items() if k in model_state and v.size() == model_state[k].size()} + unmatched_weights = [k for k, v in weights.items() if k not in model_state or v.size() != model_state[k].size()] + + log.info("Loading weights:" + ", ".join(filtered_weights.keys())) + if len(unmatched_weights) > 0: + log.info("These weights did not match:" + ", ".join(unmatched_weights)) + self.load_state_dict(filtered_weights, strict=strict) + + def set_pretrained_weights(self): + path_pretrained = getattr(self.opt, "path_pretrained", None) + weight_name = getattr(self.opt, "weight_name", "latest") + + if path_pretrained is not None: + if not os.path.exists(path_pretrained): + raise FileNotFoundError("The path does not exist, it will not load any model") + else: + log.info("load pretrained weights from {}".format(path_pretrained)) + m = torch.load(path_pretrained, map_location="cpu")["models"][weight_name] + self.load_state_dict_with_same_shape(m, strict=False) + + def get_labels(self): + """returns a tensor of size ``[N_points]`` where each value is the label of a point""" + return getattr(self, "labels", None) + + def get_batch(self): + """returns a tensor of size ``[N_points]`` where each value is the batch index of a point""" + return getattr(self, "batch_idx", None) + + def get_output(self): + """returns a tensor of size ``[N_points,...]`` where each value is the output + of the network for a point (output of the last layer in general) + """ + return self.output + + def get_input(self): + """returns the last input that was given to the model or raises error""" + return getattr(self, "input") + + def forward(self, *args, **kwargs) -> Any: + """Run forward pass; called by both functions and .""" + raise NotImplementedError("You must implement your own forward") + + def _manage_optimizer_zero_grad(self): + if self._accumulated_gradient_step == 1: + self._optimizer.zero_grad() # clear existing gradients + return True + else: + if self._accumulated_gradient_count == self._accumulated_gradient_step: + self._accumulated_gradient_count = 0 + return True + + if self._accumulated_gradient_count == 0: + self._optimizer.zero_grad() # clear existing gradients + self._accumulated_gradient_count += 1 + return False + + def _do_scheduler_update(self, update_scheduler_on, scheduler, epoch, batch_size, num_batches): + if hasattr(self, update_scheduler_on): + update_scheduler_on = getattr(self, update_scheduler_on) + if update_scheduler_on is None: + raise Exception("The optimizer does not seems to be instantiated (instantiate_optimizers).") + + num_steps = 0 + step_size = epoch + if update_scheduler_on == SchedulerUpdateOn.ON_EPOCH.value: + num_steps = epoch - self._num_epochs + elif update_scheduler_on == SchedulerUpdateOn.ON_NUM_BATCH.value: + num_steps = 1 + step_size = self._num_batches / num_batches + elif update_scheduler_on == SchedulerUpdateOn.ON_NUM_SAMPLE.value: + num_steps = batch_size + + for _ in range(num_steps): + scheduler.step(step_size) + else: + raise Exception("The attributes {} should be defined within self".format(update_scheduler_on)) + + def optimize_parameters(self, epoch, batch_size, num_batches): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + + with torch.cuda.amp.autocast(enabled=self.is_mixed_precision()): # enable autocasting if supported + self(epoch=epoch) # first call forward to calculate intermediate results + + self.loss = self._grad_scale.scale(self.loss / self._accumulated_gradient_step) # scale losses if needed + make_optimizer_step = self._manage_optimizer_zero_grad() # Accumulate gradient if option is up + self.backward() # calculate gradients + + if make_optimizer_step: + if self._grad_clip > 0: + self._grad_scale.unscale_(self._optimizer) # unscale losses to orig + torch.nn.utils.clip_grad_value_(self.parameters(), self._grad_clip) + + self._grad_scale.step(self._optimizer) # update parameters + self._grad_scale.update() # update scaling + + if self._lr_scheduler: + self._do_scheduler_update("_update_lr_scheduler_on", self._lr_scheduler, epoch, batch_size, num_batches) + + if self._bn_scheduler: + self._do_scheduler_update("_update_bn_scheduler_on", self._bn_scheduler, epoch, batch_size, num_batches) + + self._num_epochs = epoch + self._num_batches += 1 + self._num_samples += batch_size + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # calculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss.backward() # calculate gradients of network G w.r.t. loss_G + + def get_current_losses(self): + """Return training losses / errors. train.py will print out these errors on console""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + if hasattr(self, name): + try: + errors_ret[name] = float(getattr(self, name)) + except: + errors_ret[name] = None + return errors_ret + + def get_parameter_list(self) -> List[dict]: + return [{"params": self.parameters()}] + + def init_train_objects(self, config): + self.init_optim(config) + + self.init_schedulers(config) + + # Accumulated gradients + self._accumulated_gradient_step = self.get_from_opt( + config, ["training", "optim", "accumulated_gradient"], default_value=1 + ) + if self._accumulated_gradient_step > 1: + self._accumulated_gradient_count = 0 + # Gradient clipping + self._grad_clip = self.get_from_opt(config, ["training", "optim", "grad_clip"], default_value=-1) + + self.init_grad_scaler(config) + + def init_optim(self, config): + # Optimiser + optimizer_opt = self.get_from_opt( + config, + ["training", "optim", "optimizer"], + msg_err="optimizer needs to be defined within the training config", + ) + optimizer_cls_name = optimizer_opt.get("class") + if optimizer_cls_name == "AdaBelief": + optimizer_cls = AdaBelief + else: + optimizer_cls = getattr(torch.optim, optimizer_cls_name) + optimizer_params = {} + if hasattr(optimizer_opt, "params"): + optimizer_params = optimizer_opt.params + self._optimizer = optimizer_cls(self.get_parameter_list(), **optimizer_params) + + def init_grad_scaler(self, config): + # Gradient Scaling + self._enable_mixed = self.get_from_opt(config, ["training", "enable_mixed"], default_value=False) + self._enable_mixed = bool(self._enable_mixed) + if self._enable_mixed and not self._supports_mixed: + self._enable_mixed = False + log.warning("Mixed precision is not supported on this model, using default precision...") + elif self.is_mixed_precision(): + log.info("Model will use mixed precision") + self._grad_scale = torch.cuda.amp.GradScaler(enabled=self.is_mixed_precision()) + + def init_schedulers(self, config): + # LR Scheduler + scheduler_opt = self.get_from_opt(config, ["training", "optim", "lr_scheduler"]) + if scheduler_opt: + update_lr_scheduler_on = config.get('update_lr_scheduler_on') # Update to OmegaConf 2.0 + if update_lr_scheduler_on: + self._update_lr_scheduler_on = update_lr_scheduler_on + scheduler_opt.update_scheduler_on = self._update_lr_scheduler_on + lr_scheduler = instantiate_scheduler(self._optimizer, scheduler_opt) + self._add_scheduler("lr_scheduler", lr_scheduler) + # BN Scheduler + bn_scheduler_opt = self.get_from_opt(config, ["training", "optim", "bn_scheduler"]) + if bn_scheduler_opt: + update_bn_scheduler_on = config.get('update_bn_scheduler_on') # update to OmegaConf 2.0 + if update_bn_scheduler_on: + self._update_bn_scheduler_on = update_bn_scheduler_on + bn_scheduler_opt.update_scheduler_on = self._update_bn_scheduler_on + bn_scheduler = instantiate_bn_scheduler(self, bn_scheduler_opt) + self._add_scheduler("bn_scheduler", bn_scheduler) + + def get_regularization_loss(self, regularizer_type="L2", **kwargs): + loss = 0 + regularizer_cls = RegularizerTypes[regularizer_type.upper()].value + regularizer = regularizer_cls(self, **kwargs) + return regularizer.regularized_all_param(loss) + + def get_spatial_ops(self): + return self._spatial_ops_dict + + def enable_dropout_in_eval(self): + def search_from_key(modules): + for _, m in modules.items(): + if "Dropout" in m.__class__.__name__: + m.train() + search_from_key(m._modules) + + search_from_key(self._modules) + + def enable_bn_in_eval(self): + def search_from_key(modules): + for _, m in modules.items(): + if "BatchNorm" in m.__class__.__name__: + m.train() + search_from_key(m._modules) + + search_from_key(self._modules) + + def get_from_opt(self, opt, keys=[], default_value=None, msg_err=None, silent=True): + if len(keys) == 0: + raise Exception("Keys should not be empty") + value_out = default_value + + def search_with_keys(args, keys, value_out): + if len(keys) == 0: + value_out = args + return value_out + value = args[keys[0]] + return search_with_keys(value, keys[1:], value_out) + + try: + value_out = search_with_keys(opt, keys, value_out) + except Exception as e: + if msg_err: + raise Exception(str(msg_err)) + else: + if not silent: + log.exception(e) + value_out = default_value + return value_out + + def get_current_visuals(self): + """Return an OrderedDict containing associated tensors within visual_names""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def log_optimizers(self): + colored_print(COLORS.Green, "Optimizer: {}".format(self._optimizer)) + colored_print(COLORS.Green, "Learning Rate Scheduler: {}".format(self._lr_scheduler)) + colored_print(COLORS.Green, "BatchNorm Scheduler: {}".format(self._bn_scheduler)) + colored_print(COLORS.Green, "Accumulated gradients: {}".format(self._accumulated_gradient_step)) + + def to(self, *args, **kwargs): + super().to(*args, *kwargs) + if self.optimizer: + for state in self.optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.to(*args, **kwargs) + return self + + def verify_data(self, data, forward_only=False): + """Goes through the __REQUIRED_DATA__ and __REQUIRED_LABELS__ attribute of the model + and verifies that the passed data object contains all required members. + If something is missing it raises a KeyError exception. + """ + missing_keys = [] + required_attributes = self.__REQUIRED_DATA__ + if not forward_only: + required_attributes += self.__REQUIRED_LABELS__ + for attr in required_attributes: + if not hasattr(data, attr) or data[attr] is None: + missing_keys.append(attr) + if len(missing_keys): + raise KeyError( + "Missing attributes in your data object: {}. The model will fail to forward.".format(missing_keys) + ) + + def print_transforms(self): + message = "" + for attr in self.__dict__: + if "transform" in attr: + message += "{}{} {}= {}\n".format(COLORS.IPurple, attr, COLORS.END_NO_TOKEN, getattr(self, attr)) + print(message) diff --git a/torch-points3d/torch_points3d/models/instance/__init__.py b/torch-points3d/torch_points3d/models/instance/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/models/instance/base.py b/torch-points3d/torch_points3d/models/instance/base.py new file mode 100644 index 0000000..d2cd8a0 --- /dev/null +++ b/torch-points3d/torch_points3d/models/instance/base.py @@ -0,0 +1,728 @@ +import logging +from functools import partial +from typing import Any + +import numpy as np +import torch +import torch.nn.functional as F + +from torch_points3d.core.losses.focal_loss import focal_ce +from torch_points3d.core.losses.mixture_losses import discretized_mix_logistic_loss, to_one_hot, mix_gaussian_loss +from torch_points3d.models.base_architectures import BackboneBasedModel +from torch_points3d.models.base_model import BaseModel +from torch_points3d.models.model_interface import InstanceTrackerInterface + +log = logging.getLogger(__name__) + + +def mape(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + mask = y != 0 + error = torch.zeros_like(y) + error[mask] = torch.abs((y[mask] - x[mask]) / y[mask]) + + if reduce: + return error.mean() + else: + return error + + +def smape(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + error = ((y - x).abs() / (torch.abs(x) + torch.abs(y) + torch.finfo(torch.float16).eps)) + + if reduce: + return error.mean() + else: + return error + + +def smoothl1_zero(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(x[~mask], y[~mask], reduction="none") + + if reduce: + return error.mean() + else: + return error + + +def smoothl1_zero10(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(x[mask], -torch.ones_like(y)[mask] * 10, reduction="none") + error[~mask] = F.smooth_l1_loss(x[~mask], y[~mask], reduction="none") + + if reduce: + return error.mean() + else: + return error + + +def smoothl1_zero_db(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + ori_x = x[::2] + + y = y[::2] + aug_x = x[1::2] + + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(ori_x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(ori_x[~mask], y[~mask], reduction="none") + + # huber loss + beta = 1.0 + diff = F.relu(y - aug_x) + error_aug = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + + diff = F.relu(ori_x - aug_x) + error_augx = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + + loss = 2 * error + .5 * error_aug + .5 * error_augx + + if reduce: + return loss.mean() + else: + return loss + + +def smoothl1_zero_db5(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + ori_x = x[::2] + + y = y[::2] + aug_x = x[1::2] + + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(ori_x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(ori_x[~mask], y[~mask], reduction="none") + + # huber loss + beta = 1.0 + t = 0.01 + diff_aug = F.relu(y + t - aug_x) + error_aug = torch.where(diff_aug < beta, 0.5 * diff_aug ** 2 / beta, diff_aug - 0.5 * beta) + + diff_augx = F.relu(ori_x + t - aug_x) + error_augx = torch.where(diff_augx < beta, 0.5 * diff_augx ** 2 / beta, diff_augx - 0.5 * beta) + + if reduce: + loss = 2 * error.mean() + .5 * error_aug.sum() / ((diff_aug != 0).sum() + torch.finfo(torch.float16).eps) \ + + .5 * error_augx.sum() / ((diff_augx != 0).sum() + torch.finfo(torch.float16).eps) + else: + loss = 2 * error + .5 * error_aug + .5 * error_augx + + return loss + + +def smoothl1_zero_db4(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + ori_x = x[::2] + + y = y[::2] + aug_x = x[1::2] + + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(ori_x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(ori_x[~mask], y[~mask], reduction="none") + + # huber loss + beta = 1.0 + diff_aug = F.relu(y - aug_x) + error_aug = torch.where(diff_aug < beta, 0.5 * diff_aug ** 2 / beta, diff_aug - 0.5 * beta) + + diff_augx = F.relu(ori_x - aug_x) + error_augx = torch.where(diff_augx < beta, 0.5 * diff_augx ** 2 / beta, diff_augx - 0.5 * beta) + + if reduce: + loss = 2 * error.mean() + .5 * error_aug.sum() / ((diff_aug != 0).sum() + torch.finfo(torch.float16).eps) \ + + .5 * error_augx.sum() / ((diff_augx != 0).sum() + torch.finfo(torch.float16).eps) + else: + loss = 2 * error + .5 * error_aug + .5 * error_augx + + return loss + + +def smoothl1_zero_db6(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + ori_x = x[::2] + + y = y[::2] + aug_x = x[1::2] + + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(ori_x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(ori_x[~mask], y[~mask], reduction="none") + + # huber loss + beta = 1.0 + t = 0.00001 + diff_aug = F.relu(y + t - aug_x) + error_aug = torch.where(diff_aug < beta, 0.5 * diff_aug ** 2 / beta, diff_aug - 0.5 * beta) + + diff_augx = F.relu(ori_x + t - aug_x) + error_augx = torch.where(diff_augx < beta, 0.5 * diff_augx ** 2 / beta, diff_augx - 0.5 * beta) + + loss = 2 * error + .5 * error_aug + .5 * error_augx + if reduce: + loss = loss.sum() + + return loss + + +def smoothl1_zero_db7(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + ori_x = x[::2] + + y = y[::2] + aug_x = x[1::2] + + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(ori_x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(ori_x[~mask], y[~mask], reduction="none") + + # huber loss + beta = 1.0 + diff_aug = F.relu(y - aug_x) + error_aug = torch.where(diff_aug < beta, 0.5 * diff_aug ** 2 / beta, diff_aug - 0.5 * beta) + + diff_augx = F.relu(ori_x - aug_x) + error_augx = torch.where(diff_augx < beta, 0.5 * diff_augx ** 2 / beta, diff_augx - 0.5 * beta) + + loss = 2 * error + .5 * error_aug + .5 * error_augx + if reduce: + loss = loss.sum() + + return loss + + +def smoothl1_zero_db3(x: torch.Tensor, y: torch.Tensor, reduce: bool = True): + ori_x = x[::2] + + y = y[::2] + aug_x = x[1::2] + + mask = y == 0 + error = torch.zeros_like(y) + # replace 0 with -1 + error[mask] = F.smooth_l1_loss(ori_x[mask], -torch.ones_like(y)[mask], reduction="none") + error[~mask] = F.smooth_l1_loss(ori_x[~mask], y[~mask], reduction="none") + + # huber loss + beta = 1.0 + t = 0.01 # minimal increase to + diff = F.relu(y + t - aug_x) + error_aug = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + + diff = F.relu(ori_x + t - aug_x) + error_augx = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta) + + loss = 2 * error + .5 * error_aug + .5 * error_augx + + if reduce: + return loss.mean() + else: + return loss + + +REG_LOSSES = { + "smoothl1": F.smooth_l1_loss, + "smoothl1_zero": smoothl1_zero, + "smoothl1_zero10": smoothl1_zero10, + "smoothl1_zero_db": smoothl1_zero_db, + "smoothl1_zero_db3": smoothl1_zero_db3, + "smoothl1_zero_db4": smoothl1_zero_db4, + "smoothl1_zero_db5": smoothl1_zero_db5, + "smoothl1_zero_db6": smoothl1_zero_db6, + "smoothl1_zero_db7": smoothl1_zero_db7, + "l2": F.mse_loss, + "l1": F.l1_loss, + "mape": mape, + "smape": smape, +} + +MOL_LOSSES = { + "dml": discretized_mix_logistic_loss, + "cml": mix_gaussian_loss, + # "clml": mix_loggaussian_loss, + "focal_dml": partial(discretized_mix_logistic_loss, gamma=2.0), + "focal_cml": partial(mix_gaussian_loss, gamma=2.0), + # "focal_clml": partial(mix_loggaussian_loss, gamma=2.0), +} + +CLS_LOSSES = { + "ce": partial(F.cross_entropy, label_smoothing=0.1), + "focal_ce": partial(focal_ce, label_smoothing=0.1), +} + + +def linear(x): return x + + +OUT_ACT = { + "linear": linear, + "elu": partial(F.elu, inplace=True), + "elu10": partial(F.elu, alpha=10, inplace=True), + "relu": partial(F.relu, inplace=True), +} + + +class InstanceBase(BaseModel, InstanceTrackerInterface): + def __init__(self, option, model_type, dataset, modules): + super().__init__(option) + self.visual_names = ["data_visual"] + + self.loss_fns = {} + self.has_reg_targets = dataset.has_reg_targets + self.has_mol_targets = dataset.has_mol_targets + self.has_cls_targets = dataset.has_cls_targets + self.reg_targets_idx = dataset.reg_targets_idx + self.mol_targets_idx = dataset.mol_targets_idx + self.cls_targets_idx = dataset.cls_targets_idx + + if self.has_reg_targets: + self.loss_names.append("loss_reg") + + self.get_task_weights_scale_center( + dataset, task="regression", short_task="reg", default_norm="standard", targets_idx=self.reg_targets_idx + ) + + self.reg_out_act = OUT_ACT[option.get("reg_out_activation", "linear").lower()] + self.reg_report_out_act = OUT_ACT[option.get("reg_out_report_activation", "linear").lower()] + + loss_strs = option.get("reg_loss_fn", "smoothl1") + if len(loss_strs) > 0: + loss_strs = loss_strs.split(",") + self.loss_fns["reg"] = [] + for loss_str in loss_strs: + loss = REG_LOSSES[loss_str] + self.loss_fns["reg"].append(loss) + + if self.has_mol_targets: + self.loss_names.append("loss_mol") + + self.get_task_weights_scale_center( + dataset, task="mol", short_task="mol", default_norm="min-max", targets_idx=self.mol_targets_idx + ) + + self.num_mixtures = [dataset.targets[target].get("num_mixtures", 1) for target in dataset.targets + if dataset.targets[target]["task"] == "mol"] + self.num_mol_intervals = np.array( + [dataset.targets[target].get("class_tol", .1) for target in dataset.targets + if dataset.targets[target]["task"] == "mol"]) + self.num_mol_intervals = np.round(self.mol_scale_targets[0] / self.num_mol_intervals) + # make even + self.num_mol_intervals += self.num_mol_intervals % 2 + + # self.use_logspace_out = False + + loss_strs = option.get("mol_loss_fn", "dml") + if len(loss_strs) > 0: + loss_strs = loss_strs.split(",") + self.loss_fns["mol"] = [] + for loss_str in loss_strs: + loss = MOL_LOSSES[loss_str] + # if loss_str == "clml": # model output directly in logspace + # self.use_logspace_out = True + self.loss_fns["mol"].append(loss) + else: + self.num_mixtures = [] + + if self.has_cls_targets: + self.loss_names.append("loss_cls") + loss_strs = option.get("cls_loss_fn", "ce") + + weights = [dataset.targets[target].get("weight", 1) for target in dataset.targets + if dataset.targets[target]["task"] == "classification"] + self.register_buffer("cls_weights", torch.tensor(weights, dtype=torch.float)) + + self.loss_fns["cls"] = [] + if len(loss_strs) > 0: + loss_strs = loss_strs.split(",") + for loss_str in loss_strs: + loss = CLS_LOSSES[loss_str] + self.loss_fns["cls"].append(loss) + + self.num_reg_classes = dataset.num_reg_classes + self.num_mol_classes = dataset.num_mol_classes + self.num_cls_classes = dataset.num_cls_classes + + # model overrides dataset settings + self.double_batch = option.get("double_batch", dataset.double_batch) + + def get_task_weights_scale_center(self, dataset, task, short_task, default_norm, targets_idx): + center = np.zeros(sum(targets_idx)) + scale = np.ones(sum(targets_idx)) + i = 0 + weights = [] + for target in dataset.targets: + if dataset.targets[target]["task"] == task: + weights.append(dataset.targets[target].get("weight", 1)) + normalization = dataset.targets[target].get("normalization", default_norm) + idx = np.zeros_like(targets_idx) + idx[i] = True + if normalization == "standard": + center[i] = self.get_dataset_avg_stat(dataset, "mean", default=0.0, feat_idx=idx) + scale[i] = self.get_dataset_avg_stat(dataset, "std", default=1.0, feat_idx=idx) + elif normalization == "min-max": + center[i] = self.get_dataset_avg_stat(dataset, "min", default=0.0, feat_idx=idx) + scale[i] = self.get_dataset_avg_stat(dataset, "max", default=1.0, feat_idx=idx) - center[i] + else: + if normalization != "none": + log.warning(f"'{normalization}' is not a valid normalization, using no normalization") + + center[i] = dataset.targets[target].get("center_override", center[i]) + scale[i] = dataset.targets[target].get("scale_override", scale[i]) + + scale[i] *= dataset.targets[target].get("scale_mult", 1.) + i += 1 + self.register_buffer(f"{short_task}_scale_targets", torch.tensor(scale.reshape(1, -1), dtype=torch.float)) + self.register_buffer(f"{short_task}_center_targets", torch.tensor(center.reshape(1, -1), dtype=torch.float)) + self.register_buffer(f"{short_task}_weights", torch.tensor(weights, dtype=torch.float)) + + def get_dataset_avg_stat(self, dataset, stat, default, feat_idx): + value = np.array([ + area["train"][feat_idx] for area in + getattr(dataset, f"get_{stat}_targets")().values() if "train" in area + ]) + nans = np.isnan(value) + + if nans.all(0).any(): + value = np.array([default] * len(feat_idx)) + log.warning(f"All training area with no valid {stat} value, setting to {default}. " + "This is fine if reloading amodel overrides this.") + return value + elif nans.all(0).any(): + idx = np.argwhere(nans.all(0)) + value[:, idx] = default + log.warning(f"Some training area with no valid {stat} value, setting the missing to {default}. " + "This is fine if reloading amodel overrides this.") + + return np.nanmean(value, 0) + + def set_input(self, data, device): + raise NotImplemented + + def convert_outputs(self, outputs): + reg_out = mol_out = cls_out = None + if outputs is not None: + + if self.has_reg_targets: + reg_out = self.reg_out_act(outputs[:, :self.num_reg_classes]) + if self.has_mol_targets: + mol_out = outputs[:, self.num_reg_classes: self.num_reg_classes + self.num_mol_classes] + # if self.use_logspace_out: + # nr_mix = mol_out.size(1) // 3 + # mol_out[:, nr_mix:2 * nr_mix] = F.softplus(mol_out[:, nr_mix:2 * nr_mix]) + if self.has_cls_targets: + cls_out = outputs[:, self.num_reg_classes + self.num_mol_classes:] + + return reg_out, mol_out, cls_out + + def forward(self, *args, **kwargs) -> Any: + raise NotImplemented + + def compute_loss(self): + raise NotImplemented + + def compute_reg_loss(self): + if self.has_reg_targets and len(self.loss_fns["reg"]) > 0 and self.reg_y_mask.any(): + self.loss_reg = 0 + # scaling by std to have equal grads + output = self.reg_out + labels = ((self.reg_y - self.reg_center_targets) / self.reg_scale_targets) + if self.training and self.double_batch: + output2 = self.reg_out2 + + if not self.reg_y_mask.all(): + output = output[self.reg_y_mask] + labels = labels[self.reg_y_mask] + if self.training and self.double_batch: + output2 = output2[self.reg_y_mask] + + for loss_fn in self.loss_fns["reg"]: + + if self.training and self.double_batch: + self.loss_reg += ( + (0.5 * loss_fn(output, labels, reduce=False)) + + (0.5 * loss_fn(output2, labels, reduce=False)) + ).mean() + else: + self.loss_reg += loss_fn(output, labels, reduce=True) + + self.loss += self.reg_weights.mean() * self.loss_reg + + def compute_mol_loss(self): + if self.has_mol_targets and len(self.loss_fns["mol"]) > 0 and self.mol_y_mask.any(): + # iterate through each mol task + i_mixtures = 0 + self.loss_mol = 0 + for i, (num_mixtures, num_classes) in enumerate(zip(self.num_mixtures, self.num_mol_intervals)): + mask = torch.zeros_like(self.mol_y_mask) + # only set mask for current task + mask[:, i: i + 1] = self.mol_y_mask[:, i: i + 1] + if (~mask).all(): + continue + out_mask = torch.zeros_like(self.mol_out).bool() + out_mask[:, i_mixtures * 3: (i_mixtures + num_mixtures) * 3] = \ + mask[:, i: i + 1].repeat_interleave(num_mixtures * 3, 1) + + output = self.mol_out[out_mask].reshape(-1, num_mixtures * 3) + labels = (self.mol_y[mask].reshape(-1, 1) - self.mol_center_targets[:, [i]]) / self.mol_scale_targets[:, [i]] + labels = labels * 2 - 1 # between -1 and 1 + + if self.training and self.double_batch: + output2 = self.mol_out2[out_mask].reshape(-1, num_mixtures * 3) + + loss_mol = 0 + for loss_fn in self.loss_fns["mol"]: + if self.training and self.double_batch: + loss_mol += ( + (0.5 * loss_fn(output, labels, num_classes=num_classes, reduce=False)) + + (0.5 * loss_fn(output2, labels, num_classes=num_classes, reduce=False)) + ).mean() + else: + loss_mol += loss_fn(output, labels, num_classes=num_classes, reduce=True) + self.loss += self.mol_weights[i] * loss_mol + self.loss_mol += loss_mol + i_mixtures += num_mixtures + + def compute_cls_loss(self): + if self.has_cls_targets and len(self.loss_fns["cls"]) > 0 and self.cls_y_mask.any(): + # iterate through each classification task + i_classes = 0 + self.loss_cls = 0 + for i, num_classes in enumerate(self.num_cls_classes): + mask = torch.zeros_like(self.cls_y_mask) + # only set mask for current task + mask[:, i: i + 1] = self.cls_y_mask[:, i: i + 1] + if (~mask).all(): + continue + out_mask = torch.zeros_like(self.cls_out).bool() + out_mask[:, i_classes: i_classes + num_classes] = mask[:, i: i + 1].repeat_interleave(num_classes, 1) + + output = self.cls_out[out_mask].reshape(-1, num_classes) + labels = self.cls_y[mask] + if self.training and self.double_batch: + output2 = self.cls_out2[out_mask].reshape(-1, num_classes) + + loss_cls = 0 + for loss_fn in self.loss_fns["cls"]: + if self.training and self.double_batch: + loss_cls += ( + (0.5 * loss_fn(output, labels, reduction="none")) + + (0.5 * loss_fn(output2, labels, reduction="none")) + ).mean() + else: + loss_cls += loss_fn(output, labels) + + self.loss += self.cls_weights[i] * loss_cls + self.loss_cls += loss_cls + i_classes = i_classes + num_classes + + def get_reg_output(self): + """ returns a tensor of size ``[N_points,N_regression_targets]`` where each value is the regression output + of the network for a point (output of the last layer in general) + """ + return self.reg_report_out_act(self.reg_out * self.reg_scale_targets + self.reg_center_targets) + + def get_mol_output(self, ensemble=True): + """ returns a tensor of size ``[N_points,N_mol_targets]`` where each value is the mixture of logits output + of the network for a point (output of the last layer in general) + """ + mol_out = [] + i_mixtures = 0 + + for i, num_mixtures in enumerate(self.num_mixtures): + mixture = self.mol_out[:, i_mixtures * 3: (i_mixtures + num_mixtures) * 3] + logits = mixture[:, : num_mixtures] + means = mixture[:, num_mixtures: num_mixtures * 2] + + if ensemble: + # ensemble mixture predictions + softmax = logits.softmax(-1) + else: + # use most important mixture prediction only + softmax = to_one_hot(logits.max(-1)[1], num_mixtures) + + mol_out.append(torch.clamp((means * softmax).sum(1), min=-1, max=1)) + + i_mixtures += num_mixtures + + mol_out = torch.stack(mol_out, 1) + return (((mol_out + 1) * self.mol_scale_targets) / 2.) + self.mol_center_targets + + def get_cls_output(self): + """ returns a list of tensors for each classification task, + each of size ``[N_points,...]`` where each value is the log probability output + of the network for a point (output of the last layer in general) + """ + cls_out = [] + cls_i = 0 + for num_cls in self.num_cls_classes: + cls_out.append(self.cls_out[:, cls_i: cls_i + num_cls]) + cls_i += num_cls + + return cls_out + + def get_reg_input(self): + """ returns the last regression input that was given to the model or raises error + """ + return self.reg_y + + def get_mol_input(self): + """ returns the last mixture of logits input that was given to the model or raises error + """ + return self.mol_y + + def get_cls_input(self): + """ returns the last classification input that was given to the model or raises error + """ + return self.cls_y + + def compute_instance_loss(self): + self.compute_reg_loss() + self.compute_mol_loss() + self.compute_cls_loss() + + def load_state_dict(self, state_dict: dict, strict: bool = True): + if not self.opt.get("override_target_stats", True): + remove_dict_entry(state_dict, "reg_scale_targets") + remove_dict_entry(state_dict, "reg_center_targets") + remove_dict_entry(state_dict, "mol_scale_targets") + remove_dict_entry(state_dict, "mol_center_targets") + remove_dict_entry(state_dict, "reg_weights") + remove_dict_entry(state_dict, "mol_weights") + remove_dict_entry(state_dict, "cls_weights") + + super().load_state_dict(state_dict, strict) + + +def remove_dict_entry(dict, key): + if key in dict: + del dict[key] + log.info(f"removed '{key}', will use dataset value instead") + return dict + + +class InstanceBackboneBasedModel(BackboneBasedModel): + def __init__(self, option, model_type, dataset, modules_lib): + super().__init__(option, model_type, dataset, modules_lib) + self.visual_names = ["data_visual"] + + self.loss_fns = {} + self.has_reg_targets = dataset.has_reg_targets + self.has_mol_targets = dataset.has_mol_targets + self.has_cls_targets = dataset.has_cls_targets + + if self.has_reg_targets or self.has_mol_targets or self.has_cls_targets: + + if dataset.has_reg_targets: + self.loss_names.append("loss_reg") + scale = np.nanmean([area["train"] for area in dataset.get_std_targets().values()]) + self.register_buffer("reg_scale_targets", torch.tensor(scale.reshape(1, -1), dtype=torch.float)) + loss_strs = option.get("reg_loss_fn", "smoothl1") + if len(loss_strs) > 0: + loss_strs = loss_strs.split(",") + for loss_str in loss_strs: + loss = REG_LOSSES[loss_str] + self.loss_fns["reg"].append(loss) + + if dataset.has_mol_targets: + self.loss_names.append("loss_mol") + min = np.nanmean([area["train"] for area in dataset.get_min_targets().values()]) + max = np.nanmean([area["train"] for area in dataset.get_max_targets().values()]) + self.register_buffer("min_targets", torch.tensor(min.reshape(1, -1), dtype=torch.float)) + self.register_buffer("max_targets", torch.tensor(max.reshape(1, -1), dtype=torch.float)) + loss_strs = option.get("mol_loss_fn", "dml") + if len(loss_strs) > 0: + loss_strs = loss_strs.split(",") + for loss_str in loss_strs: + loss = MOL_LOSSES[loss_str] + self.loss_fns["mol"].append(loss) + + if dataset.has_cls_targets: + self.loss_names.append("loss_cls") + loss_strs = option.get("cls_loss_fn", "ce") + if len(loss_strs) > 0: + loss_strs = loss_strs.split(",") + for loss_str in loss_strs: + loss = CLS_LOSSES[loss_str] + self.loss_fns["cls"].append(loss) + + def set_input(self, data, device): + raise NotImplemented + + def forward(self, *args, **kwargs) -> Any: + raise NotImplemented + + def compute_loss(self): + raise NotImplemented + + def backward(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + # caculate the intermediate results if necessary; here self.output has been computed during function + # calculate loss given the input and intermediate results + self.loss.backward() # calculate gradients of network G w.r.t. loss_G + + +class Instance_MP(InstanceBackboneBasedModel): + def __init__(self, option, model_type, dataset, modules): + """Initialize this model class. + Parameters: + opt -- training/test options + A few things can be done here. + - (required) call the initialization function of BaseModel + - define loss function, visualization images, model names, and optimizers + """ + super().__init__(option, model_type, dataset, modules) # call the initialization method of RegressionBase + + nn = option.mlp_cls.nn + self.dropout = option.mlp_cls.get("dropout") + self.lin1 = torch.nn.Linear(nn[0], nn[1]) + self.lin2 = torch.nn.Linear(nn[2], nn[3]) + self.lin3 = torch.nn.Linear(nn[4], dataset.num_classes) + + def set_input(self, data, device): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + Parameters: + input: a dictionary that contains the data itself and its metadata information. + """ + data = data.to(device) + self.input = data + self.labels = data.y + self.batch_idx = data.batch + + def compute_loss(self): + self.loss_regr = 0 + labels = self.labels.view(self.output.shape) + for loss_fn in self.loss_fns: + self.loss_regr += loss_fn(self.output, labels) + + self.loss = self.loss_regr + + def forward(self, *args, **kwargs) -> Any: + """Run forward pass. This will be called by both functions and .""" + data = self.down_modules[0](self.input) + + x = F.relu(self.lin1(data.x)) + x = F.dropout(x, p=self.dropout, training=bool(self.training)) + x = self.lin2(x) + x = F.dropout(x, p=self.dropout, training=bool(self.training)) + x = self.lin3(x) + self.output = x + + if self.labels is not None: + self.compute_loss() + + self.data_visual = self.input + self.data_visual.y = self.labels + self.data_visual.pred = self.output + return self.output diff --git a/torch-points3d/torch_points3d/models/instance/kpconv.py b/torch-points3d/torch_points3d/models/instance/kpconv.py new file mode 100755 index 0000000..b2ffebe --- /dev/null +++ b/torch-points3d/torch_points3d/models/instance/kpconv.py @@ -0,0 +1,291 @@ +import logging +from typing import List + +import numpy as np +import torch +from easydict import EasyDict +from torch import nn + +from torch_points3d.models.instance.base import InstanceBase + +from torch_points3d.modules.KPConv.architectures import KPCNN +from torch_points3d.modules.KPConv.common import batch_grid_subsampling, batch_neighbors + +from time import time + +log = logging.getLogger(__name__) + + +class SeparateLinear(torch.nn.Module): + + def __init__(self, in_channel, out_channels): + super(SeparateLinear, self).__init__() + if isinstance(out_channels, int): + self.linears = nn.ModuleList([nn.Linear(in_channel, 1, bias=True) for i in range(out_channels)]) + elif isinstance(out_channels, dict): + num_reg_classes = out_channels.get("num_reg_classes", 0) + num_mixtures = out_channels.get("num_mixtures", []) + num_cls_classes = out_channels.get("num_cls_classes", []) + + self.linears = [] + if num_reg_classes > 0: + self.linears += [torch.nn.Linear(in_channel, 1, bias=True) for i in range(num_reg_classes)] + if len(num_mixtures) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_mixtures * 3, bias=True) for i, num_mixtures in + enumerate(num_mixtures) + ] + if len(num_cls_classes) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_classes) for num_classes in num_cls_classes + ] + + self.linears = torch.nn.ModuleList(self.linears) + else: + self.linears = nn.ModuleList([nn.Linear(in_channel, 1, bias=True)]) + + def forward(self, x): + return torch.cat([lin(x) for lin in self.linears], 1) + + +class KPConv(InstanceBase): + def __init__(self, option, model_type, dataset, modules): + super(KPConv, self).__init__(option, model_type, dataset, modules) + + self.config = config = option.config + + self.model = KPCNN(config) + + self.neighborhood_limits = [] + in_channel = self.model.head_mlp.mlp.weight.shape[1] + self.head = self.init_head(in_channel) + + self.dataset_num_points = dataset.dataset_opt.fixed.num_points + self.model_num_points = option.get("num_points", None) + if self.model_num_points is None: + self.should_sample = False + else: + from openpoints.models.layers import furthest_point_sample + self.furthest_point_sample = furthest_point_sample + if self.model_num_points == 1024: + self.point_all = 1200 + elif self.model_num_points == 4096: + self.point_all = 4800 + elif self.model_num_points == 6144: + self.point_all = 6900 + elif self.model_num_points == 8192: + self.point_all = 8192 + elif self.model_num_points == 12288: + self.point_all = 12288 + elif self.model_num_points == 16384: + self.point_all = 16384 + else: + raise NotImplementedError() + self.should_sample = self.model_num_points < self.dataset_num_points + self._supports_mixed = True + + self.head_optim_settings = option.get("head_optim_settings", {}) + self.backbone_optim_settings = option.get("backbone_optim_settings", {}) + + def get_parameter_list(self) -> List[dict]: + params_list = [] + head_parameters = self.head.parameters() + + backbone_parameters = self.model.parameters() + + params_list.append({"params": head_parameters, **self.head_optim_settings}) + params_list.append({"params": backbone_parameters, **self.backbone_optim_settings}) + + return params_list + + def init_head(self, in_channel): + return SeparateLinear( + in_channel, { + "num_reg_classes": self.num_reg_classes, + "num_mixtures": self.num_mixtures, + "num_cls_classes": self.num_cls_classes + } + ) + + def set_input(self, data, device): + self.data_visual = data + + points = data['pos'] + features = data['x'] + + if self.should_sample: # point resampling strategy if same number of points + points = points.view(-1, self.dataset_num_points, points.shape[-1]).to(device) + features = features.view(-1, self.dataset_num_points, features.shape[-1]).to(device) + point_all = points.size(1) if points.size(1) < self.point_all else self.point_all + fps_idx = self.furthest_point_sample(points[:, :, :3].contiguous(), point_all) + fps_idx = fps_idx[:, np.random.choice(point_all, self.model_num_points, False)] + points = torch.gather(points, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, points.shape[-1])) + features = torch.gather(features, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, features.shape[-1])) + + self.batch_idx = data.batch + lengths = data.ptr[1:] - data.ptr[:-1] + + # TODO could to this in batch pre collate + self.input = EasyDict(self.prepare_inputs( + points.view(-1, 3).cpu().numpy(), + features.view(-1, features.shape[-1]).cpu().numpy(), + lengths.numpy().astype(np.int32), + device + )) + + if len(self.loss_fns) > 0: + bs = len(data) + if self.has_reg_targets and data.y_reg is not None: + self.reg_y_mask = data.y_reg_mask.to(device).view(bs, -1) + self.reg_y = data.y_reg.to(device).view(bs, -1) + if self.has_mol_targets and data.y_mol is not None: + self.mol_y_mask = data.y_mol_mask.to(device).view(bs, -1) + self.mol_y = data.y_mol.to(device).view(bs, -1) + if self.has_cls_targets and data.y_cls is not None: + self.cls_y_mask = data.y_cls_mask.to(device).view(bs, -1) + self.cls_y = data.y_cls.to(device).view(bs, -1) + + def big_neighborhood_filter(self, neighbors, layer): + """ + Filter neighborhoods with max number of neighbors. Limit is set to keep XX% of the neighborhoods untouched. + Limit is computed at initialization + """ + + # crop neighbors matrix + if len(self.neighborhood_limits) > 0: + return neighbors[:, :self.neighborhood_limits[layer]] + else: + return neighbors + + def prepare_inputs(self, stacked_points, stacked_features, stack_lengths, device): + + # Starting radius of convolutions + r_normal = self.config.first_subsampling_dl * self.config.conv_radius + + # Starting layer + layer_blocks = [] + + # Lists of inputs + input_points = [] + input_neighbors = [] + input_pools = [] + input_stack_lengths = [] + deform_layers = [] + + ###################### + # Loop over the blocks + ###################### + + arch = self.config.architecture + L = 0 + for block_i, block in enumerate(arch): + + # Get all blocks of the layer + if not ('pool' in block or 'strided' in block or 'global' in block or 'upsample' in block): + layer_blocks += [block] + continue + L += 1 + # Convolution neighbors indices + # ***************************** + + deform_layer = False + if layer_blocks: + # Convolutions are done in this layer, compute the neighbors with the good radius + if np.any(['deformable' in blck for blck in layer_blocks]): + r = r_normal * self.config.deform_radius / self.config.conv_radius + deform_layer = True + else: + r = r_normal + conv_i = batch_neighbors(stacked_points, stacked_points, stack_lengths, stack_lengths, r) + + else: + # This layer only perform pooling, no neighbors required + conv_i = np.zeros((0, 1), dtype=np.int32) + + # Pooling neighbors indices + # ************************* + + # If end of layer is a pooling operation + if 'pool' in block or 'strided' in block: + + # New subsampling length + dl = 2 * r_normal / self.config.conv_radius + + # Subsampled points + pool_p, pool_b = batch_grid_subsampling(stacked_points, stack_lengths, sampleDl=dl) + + # Radius of pooled neighbors + if 'deformable' in block: + r = r_normal * self.config.deform_radius / self.config.conv_radius + deform_layer = True + else: + r = r_normal + + # Subsample indices + pool_i = batch_neighbors(pool_p, stacked_points, pool_b, stack_lengths, r) + + else: + # No pooling in the end of this layer, no pooling indices required + pool_i = np.zeros((0, 1), dtype=np.int32) + pool_p = np.zeros((0, 1), dtype=np.float32) + pool_b = np.zeros((0,), dtype=np.int32) + + # Reduce size of neighbors matrices by eliminating the farthest point + conv_i = self.big_neighborhood_filter(conv_i, len(input_points)) + pool_i = self.big_neighborhood_filter(pool_i, len(input_points)) + + # Updating input lists + input_points += [stacked_points] + input_neighbors += [conv_i.astype(np.int64)] + input_pools += [pool_i.astype(np.int64)] + input_stack_lengths += [stack_lengths] + deform_layers += [deform_layer] + + # New points for next layer + stacked_points = pool_p + stack_lengths = pool_b + + # Update radius and reset blocks + r_normal *= 2 + layer_blocks = [] + + # Stop when meeting a global pooling or upsampling + if 'global' in block or 'upsample' in block: + break + + ############### + # Return inputs + ############### + + # Save deform layers + + # list of network inputs + li = input_points + input_neighbors + input_pools + input_stack_lengths + li += [stacked_features, ] + + # Extract input tensors from the list of numpy array + input = {} + ind = 0 + input["points"] = [torch.from_numpy(nparray).to(device) for nparray in li[ind:ind + L]] + ind += L + input["neighbors"] = [torch.from_numpy(nparray).to(device) for nparray in li[ind:ind + L]] + ind += L + input["pools"] = [torch.from_numpy(nparray).to(device) for nparray in li[ind:ind + L]] + ind += L + input["lengths"] = [torch.from_numpy(nparray).to(device) for nparray in li[ind:ind + L]] + ind += L + input["features"] = torch.from_numpy(li[ind]).to(device) + + return input + + def compute_loss(self): + self.loss = 0 + self.compute_instance_loss() + + def forward(self, *args, **kwargs): + out = self.model(self.input) + self.output = self.head(out) + self.reg_out, self.mol_out, self.cls_out = self.convert_outputs(self.output) + self.compute_loss() + + self.data_visual.pred = self.output diff --git a/torch-points3d/torch_points3d/models/instance/minkowski.py b/torch-points3d/torch_points3d/models/instance/minkowski.py new file mode 100644 index 0000000..b5d33d5 --- /dev/null +++ b/torch-points3d/torch_points3d/models/instance/minkowski.py @@ -0,0 +1,445 @@ +import logging +from typing import List + +import MinkowskiEngine as ME +import torch +import torch.nn.functional as F +from torch_geometric.data import Batch + +from torch_points3d.models.instance.base import InstanceBase +from torch_points3d.models.instance.semi_supervised_helper import gather, invariance_loss, variance_loss, \ + covariance_loss, barlow_loss +from torch_points3d.modules.MinkowskiEngine import initialize_minkowski_unet + +log = logging.getLogger(__name__) + + +class SeparateLinear(torch.nn.Module): + + def __init__(self, in_channel, num_reg_classes, num_mixtures, num_cls_classes): + super(SeparateLinear, self).__init__() + self.linears = [] + if num_reg_classes > 0: + self.linears += [torch.nn.Linear(in_channel, 1, bias=True) for i in range(num_reg_classes)] + if len(num_mixtures) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_mixtures * 3, bias=True) for i, num_mixtures in enumerate(num_mixtures) + ] + if len(num_cls_classes) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_classes) for num_classes in num_cls_classes + ] + + self.linears = torch.nn.ModuleList(self.linears) + + def forward(self, x): + return torch.cat([lin(x.F) for lin in self.linears], 1) + + +class MinkowskiBaselineModel(InstanceBase): + def __init__(self, option, model_type, dataset, modules): + super(MinkowskiBaselineModel, self).__init__(option, model_type, dataset, modules) + self.model = initialize_minkowski_unet( + option.model_name, dataset.feature_dimension, dataset.num_classes, activation=option.activation, + first_stride=option.first_stride, global_pool=option.global_pool, bias=option.get("bias", True), + bn_momentum=option.get("bn_momentum", 0.1), norm_type=option.get("norm_type", "bn"), + dropout=option.get("dropout", 0.0), drop_path=option.get("drop_path", 0.0), + **option.get("extra_options", {}) + ) + in_channel = self.model.final.linear.weight.shape[1] + self._supports_mixed = True + self.model.final = SeparateLinear(in_channel, self.num_reg_classes, self.num_mixtures, self.num_cls_classes) + + for m in self.model.final.linears: + torch.nn.init.trunc_normal_(m.weight, std=.02) + if m.bias is not None: + torch.nn.init.constant_(m.bias, 0) + + self.head_namespace = option.get("head_namespace", "final.linears") + self.head_optim_settings = option.get("head_optim_settings", {}) + self.backbone_optim_settings = option.get("backbone_optim_settings", {}) + + self.add_pos = option.get("add_pos", False) + + def get_parameter_list(self) -> List[dict]: + params_list = [] + head_parameters, backbone_parameters = [], [] + for name, param in self.model.named_parameters(): + if self.head_namespace in name: + head_parameters.append(param) + else: + backbone_parameters.append(param) + params_list.append({"params": head_parameters, **self.head_optim_settings}) + params_list.append({"params": backbone_parameters, **self.backbone_optim_settings}) + + return params_list + + def set_input(self, data, device): + self.batch_idx = data.batch.squeeze() + coords = torch.cat([data.batch.unsqueeze(-1).int(), data.coords.int()], -1) + self.data_visual = data + features = data.x + if self.add_pos: + features = torch.cat([data.pos, features], 1) + self.input = ME.SparseTensor(features=features, coordinates=coords, device=device) + + if len(self.loss_fns) > 0: + bs = len(data) + if self.has_reg_targets and data.y_reg is not None: + self.reg_y_mask = data.y_reg_mask.to(device).view(bs, -1) + self.reg_y = data.y_reg.to(device).view(bs, -1) + if self.has_mol_targets and data.y_mol is not None: + self.mol_y_mask = data.y_mol_mask.to(device).view(bs, -1) + self.mol_y = data.y_mol.to(device).view(bs, -1) + if self.has_cls_targets and data.y_cls is not None: + self.cls_y_mask = data.y_cls_mask.to(device).view(bs, -1) + self.cls_y = data.y_cls.to(device).view(bs, -1) + + def compute_loss(self): + self.loss = 0 + self.compute_instance_loss() + + def forward(self, *args, **kwargs): + self.output = self.model(self.input) + self.reg_out, self.mol_out, self.cls_out = self.convert_outputs(self.output) + self.compute_loss() + + +class MinkowskiVAEm2(InstanceBase): + # similar to Kingma et al. https://proceedings.neurips.cc/paper/2014/file/d523773c6b194f37b938d340d5d02232-Paper.pdf + + def __init__(self, option, model_type, dataset, modules): + super(MinkowskiVAEm2, self).__init__(option, model_type, dataset, modules) + self.model = initialize_minkowski_unet( + option.model_name, dataset.feature_dimension, dataset.num_classes, activation=option.activation, + z_channels=option.z_channels, dropout=option.dropout, backbone=option.backbone, + resolution=int(1 / dataset.dataset_opt.first_subsampling), first_stride=option.first_stride, + global_pool=option.global_pool, **option.get("extra_options", {}) + ) + self.loss_names.extend( + ["loss_vae", "loss_BCE", "loss_KLD", "loss_rec", "loss_r_entropy", "loss_r_cross_entropy"] + ) + self.KLD_beta = option.KLD_beta + self.reconstruction_beta = option.reconstruction_beta + self.regression_beta = option.regression_beta + self.num_reg_classes = dataset.num_reg_classes + self.num_mol_classes = dataset.num_mol_classes + self.num_cls_classes = dataset.num_cls_classes + + def set_input(self, data, device): + self.batch_idx = data.batch.squeeze() + coords = torch.cat([data.batch.unsqueeze(-1).int(), data.coords.int()], -1) + self.data_visual = data + self.input = ME.SparseTensor(features=data.x, coordinates=coords, device=device) + self.input_target = self.input.coordinate_map_key + self.labels_mask = data.y_mask.to(device).view(-1, self.model.out_channels) + self.labels = data.y.to(device).view(-1, self.model.out_channels) + + def compute_loss(self): + # VAE loss + # loss to check if correct pruning was applied + + self.loss_BCE = 0 + for out_cl, target in zip(self.out_cls, self.rec_targets): + curr_loss = F.binary_cross_entropy_with_logits(out_cl.F.squeeze(), target.type(out_cl.F.dtype)) + self.loss_BCE += curr_loss / len(self.out_cls) + + self.loss_KLD = self.KLD_beta * 0.5 * torch.mean( + torch.mean(self.z_logvar.F.exp() + self.z_mean.F.pow(2) - 1 - self.z_logvar.F, 1) + ) + # feature reconstruction error (removes last dim as it is assumed to be 1) + rec = self.reconstruction + rec.F[:, -1] = 1 + + loss_rec = (rec - self.input).F + loss_rec = loss_rec[loss_rec[:, -1] == 0][:, :-1] + use_l1 = loss_rec > 1.0 + self.loss_rec = (torch.pow(loss_rec, 2) * ~use_l1 + torch.abs(loss_rec) * use_l1).mean() + + self.loss_vae = self.loss_KLD + self.reconstruction_beta * (self.loss_BCE + self.loss_rec) + + self.loss_regr = 0 + self.loss_r_cross_entropy = 0 + self.loss_r_entropy = 0 + # only calculate loss if labels are set + if self.labels is not None: + if self.labels_mask.any(): + # scaling by std to have equal grads + cond = (self.cond.F / self.reg_scale_targets)[self.labels_mask] + labels = (self.labels / self.reg_scale_targets)[self.labels_mask] + for loss_fn in self.loss_fns: + self.loss_regr += loss_fn(cond, labels) / len(self.loss_fns) # scaling by std to have equal grads + + self.loss = self.regression_beta * (self.loss_regr + self.loss_r_entropy + self.loss_r_cross_entropy) \ + + self.loss_vae + + return self.loss + + def forward(self, *args, **kwargs): + (self.out_cls, self.rec_targets, self.reconstruction, + self.zs, self.z_mean, self.z_logvar, + self.cond_norm, self.cond) = self.model( + self.input, self.input_target, self.labels, self.labels_mask + ) + self.output = self.cond.F + self.compute_loss() + + self.data_visual.pred = self.output + + +class MinkowskiVAE(MinkowskiVAEm2): + # similar to VAE for regression paper https://arxiv.org/abs/1904.05948 + + def compute_loss(self): + # VAE loss + # loss to check if correct pruning was applied + + self.loss_BCE = 0 + for out_cl, target in zip(self.out_cls, self.rec_targets): + curr_loss = F.binary_cross_entropy_with_logits(out_cl.F.squeeze(), target.type(out_cl.F.dtype)) + self.loss_BCE += curr_loss / len(self.out_cls) + + self.loss_KLD = self.KLD_beta * 0.5 * torch.mean( + torch.mean(self.z_logvar.F.exp() + self.z_mean.F.pow(2) - 1 - self.z_logvar.F, 1) + ) + # feature reconstruction error (removes last dim as it is assumed to be 1) + rec = self.reconstruction + rec.F[:, -1] = 1 + + loss_rec = (rec - self.input).F # TODO maybe just include intersected points + loss_rec = loss_rec[loss_rec[:, -1] == 0][:, :-1] + use_l1 = loss_rec > 1.0 + self.loss_rec = (torch.pow(loss_rec, 2) * ~use_l1 + torch.abs(loss_rec) * use_l1).mean() + + self.loss_vae = self.loss_KLD + self.reconstruction_beta * (self.loss_BCE + self.loss_rec) + + self.loss_regr = 0 + self.loss_r_cross_entropy = 0 + self.loss_r_entropy = 0 + # only calculate loss if labels are set + if self.labels is not None: + shape = self.r_mean.F.shape + if self.labels_mask.any(): + r_mean = (self.r_mean.F / self.reg_scale_targets)[self.labels_mask] + r_logvar = (self.r_logvar.F / self.reg_scale_targets.pow(2).log())[self.labels_mask] + labels = (self.labels / self.reg_scale_targets)[self.labels_mask] + for loss_fn in self.loss_fns: + self.loss_regr += ((loss_fn(r_mean, labels, reduction="none") + / (r_logvar.detach().exp() ** 0.5)).mean() # scaling by std to have equal grads + / len(self.loss_fns)) + + # cross entropy + self.loss_r_cross_entropy += 0.5 * ( + ((F.mse_loss(r_mean, labels - 1e-6 / shape[0], reduction="none")) + / (torch.exp(r_logvar))) + (r_logvar)).mean() + + # use entropy if nans are present (assumes univariate Gaussians) + if not self.labels_mask.all(): + # r_mean = self.r_mean.F[~self.labels_mask.view(shape)] + r_logvar = self.r_logvar.F[~self.labels_mask] + # removing static vars + # self.loss_r_entropy += 0.5 * (r_logvar).mean() + ''' + intuition behind including the r_mean part: keeping the prediction constant but since the + factor is small, the decoder should still be able to change it (aka it is a regularization term) + ''' + self.loss_r_entropy += 0.5 * (r_logvar).mean() + # self.loss_r_entropy += 0.5 * (((F.smooth_l1_loss(r_mean, r_mean.detach() + 1e-6, reduction="none")) + # / (torch.exp(r_logvar))) + r_logvar).mean() + # http://gregorygundersen.com/blog/2020/09/01/gaussian-entropy + # self.loss_r_entropy += 0.5 * ( r_logvar + np.log(2) + np.log(np.pi)) + 0.5 + # this would be wikipedia + # self.loss_entropy += r_logvar + np.log(np.sqrt(2 * np.pi * np.e)) + + self.loss = self.regression_beta * (self.loss_regr + self.loss_r_entropy + self.loss_r_cross_entropy) \ + + self.loss_vae + + return self.loss + + def forward(self, *args, **kwargs): + (self.out_cls, self.rec_targets, self.reconstruction, + self.zs, self.z_mean, self.z_logvar, + self.rs, self.r_mean, self.r_logvar + ) = self.model(self.input, self.input_target) + self.output = self.r_mean.F + self.compute_loss() + + self.data_visual.pred = self.output + + +class MinkowskiBarlowTwins(InstanceBase): + def __init__(self, option, model_type, dataset, modules): + super(MinkowskiBarlowTwins, self).__init__(option, model_type, dataset, modules) + model_version = option.get("model_version", "standard") + self.reset_output = option.get("reset_output", True) + self.model = initialize_minkowski_unet( + option.model_name, dataset.feature_dimension, + { + "num_reg_classes": self.num_reg_classes, + "num_mixtures": self.num_mixtures, + "num_cls_classes": self.num_cls_classes + }, + activation=option.activation, + first_stride=option.first_stride, dropout=option.dropout, global_pool=option.global_pool, + mode=option.mode, model_version=model_version, proj_activation=option.proj_activation, + proj_layers=option.proj_layers, proj_last_norm=option.proj_last_norm, backbone=option.backbone, + detach_classifier=option.mode != "finetune" and model_version == "standard", + **option.get("extra_options", {}) + ) + + self.mode = option.mode + if self.mode not in ["finetune", "freeze"]: + self.loss_names.extend( + ["loss_self_supervised"] + ) + self.scale_loss = option.scale_loss + self.backbone_lr = option.backbone_lr + self._supports_mixed = True + + def get_parameter_list(self) -> List[dict]: + params_list = [] + classifier_parameters, model_parameters = [], [] + for name, param in self.model.named_parameters(): + if "encoder.final.classifier.linears" in name: + classifier_parameters.append(param) + else: + model_parameters.append(param) + + params_list.append({"params": classifier_parameters}) + if self.mode in ["finetune", "train"]: + model_dict = {"params": model_parameters} + if self.backbone_lr != "base_lr": + model_dict["lr"] = self.backbone_lr + params_list.append(model_dict) + + return params_list + + def set_pretrained_weights(self): + super().set_pretrained_weights() + if self.mode in ["finetune", "freeze"] and self.reset_output: + log.info(f"resetting weights for final prediction layer (since we are in {self.mode} mode)") + for m in self.model.encoder.final.classifier.linears: + m.weight.data.normal_(mean=0.0, std=0.01) + m.bias.data.zero_() + + def set_input(self, data, device): + self.batch_idx = data.batch.squeeze() + + if self.training and self.double_batch: + # augment data twice + data = data.to_data_list() + x1 = Batch.from_data_list(data[::2]) + x2 = Batch.from_data_list(data[1::2]) + coords2 = torch.cat([x2.batch.unsqueeze(-1).int(), x2.coords.int()], -1) + self.input2 = ME.SparseTensor(features=x2.x, coordinates=coords2, device=device) + else: + x1 = data + self.input2 = None + + bs = len(x1) + coords = torch.cat([x1.batch.unsqueeze(-1).int(), x1.coords.int()], -1) + self.data_visual = x1 + + self.input = ME.SparseTensor(features=x1.x, coordinates=coords, device=device) + + if len(self.loss_fns) > 0: + if self.has_reg_targets and x1.y_reg is not None: + self.reg_y_mask = x1.y_reg_mask.to(device).view(bs, -1) + self.reg_y = x1.y_reg.to(device).view(bs, -1) + if self.has_mol_targets and x1.y_mol is not None: + self.mol_y_mask = x1.y_mol_mask.to(device).view(bs, -1) + self.mol_y = x1.y_mol.to(device).view(bs, -1) + if self.has_cls_targets and x1.y_cls is not None: + self.cls_y_mask = x1.y_cls_mask.to(device).view(bs, -1) + self.cls_y = x1.y_cls.to(device).view(bs, -1) + + def compute_loss(self): + self.loss = 0 + self.compute_instance_loss() + if self.mode not in ["finetune", "freeze"]: + self.compute_self_supervised_loss() + + def compute_self_supervised_loss(self): + # barlow loss + # empirical cross-correlation matrix + self.loss_self_supervised = 0 + if self.training and self.double_batch: + self.loss_self_supervised += barlow_loss( + self.z1, self.z2, self.scale_loss["lambda"] + ) + self.loss += self.scale_loss["all"] * self.loss_self_supervised + + def compute_instance_loss(self): + self.compute_reg_loss() + self.compute_mol_loss() + self.compute_cls_loss() + + def forward(self, *args, **kwargs): + self.set_mode() + self.output, self.output2, self.z1, self.z2 = self.model(self.input, self.input2) + self.reg_out, self.mol_out, self.cls_out = self.convert_outputs(self.output) + self.reg_out2, self.mol_out2, self.cls_out2 = self.convert_outputs(self.output2) + + self.compute_loss() + self.data_visual.pred = self.output + + def set_mode(self): + if self.training: + if self.mode == "freeze": + self.model.requires_grad_(False) + self.model.encoder.final.classifier.requires_grad_(True) + self.model.encoder.eval() + self.model.encoder.final.classifier.train() + self.enable_dropout_in_eval() + + +class MinkowskiVICReg(MinkowskiBarlowTwins): + + def __init__(self, option, model_type, dataset, modules): + super(MinkowskiVICReg, self).__init__(option, model_type, dataset, modules) + + if self.mode not in ["finetune", "freeze"]: + self.loss_names.extend( + ["loss_invariance", "loss_variance", "loss_covariance"] + ) + + def compute_self_supervised_loss(self): + # barlow loss + # empirical cross-correlation matrix + self.loss_self_supervised = 0 + if self.training and self.mode == "train": + # from https://github.com/vturrisi/solo-learn/blob/6f19d5dc38fb6521e7fdd6aed5ac4a30ef8f3bd8/solo/losses/vicreg.py#L83 + z1, z2 = self.z1, self.z2 + # invariance loss + self.loss_invariance = invariance_loss(z1, z2) + + # vicreg's official code gathers the tensors here + # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py + z1, z2 = gather(z1), gather(z2) + + # variance_loss + self.loss_variance = variance_loss(z1, z2) + self.loss_covariance = covariance_loss(z1, z2) + loss = self.scale_loss["invariance"] * self.loss_invariance + \ + self.scale_loss["variance"] * self.loss_variance + \ + self.scale_loss["covariance"] * self.loss_covariance + + self.loss_self_supervised += loss + self.loss += self.loss_self_supervised + + # def calc_VICReg_loss(self, z1, z2, G=None): + # # following https://arxiv.org/pdf/2205.11508.pdf + # N, D = z1.size() + # V = 2 + # if G is None: + # G = torch.zeros(V*N, V*N) # X′ ∈ R N ′×D′, V is the number of views + # i = torch.arange(0, N * V).repeat_interleave(V - 1) # row indices + # j= (i + torch.arange(1, V).repeat(N * V) * N).remainder(N * V) # column indices + # G[i, j] = 1 # unweighted graph connecting the rows of View_1(X′ ), . . . , View_V (X′) + # + # C = torch.cov(z.t()) + # eps = 1e-4 + # self.loss_variance += D - torch.diag(C).clamp(eps).sqrt().sum() + # i, j = G.nonzero(as_tuple=True) + # self.loss_invariance += (z[i] - z[j]).square().sum().inner(G[i, j]) / N + # self.loss_covariance += 2 * torch.triu(C, diagonal=1).square().sum() diff --git a/torch-points3d/torch_points3d/models/instance/pointnext.py b/torch-points3d/torch_points3d/models/instance/pointnext.py new file mode 100755 index 0000000..bf9fdf1 --- /dev/null +++ b/torch-points3d/torch_points3d/models/instance/pointnext.py @@ -0,0 +1,447 @@ +import logging +from typing import List + +import numpy as np +import torch +from torch import nn +from torch_geometric.data import Batch + +from openpoints.models.layers import create_act +from torch_points3d.core.common_modules import FastBatchNorm1d +from torch_points3d.models.instance.base import InstanceBase +from torch_points3d.models.instance.semi_supervised_helper import invariance_loss, gather, variance_loss, \ + covariance_loss, barlow_loss + +log = logging.getLogger(__name__) + + +class SeparateLinear(torch.nn.Module): + + def __init__(self, in_channel, out_channels): + super(SeparateLinear, self).__init__() + if isinstance(out_channels, int): + self.linears = nn.ModuleList([nn.Linear(in_channel, 1, bias=True) for i in range(out_channels)]) + elif isinstance(out_channels, dict): + num_reg_classes = out_channels.get("num_reg_classes", 0) + num_mixtures = out_channels.get("num_mixtures", []) + num_cls_classes = out_channels.get("num_cls_classes", []) + + self.linears = [] + if num_reg_classes > 0: + self.linears += [torch.nn.Linear(in_channel, 1, bias=True) for i in range(num_reg_classes)] + if len(num_mixtures) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_mixtures * 3, bias=True) for i, num_mixtures in + enumerate(num_mixtures) + ] + if len(num_cls_classes) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_classes) for num_classes in num_cls_classes + ] + + self.linears = torch.nn.ModuleList(self.linears) + else: + self.linears = nn.ModuleList([nn.Linear(in_channel, 1, bias=True)]) + + def forward(self, x): + return torch.cat([lin(x) for lin in self.linears], 1) + + +class PointNext(InstanceBase): + def __init__(self, option, model_type, dataset, modules): + super(PointNext, self).__init__(option, model_type, dataset, modules) + from openpoints.models import build_model_from_cfg + from openpoints.utils import EasyConfig + from openpoints.models.layers import furthest_point_sample + self.furthest_point_sample = furthest_point_sample + stride = option.stride + use_mlps = option.get("use_mlps", True) + radius_scaling = option.get("radius_scaling", 2.) + radius = option.get("radius", 0.1) + nsample = option.get("nsample", 32) + act = option.get("activation", "relu") + act_args = EasyConfig({'act': act}) + if act in ["elu", "celu"]: + act_args["alpha"] = 0.54 + + MODEL = { + "pointnet": EasyConfig({ + 'NAME': 'BaseCls', + 'encoder_args': EasyConfig({ + 'NAME': 'PointNetEncoder', + 'in_channels': dataset.feature_dimension, + 'is_seg': False, + 'input_transform': False, + }), + 'cls_args': EasyConfig({ + 'NAME': 'ClsHead', + 'num_classes': dataset.num_classes, + 'act_args': act_args, + 'mlps': [512, 256, 128, 128] if use_mlps else [], + }) + }), + "pointnext_s": EasyConfig({ + 'NAME': 'BaseCls', + 'encoder_args': EasyConfig({ + 'NAME': 'PointNextEncoder', + "blocks": [1, 1, 1, 1, 1, 1], + 'strides': [1, stride, stride, stride, stride, 1], + 'width': 32, + 'in_channels': dataset.feature_dimension, + 'radius': radius, + 'radius_scaling': radius_scaling, + 'sa_layers': 2, + 'sa_use_res': True, + 'nsample': nsample, + 'expansion': 4, + 'aggr_args': EasyConfig({'feature_type': 'dp_fj', 'reduction': 'max'}), + 'group_args': EasyConfig({'NAME': 'ballquery', 'normalize_dp': True}), + 'conv_args': EasyConfig({'order': 'conv-norm-act'}), + 'act_args': act_args, + 'norm_args': EasyConfig({'norm': 'bn'}) + }), + 'cls_args': EasyConfig({ + 'NAME': 'ClsHead', + 'num_classes': dataset.num_classes, + 'act_args': act_args, + 'mlps': [512, 256] if use_mlps else [], + 'norm_args': EasyConfig({'norm': 'bn1d'}) + }) + }), + "pointnext_b": EasyConfig({ + 'NAME': 'BaseCls', + 'encoder_args': EasyConfig({ + 'NAME': 'PointNextEncoder', + 'blocks': [1, 2, 3, 2, 1, 1], + 'strides': [1, stride, stride, stride, stride, 1], + 'width': 32, + 'in_channels': dataset.feature_dimension, + 'radius': radius, + 'radius_scaling': radius_scaling, + 'sa_layers': 1, + 'sa_use_res': False, + 'nsample': nsample, + 'expansion': 4, + 'aggr_args': EasyConfig({'feature_type': 'dp_fj', 'reduction': 'max'}), + 'group_args': EasyConfig({'NAME': 'ballquery', 'normalize_dp': True}), + 'conv_args': EasyConfig({'order': 'conv-norm-act'}), + 'act_args': act_args, + 'norm_args': EasyConfig({'norm': 'bn'}) + }), + 'cls_args': EasyConfig({ + 'NAME': 'ClsHead', + 'num_classes': dataset.num_classes, + 'act_args': act_args, + 'mlps': [512, 256] if use_mlps else [], + 'norm_args': EasyConfig({'norm': 'bn1d'}) + }) + }) + } + + cfg = MODEL[option.arch] + + cfg = EasyConfig(cfg) + self.model = build_model_from_cfg(cfg) + + in_channel = self.model.prediction.head[-1][0].weight.shape[1] + self.model.prediction.head[-1] = self.init_head(in_channel) + + self.dataset_num_points = dataset.dataset_opt.fixed.num_points + self.model_num_points = option.num_points + if self.model_num_points == 1024: + self.point_all = 1200 + elif self.model_num_points == 4096: + self.point_all = 4800 + elif self.model_num_points == 6144: + self.point_all = 6900 + elif self.model_num_points == 8192: + self.point_all = 8192 + elif self.model_num_points == 12288: + self.point_all = 12288 + elif self.model_num_points == 16384: + self.point_all = 16384 + else: + raise NotImplementedError() + self.should_sample = self.model_num_points < self.dataset_num_points + self._supports_mixed = True + + self.head_namespace = option.get("head_namespace", "linears") + self.head_optim_settings = option.get("head_optim_settings", {}) + self.backbone_optim_settings = option.get("backbone_optim_settings", {}) + + def get_parameter_list(self) -> List[dict]: + params_list = [] + head_parameters, backbone_parameters = [], [] + for name, param in self.model.named_parameters(): + if self.head_namespace in name: + head_parameters.append(param) + else: + backbone_parameters.append(param) + params_list.append({"params": head_parameters, **self.head_optim_settings}) + params_list.append({"params": backbone_parameters, **self.backbone_optim_settings}) + + return params_list + + def init_head(self, in_channel): + return SeparateLinear( + in_channel, { + "num_reg_classes": self.num_reg_classes, + "num_mixtures": self.num_mixtures, + "num_cls_classes": self.num_cls_classes + } + ) + + def set_input(self, data, device): + self.data_visual = data + + points = data['pos'].to(device) + points = points.view(-1, self.dataset_num_points, points.shape[-1]) + + features = data['x'].to(device) + features = features.view(-1, self.dataset_num_points, features.shape[-1]) + + # # debug + # from openpoints.dataset import vis_points + # import ipdb; ipdb.set_trace() + # vis_points(data['pos']) + + if self.should_sample: # point resampling strategy + point_all = points.size(1) if points.size(1) < self.point_all else self.point_all + fps_idx = self.furthest_point_sample(points[:, :, :3].contiguous(), point_all) + fps_idx = fps_idx[:, np.random.choice(point_all, self.model_num_points, False)] + points = torch.gather(points, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, points.shape[-1])) + features = torch.gather(features, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, features.shape[-1])) + + self.input = {"pos": points, "x": features.transpose(1, 2).contiguous()} + self.batch_idx = data.batch + + if len(self.loss_fns) > 0: + bs = len(data) + if self.has_reg_targets and data.y_reg is not None: + self.reg_y_mask = data.y_reg_mask.to(device).view(bs, -1) + self.reg_y = data.y_reg.to(device).view(bs, -1) + if self.has_mol_targets and data.y_mol is not None: + self.mol_y_mask = data.y_mol_mask.to(device).view(bs, -1) + self.mol_y = data.y_mol.to(device).view(bs, -1) + if self.has_cls_targets and data.y_cls is not None: + self.cls_y_mask = data.y_cls_mask.to(device).view(bs, -1) + self.cls_y = data.y_cls.to(device).view(bs, -1) + + def compute_loss(self): + self.loss = 0 + self.compute_instance_loss() + + def forward(self, *args, **kwargs): + self.output = self.model(self.input) + self.reg_out, self.mol_out, self.cls_out = self.convert_outputs(self.output) + self.compute_loss() + + self.data_visual.pred = self.output + + +class ProjClassifier(nn.Module): + def __init__(self, hidden_dim: int, proj_layers, out_channels: [int, dict], detach_classifier: bool, act_fn, + last_norm: bool): + nn.Module.__init__(self) + sizes = [hidden_dim] + list(proj_layers) + layers = [] + for i in range(len(sizes) - 2): + layers.append(nn.Linear(sizes[i], sizes[i + 1])) + layers.append(FastBatchNorm1d(sizes[i + 1])) + layers.append(act_fn) + layers.append(nn.Linear(sizes[-2], sizes[-1])) + if last_norm: + layers.append(FastBatchNorm1d(sizes[-1], affine=False)) + self.projector = nn.Sequential(*layers) + self.detach_classifier = detach_classifier + + self.classifier = SeparateLinear(hidden_dim, out_channels) + + def forward(self, x: torch.Tensor): + x_ = x.detach() if self.detach_classifier else x + return self.classifier(x_), self.projector(x) + + +class PointNextBarlowTwin(PointNext): + def __init__(self, option, model_type, dataset, modules): + model_version = option.get("model_version", "standard") + self.proj_layers = option.proj_layers + self.proj_last_norm = option.proj_last_norm + self.proj_activation = option.get("proj_activation", None) + if self.proj_activation is None: + self.proj_activation = option.get("activation", "relu") + self.detach_classifier = option.mode != "finetune" and model_version == "standard" + self.reset_output = option.get("reset_output", True) + + super().__init__(option, model_type, dataset, modules) + + self.mode = option.mode + if self.mode not in ["finetune", "freeze"]: + self.loss_names.extend( + ["loss_self_supervised"] + ) + self.scale_loss = option.scale_loss + self.backbone_lr = option.backbone_lr + + def init_head(self, in_channel): + self.act_fn = create_act(self.proj_activation) + return ProjClassifier( + in_channel, self.proj_layers, + { + "num_reg_classes": self.num_reg_classes, + "num_mixtures": self.num_mixtures, + "num_cls_classes": self.num_cls_classes + }, self.detach_classifier, self.act_fn, self.proj_last_norm + ) + + def get_parameter_list(self) -> List[dict]: + params_list = [] + classifier_parameters, model_parameters = [], [] + for name, param in self.model.named_parameters(): + if "prediction.head" in name: + classifier_parameters.append(param) + else: + model_parameters.append(param) + + params_list.append({"params": classifier_parameters}) + if self.mode in ["finetune", "train"]: + model_dict = {"params": model_parameters} + if self.backbone_lr != "base_lr": + model_dict["lr"] = self.backbone_lr + params_list.append(model_dict) + + return params_list + + def set_pretrained_weights(self): + super().set_pretrained_weights() + if self.mode in ["finetune", "freeze"] and self.reset_output: + log.info(f"resetting weights for final prediction layer (since we are in {self.mode} mode)") + for m in self.model.prediction.head[-1].classifier.linears: + m.weight.data.normal_(mean=0.0, std=0.01) + m.bias.data.zero_() + + def set_input(self, data, device): + + points = data['pos'].to(device) + points = points.view(-1, self.dataset_num_points, points.shape[-1]) + + features = data['x'].to(device) + features = features.view(-1, self.dataset_num_points, features.shape[-1]) + + # # debug + # from openpoints.vis3d import vis_points + # vis_points(data['pos'].cpu().numpy()[0]) + # import ipdb; ipdb.set_trace() + + if self.should_sample: # point resampling strategy + point_all = points.size(1) if points.size(1) < self.point_all else self.point_all + fps_idx = self.furthest_point_sample(points[:, :, :3].contiguous(), point_all) + fps_idx = fps_idx[:, np.random.choice(point_all, self.model_num_points, False)] + points = torch.gather(points, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, points.shape[-1])) + features = torch.gather(features, 1, fps_idx.unsqueeze(-1).long().expand(-1, -1, features.shape[-1])) + + self.batch_idx = data.batch + if self.training and self.double_batch: + self.input = {"pos": points[::2].contiguous(), "x": features[::2].transpose(1, 2).contiguous()} + self.input2 = {"pos": points[1::2].contiguous(), "x": features[1::2].transpose(1, 2).contiguous()} + data = Batch.from_data_list(data.to_data_list()[::2]) + else: + self.input = {"pos": points, "x": features.transpose(1, 2).contiguous()} + self.input2 = None + + self.data_visual = data + + if len(self.loss_fns) > 0: + bs = len(data) + if self.has_reg_targets and data.y_reg is not None: + self.reg_y_mask = data.y_reg_mask.to(device).view(bs, -1) + self.reg_y = data.y_reg.to(device).view(bs, -1) + if self.has_mol_targets and data.y_mol is not None: + self.mol_y_mask = data.y_mol_mask.to(device).view(bs, -1) + self.mol_y = data.y_mol.to(device).view(bs, -1) + if self.has_cls_targets and data.y_cls is not None: + self.cls_y_mask = data.y_cls_mask.to(device).view(bs, -1) + self.cls_y = data.y_cls.to(device).view(bs, -1) + + def compute_loss(self): + self.loss = 0 + self.compute_instance_loss() + + if self.mode not in ["finetune", "freeze"]: + self.compute_self_supervised_loss() + + def compute_self_supervised_loss(self): + # barlow loss + # empirical cross-correlation matrix + self.loss_self_supervised = 0 + if self.training and self.double_batch: + self.loss_self_supervised += barlow_loss( + self.z1, self.z2, self.scale_loss["lambda"] + ) + self.loss += self.scale_loss["all"] * self.loss_self_supervised + + def compute_instance_loss(self): + self.compute_reg_loss() + self.compute_mol_loss() + self.compute_cls_loss() + + def forward_(self, input1, input2): + class_out_1, z1 = self.model(input1) + if self.training and self.mode == "train": + class_out_2, z2 = self.model(input2) + else: + class_out_2, z2 = None, None + + return class_out_1, class_out_2, z1, z2 + + def forward(self, *args, **kwargs): + self.set_mode() + self.output, self.output2, self.z1, self.z2 = self.forward_(self.input, self.input2) + self.reg_out, self.mol_out, self.cls_out = self.convert_outputs(self.output) + self.reg_out2, self.mol_out2, self.cls_out2 = self.convert_outputs(self.output2) + + self.compute_loss() + self.data_visual.pred = self.output + + def set_mode(self): + if self.training: + if self.mode == "freeze": + self.model.requires_grad_(False) + self.model.prediction.head[-1].requires_grad_(True) + self.model.eval() + self.model.prediction.head[-1].train() + + +class PointNextVICReg(PointNextBarlowTwin): + + def __init__(self, option, model_type, dataset, modules): + super(PointNextVICReg, self).__init__(option, model_type, dataset, modules) + + if self.mode not in ["finetune", "freeze"]: + self.loss_names.extend( + ["loss_invariance", "loss_variance", "loss_covariance"] + ) + + def compute_self_supervised_loss(self): + # barlow loss + # empirical cross-correlation matrix + self.loss_self_supervised = 0 + if self.training and self.mode == "train": + # from https://github.com/vturrisi/solo-learn/blob/6f19d5dc38fb6521e7fdd6aed5ac4a30ef8f3bd8/solo/losses/vicreg.py#L83 + z1, z2 = self.z1, self.z2 + # invariance loss + self.loss_invariance = invariance_loss(z1, z2) + + # vicreg's official code gathers the tensors here + # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py + z1, z2 = gather(z1), gather(z2) + + # variance_loss + self.loss_variance = variance_loss(z1, z2) + self.loss_covariance = covariance_loss(z1, z2) + loss = self.scale_loss["invariance"] * self.loss_invariance + \ + self.scale_loss["variance"] * self.loss_variance + \ + self.scale_loss["covariance"] * self.loss_covariance + + self.loss_self_supervised += loss + self.loss += self.loss_self_supervised diff --git a/torch-points3d/torch_points3d/models/instance/simplestnet.py b/torch-points3d/torch_points3d/models/instance/simplestnet.py new file mode 100755 index 0000000..1248c04 --- /dev/null +++ b/torch-points3d/torch_points3d/models/instance/simplestnet.py @@ -0,0 +1,107 @@ +import logging +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn + +from torch_points3d.core.common_modules import FastBatchNorm1d +from torch_points3d.models.instance.base import InstanceBase + +log = logging.getLogger(__name__) + + +class SeparateLinear(torch.nn.Module): + + def __init__(self, in_channel, num_reg_classes, num_mixtures, num_cls_classes): + super(SeparateLinear, self).__init__() + self.linears = [] + if num_reg_classes > 0: + self.linears += [torch.nn.Linear(in_channel, 1, bias=True) for i in range(num_reg_classes)] + if len(num_mixtures) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_mixtures * 3, bias=True) for i, num_mixtures in + enumerate(num_mixtures) + ] + if len(num_cls_classes) > 0: + self.linears += [ + torch.nn.Linear(in_channel, num_classes) for num_classes in num_cls_classes + ] + + self.linears = torch.nn.ModuleList(self.linears) + + def forward(self, x): + return torch.cat([lin(x) for lin in self.linears], 1) + + +class SimplestNet(InstanceBase): + def __init__(self, option, model_type, dataset, modules): + super(SimplestNet, self).__init__(option, model_type, dataset, modules) + self.model = nn.Sequential( + nn.Conv1d(dataset.feature_dimension + 3, 64, 1), + nn.GELU(), + nn.BatchNorm1d(64), + nn.Conv1d(64, 128, 1), + nn.GELU(), + nn.BatchNorm1d(128), + nn.Conv1d(128, 128, 1), + nn.GELU(), + nn.BatchNorm1d(128), + ) + self.head = SeparateLinear(128, self.num_reg_classes, self.num_mixtures, self.num_cls_classes) + self.dataset_num_points = dataset.dataset_opt.fixed.num_points + self._supports_mixed = True + + self.head_namespace = option.get("head_namespace", "head.linears") + self.head_optim_settings = option.get("head_optim_settings", {}) + self.backbone_optim_settings = option.get("backbone_optim_settings", {}) + + def get_parameter_list(self) -> List[dict]: + params_list = [] + head_parameters, backbone_parameters = [], [] + for name, param in self.model.named_parameters(): + if self.head_namespace in name: + head_parameters.append(param) + else: + backbone_parameters.append(param) + params_list.append({"params": head_parameters, **self.head_optim_settings}) + params_list.append({"params": backbone_parameters, **self.backbone_optim_settings}) + + return params_list + + def set_input(self, data, device): + self.data_visual = data + points = data['pos'].to(device) + points = points.view(-1, self.dataset_num_points, points.shape[-1]) + + features = data['x'].to(device) + features = features.view(-1, self.dataset_num_points, features.shape[-1]) + + self.input = torch.cat([features, points], 2).moveaxis(2, 1) + self.batch_idx = data.batch + + if len(self.loss_fns) > 0: + bs = len(data) + if self.has_reg_targets and data.y_reg is not None: + self.reg_y_mask = data.y_reg_mask.to(device).view(bs, -1) + self.reg_y = data.y_reg.to(device).view(bs, -1) + if self.has_mol_targets and data.y_mol is not None: + self.mol_y_mask = data.y_mol_mask.to(device).view(bs, -1) + self.mol_y = data.y_mol.to(device).view(bs, -1) + if self.has_cls_targets and data.y_cls is not None: + self.cls_y_mask = data.y_cls_mask.to(device).view(bs, -1) + self.cls_y = data.y_cls.to(device).view(bs, -1) + + def compute_loss(self): + self.loss = 0 + self.compute_instance_loss() + + def forward(self, *args, **kwargs): + x = self.model(self.input) + x = F.adaptive_avg_pool1d(x, 1).squeeze(2) + self.output = self.head(x) + + self.reg_out, self.mol_out, self.cls_out = self.convert_outputs(self.output) + self.compute_loss() + + self.data_visual.pred = self.output diff --git a/torch-points3d/torch_points3d/models/instance/sparseconv3d.py b/torch-points3d/torch_points3d/models/instance/sparseconv3d.py new file mode 100644 index 0000000..b09517b --- /dev/null +++ b/torch-points3d/torch_points3d/models/instance/sparseconv3d.py @@ -0,0 +1,60 @@ +import logging + +from torch import nn +import torch + +from torch_points3d.models.instance.base import InstanceBase +from torch_points3d.models.regression.minkowski import SeparateLinear +from torch_points3d.modules.SparseConv3d.SENet import ResNetBase, NETWORK_CONFIGS + +log = logging.getLogger(__name__) + + +class SeparateLinear(nn.Module): + def __init__(self, in_channel, out_channels): + super(SeparateLinear, self).__init__() + self.linears = nn.ModuleList([nn.Linear(in_channel, 1, bias=True) for i in range(out_channels)]) + + def forward(self, x): + return torch.cat([lin(x) for lin in self.linears], 1) + + +class ResNetModel(InstanceBase): + def __init__(self, option, model_type, dataset, modules): + # call the initialization method + super().__init__(option, model_type, dataset, modules) + self.model = ResNetBase( + dataset.feature_dimension, dataset.num_classes, activation=option.activation, + first_stride=option.first_stride, dropout=option.dropout, global_pool=option.global_pool, + backend=option.backend, **NETWORK_CONFIGS[option.model_name]) + + in_channel = self.model.final.weight.shape[1] + out_channel = self.model.final.weight.shape[0] + self.model.final = SeparateLinear(in_channel, out_channel) + self._supports_mixed = self.model.snn.name == "torchsparse" + + def set_input(self, data, device): + self.batch_idx = data.batch.squeeze() + self.input = self.model.snn.SparseTensor(data.x, data.coords, data.batch, device) + if data.y is not None: + self.labels = data.y.to(device) + else: + self.labels = None + + self.data_visual = data + + def compute_loss(self): + self.loss_regr = 0 + labels = self.labels.view(self.output.shape) + for loss_fn in self.loss_fns: + self.loss_regr += (loss_fn(self.output, labels, reduction="none") / self.scale_targets).mean() + + self.loss_regr += self.get_internal_loss() + self.loss = self.loss_regr + + def forward(self, *args, **kwargs): + self.output = self.model(self.input) + if self.labels is not None: + self.compute_loss() + + self.data_visual.pred = self.output diff --git a/torch-points3d/torch_points3d/models/model_factory.py b/torch-points3d/torch_points3d/models/model_factory.py new file mode 100644 index 0000000..188c273 --- /dev/null +++ b/torch-points3d/torch_points3d/models/model_factory.py @@ -0,0 +1,44 @@ +import importlib + +from torch_points3d.utils.model_building_utils.model_definition_resolver import resolve_model +from .base_model import BaseModel + + +def instantiate_model(config, dataset) -> BaseModel: + """ Creates a model given a dataset and a training config. The config should contain the following: + - config.data.task: task that will be evaluated + - config.model_name: model to instantiate + - config.models: All models available + """ + + # Get task and model_name + task = config.data.task + tested_model_name = config.model_name + + # Find configs + models = config.get('models') + model_config = getattr(models, tested_model_name, None) + if model_config is None: + models_keys = models.keys() if models is not None else "" + raise Exception("The model_name {} isn t within {}".format(tested_model_name, list(models_keys))) + resolve_model(model_config, dataset, task) + + model_class = getattr(model_config, "class") + model_paths = model_class.split(".") + module = ".".join(model_paths[:-1]) + class_name = model_paths[-1] + model_module = ".".join(["torch_points3d.models", task, module]) + modellib = importlib.import_module(model_module) + + model_cls = None + for name, cls in modellib.__dict__.items(): + if name.lower() == class_name.lower(): + model_cls = cls + + if model_cls is None: + raise NotImplementedError( + "In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." + % (model_module, class_name) + ) + model = model_cls(model_config, "dummy", dataset, modellib) + return model diff --git a/torch-points3d/torch_points3d/models/model_interface.py b/torch-points3d/torch_points3d/models/model_interface.py new file mode 100644 index 0000000..d5c44b7 --- /dev/null +++ b/torch-points3d/torch_points3d/models/model_interface.py @@ -0,0 +1,111 @@ +from abc import abstractmethod, abstractproperty, ABC + + +class CheckpointInterface(ABC): + """This class is a minimal interface class for models. + """ + + @abstractproperty # type: ignore + def schedulers(self): + pass + + @schedulers.setter + def schedulers(self, schedulers): + pass + + @abstractproperty # type: ignore + def optimizer(self): + pass + + @optimizer.setter + def optimizer(self, optimizer): + pass + + @abstractmethod + def state_dict(self): + pass + + @abstractmethod + def load_state_dict(self, state, strict=False): + pass + + +class DatasetInterface(ABC): + @abstractproperty + def conv_type(self): + pass + + def get_spatial_ops(self): + pass + + +class TrackerInterface(ABC): + @property + @abstractmethod + def conv_type(self): + pass + + @abstractmethod + def get_labels(self): + """ returns a tensor of size ``[N_points]`` where each value is the label of a point + """ + + @abstractmethod + def get_batch(self): + """ returns a tensor of size ``[N_points]`` where each value is the batch index of a point + """ + + @abstractmethod + def get_output(self): + """ returns a tensor of size ``[N_points,...]`` where each value is the output + of the network for a point (output of the last layer in general) + """ + + @abstractmethod + def get_input(self): + """ returns the last input that was given to the model or raises error + """ + + @abstractmethod + def get_current_losses(self): + """Return training losses / errors. train.py will print out these errors on console""" + + @abstractproperty + def device(self): + """ Returns the device onto which the model leaves (cpu or gpu) + """ + + +class InstanceTrackerInterface(TrackerInterface): + + @abstractmethod + def get_reg_output(self): + """ returns a tensor of size ``[N_points,...]`` where each value is the regression output + of the network for a point (output of the last layer in general) + """ + + @abstractmethod + def get_mol_output(self): + """ returns a tensor of size ``[N_points,...]`` where each value is the mixture of logits output + of the network for a point (output of the last layer in general) + """ + @abstractmethod + def get_cls_output(self): + """ returns a tensor of size ``[N_points,...]`` where each value is the classification output + of the network for a point (output of the last layer in general) + """ + + @abstractmethod + def get_reg_input(self): + """ returns the last regression input that was given to the model or raises error + """ + + @abstractmethod + def get_mol_input(self): + """ returns the last mixture of logits input that was given to the model or raises error + """ + + @abstractmethod + def get_cls_input(self): + """ returns the last classification input that was given to the model or raises error + """ \ No newline at end of file diff --git a/torch-points3d/torch_points3d/modules/KPConv/__init__.py b/torch-points3d/torch_points3d/modules/KPConv/__init__.py new file mode 100644 index 0000000..86e781c --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/__init__.py @@ -0,0 +1 @@ +# from https://github.com/HuguesTHOMAS/KPConv-PyTorch \ No newline at end of file diff --git a/torch-points3d/torch_points3d/modules/KPConv/architectures.py b/torch-points3d/torch_points3d/modules/KPConv/architectures.py new file mode 100644 index 0000000..d14450d --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/architectures.py @@ -0,0 +1,335 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Define network architectures +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 06/03/2020 +# +from functools import partial + +from .blocks import * +import numpy as np + +ACTIVATIONS = { + "relu": partial(nn.ReLU, inplace=True), + "celu": partial(nn.CELU, inplace=True, alpha=0.54), + "silu": partial(nn.SiLU, inplace=True), + "swish": partial(nn.SiLU, inplace=True), + "elu": partial(nn.ELU, inplace=True, alpha=0.54), + "sigmoid": partial(nn.Sigmoid), + "tanh": partial(nn.Tanh), + "gelu": partial(nn.GELU), +} + +def p2p_fitting_regularizer(net): + fitting_loss = 0 + repulsive_loss = 0 + + for m in net.modules(): + + if isinstance(m, KPConv) and m.deformable: + + ############## + # Fitting loss + ############## + + # Get the distance to closest input point and normalize to be independant from layers + KP_min_d2 = m.min_d2 / (m.KP_extent ** 2) + + # Loss will be the square distance to closest input point. We use L1 because dist is already squared + fitting_loss += net.l1(KP_min_d2, torch.zeros_like(KP_min_d2)) + + ################ + # Repulsive loss + ################ + + # Normalized KP locations + KP_locs = m.deformed_KP / m.KP_extent + + # Point should not be close to each other + for i in range(net.K): + other_KP = torch.cat([KP_locs[:, :i, :], KP_locs[:, i + 1:, :]], dim=1).detach() + distances = torch.sqrt(torch.sum((other_KP - KP_locs[:, i:i + 1, :]) ** 2, dim=2)) + rep_loss = torch.sum(torch.clamp_max(distances - net.repulse_extent, max=0.0) ** 2, dim=1) + repulsive_loss += net.l1(rep_loss, torch.zeros_like(rep_loss)) / net.K + + return net.deform_fitting_power * (2 * fitting_loss + repulsive_loss) + + +class KPCNN(nn.Module): + """ + Class defining KPCNN + """ + + def __init__(self, config): + super(KPCNN, self).__init__() + + ##################### + # Network operations + ##################### + + # Current radius of convolution and feature dimension + layer = 0 + r = config.first_subsampling_dl * config.conv_radius + in_dim = config.in_features_dim + out_dim = config.first_features_dim + act_fn = ACTIVATIONS[config.activation] + self.act_fn = act_fn + self.K = config.num_kernel_points + + # Save all block operations in a list of modules + self.block_ops = nn.ModuleList() + + # Loop over consecutive blocks + block_in_layer = 0 + for block_i, block in enumerate(config.architecture): + + # Check equivariance + if ('equivariant' in block) and (not out_dim % 3 == 0): + raise ValueError('Equivariant block but features dimension is not a factor of 3') + + # Detect upsampling block to stop + if 'upsample' in block: + break + + # Apply the good block function defining tf ops + self.block_ops.append(block_decider( + block, r, in_dim, out_dim, layer, act_fn, config + )) + + # Index of block in this layer + block_in_layer += 1 + + # Update dimension of input from output + if 'simple' in block: + in_dim = out_dim // 2 + else: + in_dim = out_dim + + # Detect change to a subsampled layer + if 'pool' in block or 'strided' in block: + # Update radius and feature dimension for next layer + layer += 1 + r *= 2 + out_dim *= 2 + block_in_layer = 0 + + self.head_mlp = UnaryBlock(out_dim, 1024, act_fn, False, 0) + + ################ + # Network Losses + ################ + + self.deform_fitting_mode = config.deform_fitting_mode + self.deform_fitting_power = config.deform_fitting_power + self.deform_lr_factor = config.deform_lr_factor + self.repulse_extent = config.repulse_extent + self.reg_loss = 0 + self.l1 = nn.L1Loss() + + return + + def forward(self, batch): + # Save all block operations in a list of modules + x = batch.features.clone().detach() + + # Loop over consecutive blocks + for block_op in self.block_ops: + x = block_op(x, batch) + + # Head of network + x = self.head_mlp(x, batch) + + return x + + def internal_loss(self): + """ + Runs the internal loss on outputs of the model + :return: loss + """ + + # Regularization of deformable offsets + if self.deform_fitting_mode == 'point2point': + self.reg_loss = p2p_fitting_regularizer(self) + elif self.deform_fitting_mode == 'point2plane': + raise ValueError('point2plane fitting mode not implemented yet.') + else: + raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode) + + # Combined loss + return self.reg_loss + + +class KPFCNN(nn.Module): + """ + Class defining KPFCNN + """ + + def __init__(self, config, lbl_values, ign_lbls): + super(KPFCNN, self).__init__() + + ############ + # Parameters + ############ + + # Current radius of convolution and feature dimension + layer = 0 + r = config.first_subsampling_dl * config.conv_radius + in_dim = config.in_features_dim + out_dim = config.first_features_dim + act_fn = config.act_fn + self.act_fn = act_fn + self.K = config.num_kernel_points + self.C = len(lbl_values) - len(ign_lbls) + + ##################### + # List Encoder blocks + ##################### + + # Save all block operations in a list of modules + self.encoder_blocks = nn.ModuleList() + self.encoder_skip_dims = [] + self.encoder_skips = [] + + # Loop over consecutive blocks + for block_i, block in enumerate(config.architecture): + + # Check equivariance + if ('equivariant' in block) and (not out_dim % 3 == 0): + raise ValueError('Equivariant block but features dimension is not a factor of 3') + + # Detect change to next layer for skip connection + if np.any([tmp in block for tmp in ['pool', 'strided', 'upsample', 'global']]): + self.encoder_skips.append(block_i) + self.encoder_skip_dims.append(in_dim) + + # Detect upsampling block to stop + if 'upsample' in block: + break + + # Apply the good block function defining tf ops + self.encoder_blocks.append(block_decider( + block, r, in_dim, out_dim, layer, act_fn, config + )) + + # Update dimension of input from output + if 'simple' in block: + in_dim = out_dim // 2 + else: + in_dim = out_dim + + # Detect change to a subsampled layer + if 'pool' in block or 'strided' in block: + # Update radius and feature dimension for next layer + layer += 1 + r *= 2 + out_dim *= 2 + + ##################### + # List Decoder blocks + ##################### + + # Save all block operations in a list of modules + self.decoder_blocks = nn.ModuleList() + self.decoder_concats = [] + + # Find first upsampling block + start_i = 0 + for block_i, block in enumerate(config.architecture): + if 'upsample' in block: + start_i = block_i + break + + # Loop over consecutive blocks + for block_i, block in enumerate(config.architecture[start_i:]): + + # Add dimension of skip connection concat + if block_i > 0 and 'upsample' in config.architecture[start_i + block_i - 1]: + in_dim += self.encoder_skip_dims[layer] + self.decoder_concats.append(block_i) + + # Apply the good block function defining tf ops + self.decoder_blocks.append(block_decider( + block, r, in_dim, out_dim, layer, act_fn, config + )) + + # Update dimension of input from output + in_dim = out_dim + + # Detect change to a subsampled layer + if 'upsample' in block: + # Update radius and feature dimension for next layer + layer -= 1 + r *= 0.5 + out_dim = out_dim // 2 + + self.head_mlp = UnaryBlock(out_dim, config.first_features_dim, act_fn, False, 0) + + ################ + # Network Losses + ################ + + # List of valid labels (those not ignored in loss) + self.valid_labels = np.sort([c for c in lbl_values if c not in ign_lbls]) + + # Choose segmentation loss + if len(config.class_w) > 0: + class_w = torch.from_numpy(np.array(config.class_w, dtype=np.float32)) + self.criterion = torch.nn.CrossEntropyLoss(weight=class_w, ignore_index=-1) + else: + self.criterion = torch.nn.CrossEntropyLoss(ignore_index=-1) + self.deform_fitting_mode = config.deform_fitting_mode + self.deform_fitting_power = config.deform_fitting_power + self.deform_lr_factor = config.deform_lr_factor + self.repulse_extent = config.repulse_extent + self.reg_loss = 0 + self.l1 = nn.L1Loss() + + return + + def forward(self, batch): + + # Get input features + x = batch.features.clone().detach() + + # Loop over consecutive blocks + skip_x = [] + for block_i, block_op in enumerate(self.encoder_blocks): + if block_i in self.encoder_skips: + skip_x.append(x) + x = block_op(x, batch) + + for block_i, block_op in enumerate(self.decoder_blocks): + if block_i in self.decoder_concats: + x = torch.cat([x, skip_x.pop()], dim=1) + x = block_op(x, batch) + + # Head of network + x = self.head_mlp(x, batch) + + return x + + def internal_loss(self, ): + """ + Runs the internal loss on outputs of the model + :return: loss + """ + + # Regularization of deformable offsets + if self.deform_fitting_mode == 'point2point': + self.reg_loss = p2p_fitting_regularizer(self) + elif self.deform_fitting_mode == 'point2plane': + raise ValueError('point2plane fitting mode not implemented yet.') + else: + raise ValueError('Unknown fitting mode: ' + self.deform_fitting_mode) + + # Combined loss + return self.reg_loss diff --git a/torch-points3d/torch_points3d/modules/KPConv/blocks.py b/torch-points3d/torch_points3d/modules/KPConv/blocks.py new file mode 100644 index 0000000..aba4a47 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/blocks.py @@ -0,0 +1,738 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Define network blocks +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 06/03/2020 +# + + +import math + +import torch +import torch.nn as nn +from torch.nn.init import kaiming_uniform_ +from torch.nn.parameter import Parameter + +from .kernel_points import load_kernels + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Simple functions +# \**********************/ +# + + +@torch.jit.script +def gather(x: torch.Tensor, idx: torch.Tensor, method: int = 2): + """ + implementation of a custom gather operation for faster backwards. + :param x: input with shape [N, D_1, ... D_d] + :param idx: indexing with shape [n_1, ..., n_m] + :param method: Choice of the method + :return: x[idx] with shape [n_1, ..., n_m, D_1, ... D_d] + """ + + if method == 0: + return x[idx] + elif method == 1: + x = x.unsqueeze(1) + x = x.expand((-1, idx.shape[-1], -1)) + idx = idx.unsqueeze(2) + idx = idx.expand((-1, -1, x.shape[-1])) + return x.gather(0, idx) + elif method == 2: + for i, ni in enumerate(idx.size()[1:]): + x = x.unsqueeze(i + 1) + new_s = list(x.size()) + new_s[i + 1] = ni + x = x.expand(new_s) + n = len(idx.size()) + for i, di in enumerate(x.size()[n:]): + idx = idx.unsqueeze(i + n) + new_s = list(idx.size()) + new_s[i + n] = di + idx = idx.expand(new_s) + return x.gather(0, idx) + else: + raise ValueError('Unkown method') + + +@torch.jit.script +def radius_gaussian(sq_r: torch.Tensor, sig: float, eps: float = 1e-9): + """ + Compute a radius gaussian (gaussian of distance) + :param sq_r: input radiuses [dn, ..., d1, d0] + :param sig: extents of gaussians [d1, d0] or [d0] or float + :return: gaussian of sq_r [dn, ..., d1, d0] + """ + return torch.exp(-sq_r / (2 * sig ** 2 + eps)) + + +@torch.jit.script +def closest_pool(x: torch.Tensor, inds: torch.Tensor): + """ + Pools features from the closest neighbors. WARNING: this function assumes the neighbors are ordered. + :param x: [n1, d] features matrix + :param inds: [n2, max_num] Only the first column is used for pooling + :return: [n2, d] pooled features matrix + """ + + # Add a last row with minimum features for shadow pools + x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) + + # Get features for each pooling location [n2, d] + return gather(x, inds[:, 0]) + + +@torch.jit.script +def max_pool(x: torch.Tensor, inds: torch.Tensor): + """ + Pools features with the maximum values. + :param x: [n1, d] features matrix + :param inds: [n2, max_num] pooling indices + :return: [n2, d] pooled features matrix + """ + + # Add a last row with minimum features for shadow pools + x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) + + # Get all features for each pooling location [n2, max_num, d] + pool_features = gather(x, inds) + + # Pool the maximum [n2, d] + max_features, _ = torch.max(pool_features, 1) + return max_features + + +@torch.jit.script +def global_average(x: torch.Tensor, batch_lengths: torch.Tensor): + """ + Block performing a global average over batch pooling + :param x: [N, D] input features + :param batch_lengths: [B] list of batch lengths + :return: [B, D] averaged features + """ + + # Loop over the clouds of the batch + averaged_features = [] + i0 = 0 + for b_i, length in enumerate(batch_lengths): + # Average features for each batch cloud + averaged_features.append(torch.mean(x[i0:i0 + length], dim=0)) + + # Increment for next cloud + i0 += length + + # Average features in each batch + return torch.stack(averaged_features) + + +@torch.jit.script +def global_sum(x: torch.Tensor, batch_lengths: torch.Tensor): + """ + Block performing a global average over batch pooling + :param x: [N, D] input features + :param batch_lengths: [B] list of batch lengths + :return: [B, D] averaged features + """ + + # Loop over the clouds of the batch + averaged_features = [] + i0 = 0 + for b_i, length in enumerate(batch_lengths): + # Average features for each batch cloud + averaged_features.append(torch.sum(x[i0:i0 + length], dim=0)) + + # Increment for next cloud + i0 += length + + # Average features in each batch + return torch.stack(averaged_features) + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# KPConv class +# \******************/ +# + + +class KPConv(nn.Module): + + def __init__(self, kernel_size, p_dim, in_channels, out_channels, KP_extent, radius, + fixed_kernel_points='center', KP_influence='linear', aggregation_mode='sum', + deformable=False, modulated=False): + """ + Initialize parameters for KPConvDeformable. + :param kernel_size: Number of kernel points. + :param p_dim: dimension of the point space. + :param in_channels: dimension of input features. + :param out_channels: dimension of output features. + :param KP_extent: influence radius of each kernel point. + :param radius: radius used for kernel point init. Even for deformable, use the config.conv_radius + :param fixed_kernel_points: fix position of certain kernel points ('none', 'center' or 'verticals'). + :param KP_influence: influence function of the kernel points ('constant', 'linear', 'gaussian'). + :param aggregation_mode: choose to sum influences, or only keep the closest ('closest', 'sum'). + :param deformable: choose deformable or not + :param modulated: choose if kernel weights are modulated in addition to deformed + """ + super(KPConv, self).__init__() + + # Save parameters + self.K = kernel_size + self.p_dim = p_dim + self.in_channels = in_channels + self.out_channels = out_channels + self.radius = radius + self.KP_extent = KP_extent + self.fixed_kernel_points = fixed_kernel_points + self.KP_influence = KP_influence + self.aggregation_mode = aggregation_mode + self.deformable = deformable + self.modulated = modulated + + # Running variable containing deformed KP distance to input points. (used in regularization loss) + self.min_d2 = None + self.deformed_KP = None + self.offset_features = None + + # Initialize weights + self.weights = Parameter(torch.zeros((self.K, in_channels, out_channels), dtype=torch.float32), + requires_grad=True) + + # Initiate weights for offsets + if deformable: + if modulated: + self.offset_dim = (self.p_dim + 1) * self.K + else: + self.offset_dim = self.p_dim * self.K + self.offset_conv = KPConv(self.K, + self.p_dim, + self.in_channels, + self.offset_dim, + KP_extent, + radius, + fixed_kernel_points=fixed_kernel_points, + KP_influence=KP_influence, + aggregation_mode=aggregation_mode) + self.offset_bias = Parameter(torch.zeros(self.offset_dim, dtype=torch.float32), requires_grad=True) + + else: + self.offset_dim = None + self.offset_conv = None + self.offset_bias = None + + # Reset parameters + self.reset_parameters() + + # Initialize kernel points + self.kernel_points = self.init_KP() + + return + + def reset_parameters(self): + kaiming_uniform_(self.weights, a=math.sqrt(5)) + if self.deformable: + nn.init.zeros_(self.offset_bias) + return + + def init_KP(self): + """ + Initialize the kernel point positions in a sphere + :return: the tensor of kernel points + """ + + # Create one kernel disposition (as numpy array). Choose the KP distance to center thanks to the KP extent + K_points_numpy = load_kernels(self.radius, + self.K, + dimension=self.p_dim, + fixed=self.fixed_kernel_points) + + return Parameter(torch.tensor(K_points_numpy, dtype=torch.float32), + requires_grad=False) + + def forward(self, q_pts, s_pts, neighb_inds, x): + + ################### + # Offset generation + ################### + + if self.deformable: + + # Get offsets with a KPConv that only takes part of the features + self.offset_features = self.offset_conv(q_pts, s_pts, neighb_inds, x) + self.offset_bias + + if self.modulated: + + # Get offset (in normalized scale) from features + unscaled_offsets = self.offset_features[:, :self.p_dim * self.K] + unscaled_offsets = unscaled_offsets.view(-1, self.K, self.p_dim) + + # Get modulations + modulations = 2 * torch.sigmoid(self.offset_features[:, self.p_dim * self.K:]) + + else: + + # Get offset (in normalized scale) from features + unscaled_offsets = self.offset_features.view(-1, self.K, self.p_dim) + + # No modulations + modulations = None + + # Rescale offset for this layer + offsets = unscaled_offsets * self.KP_extent + + else: + offsets = None + modulations = None + + ###################### + # Deformed convolution + ###################### + + # Add a fake point in the last row for shadow neighbors + s_pts = torch.cat((s_pts, torch.zeros_like(s_pts[:1, :]) + 1e6), 0) + + # Get neighbor points [n_points, n_neighbors, dim] + neighbors = s_pts[neighb_inds, :] + + # Center every neighborhood + neighbors = neighbors - q_pts.unsqueeze(1) + + # Apply offsets to kernel points [n_points, n_kpoints, dim] + if self.deformable: + self.deformed_KP = offsets + self.kernel_points + deformed_K_points = self.deformed_KP.unsqueeze(1) + else: + deformed_K_points = self.kernel_points + + # Get all difference matrices [n_points, n_neighbors, n_kpoints, dim] + neighbors.unsqueeze_(2) + differences = neighbors - deformed_K_points + + # Get the square distances [n_points, n_neighbors, n_kpoints] + sq_distances = torch.sum(differences ** 2, dim=3) + + # Optimization by ignoring points outside a deformed KP range + if self.deformable: + + # Save distances for loss + self.min_d2, _ = torch.min(sq_distances, dim=1) + + # Boolean of the neighbors in range of a kernel point [n_points, n_neighbors] + in_range = torch.any(sq_distances < self.KP_extent ** 2, dim=2).type(torch.int32) + + # New value of max neighbors + new_max_neighb = torch.max(torch.sum(in_range, dim=1)) + + # For each row of neighbors, indices of the ones that are in range [n_points, new_max_neighb] + neighb_row_bool, neighb_row_inds = torch.topk(in_range, new_max_neighb.item(), dim=1) + + # Gather new neighbor indices [n_points, new_max_neighb] + new_neighb_inds = neighb_inds.gather(1, neighb_row_inds, sparse_grad=False) + + # Gather new distances to KP [n_points, new_max_neighb, n_kpoints] + neighb_row_inds.unsqueeze_(2) + neighb_row_inds = neighb_row_inds.expand(-1, -1, self.K) + sq_distances = sq_distances.gather(1, neighb_row_inds, sparse_grad=False) + + # New shadow neighbors have to point to the last shadow point + new_neighb_inds *= neighb_row_bool + new_neighb_inds -= (neighb_row_bool.type(torch.int64) - 1) * int(s_pts.shape[0] - 1) + else: + new_neighb_inds = neighb_inds + + # Get Kernel point influences [n_points, n_kpoints, n_neighbors] + if self.KP_influence == 'constant': + # Every point get an influence of 1. + all_weights = torch.ones_like(sq_distances) + all_weights = torch.transpose(all_weights, 1, 2) + + elif self.KP_influence == 'linear': + # Influence decrease linearly with the distance, and get to zero when d = KP_extent. + all_weights = torch.clamp(1 - torch.sqrt(sq_distances) / self.KP_extent, min=0.0) + all_weights = torch.transpose(all_weights, 1, 2) + + elif self.KP_influence == 'gaussian': + # Influence in gaussian of the distance. + sigma = self.KP_extent * 0.3 + all_weights = radius_gaussian(sq_distances, sigma) + all_weights = torch.transpose(all_weights, 1, 2) + else: + raise ValueError('Unknown influence function type (config.KP_influence)') + + # In case of closest mode, only the closest KP can influence each point + if self.aggregation_mode == 'closest': + neighbors_1nn = torch.argmin(sq_distances, dim=2) + all_weights *= torch.transpose(nn.functional.one_hot(neighbors_1nn, self.K), 1, 2) + + elif self.aggregation_mode != 'sum': + raise ValueError("Unknown convolution mode. Should be 'closest' or 'sum'") + + # Add a zero feature for shadow neighbors + x = torch.cat((x, torch.zeros_like(x[:1, :])), 0) + + # Get the features of each neighborhood [n_points, n_neighbors, in_fdim] + neighb_x = gather(x, new_neighb_inds) + + # Apply distance weights [n_points, n_kpoints, in_fdim] + weighted_features = torch.matmul(all_weights, neighb_x) + + # Apply modulations + if self.deformable and self.modulated: + weighted_features *= modulations.unsqueeze(2) + + # Apply network weights [n_kpoints, n_points, out_fdim] + weighted_features = weighted_features.permute((1, 0, 2)) + kernel_outputs = torch.matmul(weighted_features, self.weights) + + # Convolution sum [n_points, out_fdim] + return torch.sum(kernel_outputs, dim=0) + + def __repr__(self): + return 'KPConv(radius: {:.2f}, in_feat: {:d}, out_feat: {:d})'.format(self.radius, + self.in_channels, + self.out_channels) + + +# ---------------------------------------------------------------------------------------------------------------------- +# +# Complex blocks +# \********************/ +# + +def block_decider(block_name, + radius, + in_dim, + out_dim, + layer_ind, + act_fn, + config): + if block_name == 'unary': + return UnaryBlock(in_dim, out_dim, act_fn, config.use_batch_norm, config.batch_norm_momentum) + + elif block_name in ['simple', + 'simple_deformable', + 'simple_invariant', + 'simple_equivariant', + 'simple_strided', + 'simple_deformable_strided', + 'simple_invariant_strided', + 'simple_equivariant_strided']: + return SimpleBlock(block_name, in_dim, out_dim, radius, layer_ind, act_fn, config) + + elif block_name in ['resnetb', + 'resnetb_invariant', + 'resnetb_equivariant', + 'resnetb_deformable', + 'resnetb_strided', + 'resnetb_deformable_strided', + 'resnetb_equivariant_strided', + 'resnetb_invariant_strided']: + return ResnetBottleneckBlock(block_name, in_dim, out_dim, radius, layer_ind, act_fn, config) + + elif block_name == 'max_pool' or block_name == 'max_pool_wide': + return MaxPoolBlock(layer_ind) + + elif block_name == 'global_average': + return GlobalAverageBlock() + + elif block_name == 'global_sum': + return GlobalSumBlock() + + elif block_name == 'nearest_upsample': + return NearestUpsampleBlock(layer_ind) + + else: + raise ValueError('Unknown block name in the architecture definition : ' + block_name) + + +class BatchNormBlock(nn.Module): + + def __init__(self, in_dim, use_bn, bn_momentum): + """ + Initialize a batch normalization block. If network does not use batch normalization, replace with biases. + :param in_dim: dimension input features + :param use_bn: boolean indicating if we use Batch Norm + :param bn_momentum: Batch norm momentum + """ + super(BatchNormBlock, self).__init__() + self.bn_momentum = bn_momentum + self.use_bn = use_bn + self.in_dim = in_dim + if self.use_bn: + self.batch_norm = nn.BatchNorm1d(in_dim, momentum=bn_momentum) + else: + self.bias = Parameter(torch.zeros(in_dim, dtype=torch.float32), requires_grad=True) + return + + def reset_parameters(self): + nn.init.zeros_(self.bias) + + def forward(self, x): + if self.use_bn: + + x = x.unsqueeze(2) + x = x.transpose(0, 2) + x = self.batch_norm(x) + x = x.transpose(0, 2) + return x.squeeze() + else: + return x + self.bias + + def __repr__(self): + return 'BatchNormBlock(in_feat: {:d}, momentum: {:.3f}, only_bias: {:s})'.format(self.in_dim, + self.bn_momentum, + str(not self.use_bn)) + + +class UnaryBlock(nn.Module): + + def __init__(self, in_dim, out_dim, act_fn, use_bn, bn_momentum, no_relu=False): + """ + Initialize a standard unary block with its ReLU and BatchNorm. + :param in_dim: dimension input features + :param out_dim: dimension input features + :param act_fn: activation function (nn.Module) to initialize + :param use_bn: boolean indicating if we use Batch Norm + :param bn_momentum: Batch norm momentum + """ + + super(UnaryBlock, self).__init__() + self.bn_momentum = bn_momentum + self.use_bn = use_bn + self.no_relu = no_relu + self.in_dim = in_dim + self.out_dim = out_dim + self.mlp = nn.Linear(in_dim, out_dim, bias=False) + self.batch_norm = BatchNormBlock(out_dim, self.use_bn, self.bn_momentum) + if not no_relu: + self.act = act_fn() + return + + def forward(self, x, batch=None): + x = self.mlp(x) + x = self.batch_norm(x) + if not self.no_relu: + x = self.act(x) + return x + + def __repr__(self): + return 'UnaryBlock(in_feat: {:d}, out_feat: {:d}, BN: {:s}, ReLU: {:s})'.format(self.in_dim, + self.out_dim, + str(self.use_bn), + str(not self.no_relu)) + + +class SimpleBlock(nn.Module): + + def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, act_fn, config): + """ + Initialize a simple convolution block with its ReLU and BatchNorm. + :param in_dim: dimension input features + :param out_dim: dimension input features + :param radius: current radius of convolution + :param act_fn: activation function (nn.Module) to initialize + :param config: parameters + """ + super(SimpleBlock, self).__init__() + + # get KP_extent from current radius + current_extent = radius * config.KP_extent / config.conv_radius + + # Get other parameters + self.bn_momentum = config.batch_norm_momentum + self.use_bn = config.use_batch_norm + self.layer_ind = layer_ind + self.block_name = block_name + self.in_dim = in_dim + self.out_dim = out_dim + + # Define the KPConv class + self.KPConv = KPConv(config.num_kernel_points, + config.in_points_dim, + in_dim, + out_dim // 2, + current_extent, + radius, + fixed_kernel_points=config.fixed_kernel_points, + KP_influence=config.KP_influence, + aggregation_mode=config.aggregation_mode, + deformable='deform' in block_name, + modulated=config.modulated) + + # Other operations + self.batch_norm = BatchNormBlock(out_dim // 2, self.use_bn, self.bn_momentum) + self.act = act_fn() + + return + + def forward(self, x, batch): + if 'strided' in self.block_name: + q_pts = batch.points[self.layer_ind + 1] + s_pts = batch.points[self.layer_ind] + neighb_inds = batch.pools[self.layer_ind] + else: + q_pts = batch.points[self.layer_ind] + s_pts = batch.points[self.layer_ind] + neighb_inds = batch.neighbors[self.layer_ind] + + x = self.KPConv(q_pts, s_pts, neighb_inds, x) + return self.act(self.batch_norm(x)) + + +class ResnetBottleneckBlock(nn.Module): + + def __init__(self, block_name, in_dim, out_dim, radius, layer_ind, act_fn, config): + """ + Initialize a resnet bottleneck block. + :param in_dim: dimension input features + :param out_dim: dimension input features + :param radius: current radius of convolution + :param act_fn: activation function (nn.Module) to initialize + :param config: parameters + """ + super(ResnetBottleneckBlock, self).__init__() + + # get KP_extent from current radius + current_extent = radius * config.KP_extent / config.conv_radius + + # Get other parameters + self.bn_momentum = config.batch_norm_momentum + self.use_bn = config.use_batch_norm + self.block_name = block_name + self.layer_ind = layer_ind + self.in_dim = in_dim + self.out_dim = out_dim + + # First downscaling mlp + if in_dim != out_dim // 4: + self.unary1 = UnaryBlock(in_dim, out_dim // 4, act_fn, self.use_bn, self.bn_momentum) + else: + self.unary1 = nn.Identity() + + # KPConv block + self.KPConv = KPConv(config.num_kernel_points, + config.in_points_dim, + out_dim // 4, + out_dim // 4, + current_extent, + radius, + fixed_kernel_points=config.fixed_kernel_points, + KP_influence=config.KP_influence, + aggregation_mode=config.aggregation_mode, + deformable='deform' in block_name, + modulated=config.modulated) + self.batch_norm_conv = BatchNormBlock(out_dim // 4, self.use_bn, self.bn_momentum) + + # Second upscaling mlp + self.unary2 = UnaryBlock(out_dim // 4, out_dim, act_fn, self.use_bn, self.bn_momentum, no_relu=True) + + # Shortcut optional mpl + if in_dim != out_dim: + self.unary_shortcut = UnaryBlock(in_dim, out_dim, act_fn, self.use_bn, self.bn_momentum, no_relu=True) + else: + self.unary_shortcut = nn.Identity() + + # Other operations + self.act = act_fn() + + return + + def forward(self, features, batch): + + if 'strided' in self.block_name: + q_pts = batch.points[self.layer_ind + 1] + s_pts = batch.points[self.layer_ind] + neighb_inds = batch.pools[self.layer_ind] + else: + q_pts = batch.points[self.layer_ind] + s_pts = batch.points[self.layer_ind] + neighb_inds = batch.neighbors[self.layer_ind] + + # First downscaling mlp + x = self.unary1(features) + + # Convolution + x = self.KPConv(q_pts, s_pts, neighb_inds, x) + x = self.act(self.batch_norm_conv(x)) + + # Second upscaling mlp + x = self.unary2(x) + + # Shortcut + if 'strided' in self.block_name: + shortcut = max_pool(features, neighb_inds) + else: + shortcut = features + shortcut = self.unary_shortcut(shortcut) + + return self.act(x + shortcut) + + +class GlobalAverageBlock(nn.Module): + + def __init__(self): + """ + Initialize a global average block with its ReLU and BatchNorm. + """ + super(GlobalAverageBlock, self).__init__() + return + + def forward(self, x, batch): + return global_average(x, batch.lengths[-1]) + + +class GlobalSumBlock(nn.Module): + + def __init__(self): + """ + Initialize a global average block with its ReLU and BatchNorm. + """ + super(GlobalSumBlock, self).__init__() + return + + def forward(self, x, batch): + return global_sum(x, batch.lengths[-1]) + + +class NearestUpsampleBlock(nn.Module): + + def __init__(self, layer_ind): + """ + Initialize a nearest upsampling block with its ReLU and BatchNorm. + """ + super(NearestUpsampleBlock, self).__init__() + self.layer_ind = layer_ind + return + + def forward(self, x, batch): + return closest_pool(x, batch.upsamples[self.layer_ind - 1]) + + def __repr__(self): + return 'NearestUpsampleBlock(layer: {:d} -> {:d})'.format(self.layer_ind, + self.layer_ind - 1) + + +class MaxPoolBlock(nn.Module): + + def __init__(self, layer_ind): + """ + Initialize a max pooling block with its ReLU and BatchNorm. + """ + super(MaxPoolBlock, self).__init__() + self.layer_ind = layer_ind + return + + def forward(self, x, batch): + return max_pool(x, batch.pools[self.layer_ind + 1]) diff --git a/torch-points3d/torch_points3d/modules/KPConv/common.py b/torch-points3d/torch_points3d/modules/KPConv/common.py new file mode 100644 index 0000000..bfdf8d1 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/common.py @@ -0,0 +1,157 @@ + +import numpy as np +from .kernel_points import create_3D_rotations + +import torch_points3d.modules.KPConv.cpp_wrappers.cpp_subsampling.grid_subsampling as cpp_subsampling +import torch_points3d.modules.KPConv.cpp_wrappers.cpp_neighbors.radius_neighbors as cpp_neighbors + +def grid_subsampling(points, features=None, labels=None, sampleDl=0.1, verbose=0): + """ + CPP wrapper for a grid subsampling (method = barycenter for points and features) + :param points: (N, 3) matrix of input points + :param features: optional (N, d) matrix of features (floating number) + :param labels: optional (N,) matrix of integer labels + :param sampleDl: parameter defining the size of grid voxels + :param verbose: 1 to display + :return: subsampled points, with features and/or labels depending on the input + """ + + if (features is None) and (labels is None): + return cpp_subsampling.subsample(points, + sampleDl=sampleDl, + verbose=verbose) + elif (labels is None): + return cpp_subsampling.subsample(points, + features=features, + sampleDl=sampleDl, + verbose=verbose) + elif (features is None): + return cpp_subsampling.subsample(points, + classes=labels, + sampleDl=sampleDl, + verbose=verbose) + else: + return cpp_subsampling.subsample(points, + features=features, + classes=labels, + sampleDl=sampleDl, + verbose=verbose) +def batch_grid_subsampling(points, batches_len, features=None, labels=None, + sampleDl=0.1, max_p=0, verbose=0, random_grid_orient=True): + """ + CPP wrapper for a grid subsampling (method = barycenter for points and features) + :param points: (N, 3) matrix of input points + :param features: optional (N, d) matrix of features (floating number) + :param labels: optional (N,) matrix of integer labels + :param sampleDl: parameter defining the size of grid voxels + :param verbose: 1 to display + :return: subsampled points, with features and/or labels depending on the input + """ + + R = None + B = len(batches_len) + if random_grid_orient: + + ######################################################## + # Create a random rotation matrix for each batch element + ######################################################## + + # Choose two random angles for the first vector in polar coordinates + theta = np.random.rand(B) * 2 * np.pi + phi = (np.random.rand(B) - 0.5) * np.pi + + # Create the first vector in carthesian coordinates + u = np.vstack([np.cos(theta) * np.cos(phi), np.sin(theta) * np.cos(phi), np.sin(phi)]) + + # Choose a random rotation angle + alpha = np.random.rand(B) * 2 * np.pi + + # Create the rotation matrix with this vector and angle + R = create_3D_rotations(u.T, alpha).astype(np.float32) + + ################# + # Apply rotations + ################# + + i0 = 0 + points = points.copy() + for bi, length in enumerate(batches_len): + # Apply the rotation + points[i0:i0 + length, :] = np.sum(np.expand_dims(points[i0:i0 + length, :], 2) * R[bi], axis=1) + i0 += length + + ####################### + # Subsample and realign + ####################### + + if (features is None) and (labels is None): + s_points, s_len = cpp_subsampling.subsample_batch(points, + batches_len, + sampleDl=sampleDl, + max_p=max_p, + verbose=verbose) + if random_grid_orient: + i0 = 0 + for bi, length in enumerate(s_len): + s_points[i0:i0 + length, :] = np.sum(np.expand_dims(s_points[i0:i0 + length, :], 2) * R[bi].T, axis=1) + i0 += length + return s_points, s_len + + elif (labels is None): + s_points, s_len, s_features = cpp_subsampling.subsample_batch(points, + batches_len, + features=features, + sampleDl=sampleDl, + max_p=max_p, + verbose=verbose) + if random_grid_orient: + i0 = 0 + for bi, length in enumerate(s_len): + # Apply the rotation + s_points[i0:i0 + length, :] = np.sum(np.expand_dims(s_points[i0:i0 + length, :], 2) * R[bi].T, axis=1) + i0 += length + return s_points, s_len, s_features + + elif (features is None): + s_points, s_len, s_labels = cpp_subsampling.subsample_batch(points, + batches_len, + classes=labels, + sampleDl=sampleDl, + max_p=max_p, + verbose=verbose) + if random_grid_orient: + i0 = 0 + for bi, length in enumerate(s_len): + # Apply the rotation + s_points[i0:i0 + length, :] = np.sum(np.expand_dims(s_points[i0:i0 + length, :], 2) * R[bi].T, axis=1) + i0 += length + return s_points, s_len, s_labels + + else: + s_points, s_len, s_features, s_labels = cpp_subsampling.subsample_batch(points, + batches_len, + features=features, + classes=labels, + sampleDl=sampleDl, + max_p=max_p, + verbose=verbose) + if random_grid_orient: + i0 = 0 + for bi, length in enumerate(s_len): + # Apply the rotation + s_points[i0:i0 + length, :] = np.sum(np.expand_dims(s_points[i0:i0 + length, :], 2) * R[bi].T, axis=1) + i0 += length + return s_points, s_len, s_features, s_labels + + +def batch_neighbors(queries, supports, q_batches, s_batches, radius): + """ + Computes neighbors for a batch of queries and supports + :param queries: (N1, 3) the query points + :param supports: (N2, 3) the support points + :param q_batches: (B) the list of lengths of batch elements in queries + :param s_batches: (B)the list of lengths of batch elements in supports + :param radius: float32 + :return: neighbors indices + """ + return cpp_neighbors.batch_query(queries, supports, q_batches, s_batches, radius=radius) diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/build.bat b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/build.bat new file mode 100644 index 0000000..8679a29 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/build.bat @@ -0,0 +1,5 @@ +@echo off +py setup.py build_ext --inplace + + +pause \ No newline at end of file diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp new file mode 100644 index 0000000..bf22af8 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.cpp @@ -0,0 +1,333 @@ + +#include "neighbors.h" + + +void brute_neighbors(vector& queries, vector& supports, vector& neighbors_indices, float radius, int verbose) +{ + + // Initialize variables + // ****************** + + // square radius + float r2 = radius * radius; + + // indices + int i0 = 0; + + // Counting vector + int max_count = 0; + vector> tmp(queries.size()); + + // Search neigbors indices + // *********************** + + for (auto& p0 : queries) + { + int i = 0; + for (auto& p : supports) + { + if ((p0 - p).sq_norm() < r2) + { + tmp[i0].push_back(i); + if (tmp[i0].size() > max_count) + max_count = tmp[i0].size(); + } + i++; + } + i0++; + } + + // Reserve the memory + neighbors_indices.resize(queries.size() * max_count); + i0 = 0; + for (auto& inds : tmp) + { + for (int j = 0; j < max_count; j++) + { + if (j < inds.size()) + neighbors_indices[i0 * max_count + j] = inds[j]; + else + neighbors_indices[i0 * max_count + j] = -1; + } + i0++; + } + + return; +} + +void ordered_neighbors(vector& queries, + vector& supports, + vector& neighbors_indices, + float radius) +{ + + // Initialize variables + // ****************** + + // square radius + float r2 = radius * radius; + + // indices + int i0 = 0; + + // Counting vector + int max_count = 0; + float d2; + vector> tmp(queries.size()); + vector> dists(queries.size()); + + // Search neigbors indices + // *********************** + + for (auto& p0 : queries) + { + int i = 0; + for (auto& p : supports) + { + d2 = (p0 - p).sq_norm(); + if (d2 < r2) + { + // Find order of the new point + auto it = std::upper_bound(dists[i0].begin(), dists[i0].end(), d2); + int index = std::distance(dists[i0].begin(), it); + + // Insert element + dists[i0].insert(it, d2); + tmp[i0].insert(tmp[i0].begin() + index, i); + + // Update max count + if (tmp[i0].size() > max_count) + max_count = tmp[i0].size(); + } + i++; + } + i0++; + } + + // Reserve the memory + neighbors_indices.resize(queries.size() * max_count); + i0 = 0; + for (auto& inds : tmp) + { + for (int j = 0; j < max_count; j++) + { + if (j < inds.size()) + neighbors_indices[i0 * max_count + j] = inds[j]; + else + neighbors_indices[i0 * max_count + j] = -1; + } + i0++; + } + + return; +} + +void batch_ordered_neighbors(vector& queries, + vector& supports, + vector& q_batches, + vector& s_batches, + vector& neighbors_indices, + float radius) +{ + + // Initialize variables + // ****************** + + // square radius + float r2 = radius * radius; + + // indices + int i0 = 0; + + // Counting vector + int max_count = 0; + float d2; + vector> tmp(queries.size()); + vector> dists(queries.size()); + + // batch index + int b = 0; + int sum_qb = 0; + int sum_sb = 0; + + + // Search neigbors indices + // *********************** + + for (auto& p0 : queries) + { + // Check if we changed batch + if (i0 == sum_qb + q_batches[b]) + { + sum_qb += q_batches[b]; + sum_sb += s_batches[b]; + b++; + } + + // Loop only over the supports of current batch + vector::iterator p_it; + int i = 0; + for(p_it = supports.begin() + sum_sb; p_it < supports.begin() + sum_sb + s_batches[b]; p_it++ ) + { + d2 = (p0 - *p_it).sq_norm(); + if (d2 < r2) + { + // Find order of the new point + auto it = std::upper_bound(dists[i0].begin(), dists[i0].end(), d2); + int index = std::distance(dists[i0].begin(), it); + + // Insert element + dists[i0].insert(it, d2); + tmp[i0].insert(tmp[i0].begin() + index, sum_sb + i); + + // Update max count + if (tmp[i0].size() > max_count) + max_count = tmp[i0].size(); + } + i++; + } + i0++; + } + + // Reserve the memory + neighbors_indices.resize(queries.size() * max_count); + i0 = 0; + for (auto& inds : tmp) + { + for (int j = 0; j < max_count; j++) + { + if (j < inds.size()) + neighbors_indices[i0 * max_count + j] = inds[j]; + else + neighbors_indices[i0 * max_count + j] = supports.size(); + } + i0++; + } + + return; +} + + +void batch_nanoflann_neighbors(vector& queries, + vector& supports, + vector& q_batches, + vector& s_batches, + vector& neighbors_indices, + float radius) +{ + + // Initialize variables + // ****************** + + // indices + int i0 = 0; + + // Square radius + float r2 = radius * radius; + + // Counting vector + int max_count = 0; + float d2; + vector>> all_inds_dists(queries.size()); + + // batch index + int b = 0; + int sum_qb = 0; + int sum_sb = 0; + + // Nanoflann related variables + // *************************** + + // CLoud variable + PointCloud current_cloud; + + // Tree parameters + nanoflann::KDTreeSingleIndexAdaptorParams tree_params(10 /* max leaf */); + + // KDTree type definition + typedef nanoflann::KDTreeSingleIndexAdaptor< nanoflann::L2_Simple_Adaptor , + PointCloud, + 3 > my_kd_tree_t; + + // Pointer to trees + my_kd_tree_t* index; + + // Build KDTree for the first batch element + current_cloud.pts = vector(supports.begin() + sum_sb, supports.begin() + sum_sb + s_batches[b]); + index = new my_kd_tree_t(3, current_cloud, tree_params); + index->buildIndex(); + + + // Search neigbors indices + // *********************** + + // Search params + nanoflann::SearchParams search_params; + search_params.sorted = true; + + for (auto& p0 : queries) + { + + // Check if we changed batch + if (i0 == sum_qb + q_batches[b]) + { + sum_qb += q_batches[b]; + sum_sb += s_batches[b]; + b++; + + // Change the points + current_cloud.pts.clear(); + current_cloud.pts = vector(supports.begin() + sum_sb, supports.begin() + sum_sb + s_batches[b]); + + // Build KDTree of the current element of the batch + delete index; + index = new my_kd_tree_t(3, current_cloud, tree_params); + index->buildIndex(); + } + + // Initial guess of neighbors size + all_inds_dists[i0].reserve(max_count); + + // Find neighbors + float query_pt[3] = { p0.x, p0.y, p0.z}; + size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params); + + // Update max count + if (nMatches > max_count) + max_count = nMatches; + + // Increment query idx + i0++; + } + + // Reserve the memory + neighbors_indices.resize(queries.size() * max_count); + i0 = 0; + sum_sb = 0; + sum_qb = 0; + b = 0; + for (auto& inds_dists : all_inds_dists) + { + // Check if we changed batch + if (i0 == sum_qb + q_batches[b]) + { + sum_qb += q_batches[b]; + sum_sb += s_batches[b]; + b++; + } + + for (int j = 0; j < max_count; j++) + { + if (j < inds_dists.size()) + neighbors_indices[i0 * max_count + j] = inds_dists[j].first + sum_sb; + else + neighbors_indices[i0 * max_count + j] = supports.size(); + } + i0++; + } + + delete index; + + return; +} + diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.h b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.h new file mode 100644 index 0000000..ff612b0 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/neighbors/neighbors.h @@ -0,0 +1,29 @@ + + +#include "../../cpp_utils/cloud/cloud.h" +#include "../../cpp_utils/nanoflann/nanoflann.hpp" + +#include +#include + +using namespace std; + + +void ordered_neighbors(vector& queries, + vector& supports, + vector& neighbors_indices, + float radius); + +void batch_ordered_neighbors(vector& queries, + vector& supports, + vector& q_batches, + vector& s_batches, + vector& neighbors_indices, + float radius); + +void batch_nanoflann_neighbors(vector& queries, + vector& supports, + vector& q_batches, + vector& s_batches, + vector& neighbors_indices, + float radius); diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/setup.py b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/setup.py new file mode 100644 index 0000000..8f53a9c --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/setup.py @@ -0,0 +1,28 @@ +from distutils.core import setup, Extension +import numpy.distutils.misc_util + +# Adding OpenCV to project +# ************************ + +# Adding sources of the project +# ***************************** + +SOURCES = ["../cpp_utils/cloud/cloud.cpp", + "neighbors/neighbors.cpp", + "wrapper.cpp"] + +module = Extension(name="radius_neighbors", + sources=SOURCES, + extra_compile_args=['-std=c++11', + '-D_GLIBCXX_USE_CXX11_ABI=0']) + + +setup(ext_modules=[module], include_dirs=numpy.distutils.misc_util.get_numpy_include_dirs()) + + + + + + + + diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/wrapper.cpp b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/wrapper.cpp new file mode 100644 index 0000000..a4e2809 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_neighbors/wrapper.cpp @@ -0,0 +1,238 @@ +#include +#include +#include "neighbors/neighbors.h" +#include + + + +// docstrings for our module +// ************************* + +static char module_docstring[] = "This module provides two methods to compute radius neighbors from pointclouds or batch of pointclouds"; + +static char batch_query_docstring[] = "Method to get radius neighbors in a batch of stacked pointclouds"; + + +// Declare the functions +// ********************* + +static PyObject *batch_neighbors(PyObject *self, PyObject *args, PyObject *keywds); + + +// Specify the members of the module +// ********************************* + +static PyMethodDef module_methods[] = +{ + { "batch_query", (PyCFunction)batch_neighbors, METH_VARARGS | METH_KEYWORDS, batch_query_docstring }, + {NULL, NULL, 0, NULL} +}; + + +// Initialize the module +// ********************* + +static struct PyModuleDef moduledef = +{ + PyModuleDef_HEAD_INIT, + "radius_neighbors", // m_name + module_docstring, // m_doc + -1, // m_size + module_methods, // m_methods + NULL, // m_reload + NULL, // m_traverse + NULL, // m_clear + NULL, // m_free +}; + +PyMODINIT_FUNC PyInit_radius_neighbors(void) +{ + import_array(); + return PyModule_Create(&moduledef); +} + + +// Definition of the batch_subsample method +// ********************************** + +static PyObject* batch_neighbors(PyObject* self, PyObject* args, PyObject* keywds) +{ + + // Manage inputs + // ************* + + // Args containers + PyObject* queries_obj = NULL; + PyObject* supports_obj = NULL; + PyObject* q_batches_obj = NULL; + PyObject* s_batches_obj = NULL; + + // Keywords containers + static char* kwlist[] = { "queries", "supports", "q_batches", "s_batches", "radius", NULL }; + float radius = 0.1; + + // Parse the input + if (!PyArg_ParseTupleAndKeywords(args, keywds, "OOOO|$f", kwlist, &queries_obj, &supports_obj, &q_batches_obj, &s_batches_obj, &radius)) + { + PyErr_SetString(PyExc_RuntimeError, "Error parsing arguments"); + return NULL; + } + + + // Interpret the input objects as numpy arrays. + PyObject* queries_array = PyArray_FROM_OTF(queries_obj, NPY_FLOAT, NPY_IN_ARRAY); + PyObject* supports_array = PyArray_FROM_OTF(supports_obj, NPY_FLOAT, NPY_IN_ARRAY); + PyObject* q_batches_array = PyArray_FROM_OTF(q_batches_obj, NPY_INT, NPY_IN_ARRAY); + PyObject* s_batches_array = PyArray_FROM_OTF(s_batches_obj, NPY_INT, NPY_IN_ARRAY); + + // Verify data was load correctly. + if (queries_array == NULL) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting query points to numpy arrays of type float32"); + return NULL; + } + if (supports_array == NULL) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting support points to numpy arrays of type float32"); + return NULL; + } + if (q_batches_array == NULL) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting query batches to numpy arrays of type int32"); + return NULL; + } + if (s_batches_array == NULL) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting support batches to numpy arrays of type int32"); + return NULL; + } + + // Check that the input array respect the dims + if ((int)PyArray_NDIM(queries_array) != 2 || (int)PyArray_DIM(queries_array, 1) != 3) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : query.shape is not (N, 3)"); + return NULL; + } + if ((int)PyArray_NDIM(supports_array) != 2 || (int)PyArray_DIM(supports_array, 1) != 3) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : support.shape is not (N, 3)"); + return NULL; + } + if ((int)PyArray_NDIM(q_batches_array) > 1) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : queries_batches.shape is not (B,) "); + return NULL; + } + if ((int)PyArray_NDIM(s_batches_array) > 1) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : supports_batches.shape is not (B,) "); + return NULL; + } + if ((int)PyArray_DIM(q_batches_array, 0) != (int)PyArray_DIM(s_batches_array, 0)) + { + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong number of batch elements: different for queries and supports "); + return NULL; + } + + // Number of points + int Nq = (int)PyArray_DIM(queries_array, 0); + int Ns= (int)PyArray_DIM(supports_array, 0); + + // Number of batches + int Nb = (int)PyArray_DIM(q_batches_array, 0); + + // Call the C++ function + // ********************* + + // Convert PyArray to Cloud C++ class + vector queries; + vector supports; + vector q_batches; + vector s_batches; + queries = vector((PointXYZ*)PyArray_DATA(queries_array), (PointXYZ*)PyArray_DATA(queries_array) + Nq); + supports = vector((PointXYZ*)PyArray_DATA(supports_array), (PointXYZ*)PyArray_DATA(supports_array) + Ns); + q_batches = vector((int*)PyArray_DATA(q_batches_array), (int*)PyArray_DATA(q_batches_array) + Nb); + s_batches = vector((int*)PyArray_DATA(s_batches_array), (int*)PyArray_DATA(s_batches_array) + Nb); + + // Create result containers + vector neighbors_indices; + + // Compute results + //batch_ordered_neighbors(queries, supports, q_batches, s_batches, neighbors_indices, radius); + batch_nanoflann_neighbors(queries, supports, q_batches, s_batches, neighbors_indices, radius); + + // Check result + if (neighbors_indices.size() < 1) + { + PyErr_SetString(PyExc_RuntimeError, "Error"); + return NULL; + } + + // Manage outputs + // ************** + + // Maximal number of neighbors + int max_neighbors = neighbors_indices.size() / Nq; + + // Dimension of output containers + npy_intp* neighbors_dims = new npy_intp[2]; + neighbors_dims[0] = Nq; + neighbors_dims[1] = max_neighbors; + + // Create output array + PyObject* res_obj = PyArray_SimpleNew(2, neighbors_dims, NPY_INT); + PyObject* ret = NULL; + + // Fill output array with values + size_t size_in_bytes = Nq * max_neighbors * sizeof(int); + memcpy(PyArray_DATA(res_obj), neighbors_indices.data(), size_in_bytes); + + // Merge results + ret = Py_BuildValue("N", res_obj); + + // Clean up + // ******** + + Py_XDECREF(queries_array); + Py_XDECREF(supports_array); + Py_XDECREF(q_batches_array); + Py_XDECREF(s_batches_array); + + return ret; +} diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp new file mode 100644 index 0000000..24276bb --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.cpp @@ -0,0 +1,211 @@ + +#include "grid_subsampling.h" + + +void grid_subsampling(vector& original_points, + vector& subsampled_points, + vector& original_features, + vector& subsampled_features, + vector& original_classes, + vector& subsampled_classes, + float sampleDl, + int verbose) { + + // Initialize variables + // ****************** + + // Number of points in the cloud + size_t N = original_points.size(); + + // Dimension of the features + size_t fdim = original_features.size() / N; + size_t ldim = original_classes.size() / N; + + // Limits of the cloud + PointXYZ minCorner = min_point(original_points); + PointXYZ maxCorner = max_point(original_points); + PointXYZ originCorner = floor(minCorner * (1/sampleDl)) * sampleDl; + + // Dimensions of the grid + size_t sampleNX = (size_t)floor((maxCorner.x - originCorner.x) / sampleDl) + 1; + size_t sampleNY = (size_t)floor((maxCorner.y - originCorner.y) / sampleDl) + 1; + //size_t sampleNZ = (size_t)floor((maxCorner.z - originCorner.z) / sampleDl) + 1; + + // Check if features and classes need to be processed + bool use_feature = original_features.size() > 0; + bool use_classes = original_classes.size() > 0; + + + // Create the sampled map + // ********************** + + // Verbose parameters + int i = 0; + int nDisp = N / 100; + + // Initialize variables + size_t iX, iY, iZ, mapIdx; + unordered_map data; + + for (auto& p : original_points) + { + // Position of point in sample map + iX = (size_t)floor((p.x - originCorner.x) / sampleDl); + iY = (size_t)floor((p.y - originCorner.y) / sampleDl); + iZ = (size_t)floor((p.z - originCorner.z) / sampleDl); + mapIdx = iX + sampleNX*iY + sampleNX*sampleNY*iZ; + + // If not already created, create key + if (data.count(mapIdx) < 1) + data.emplace(mapIdx, SampledData(fdim, ldim)); + + // Fill the sample map + if (use_feature && use_classes) + data[mapIdx].update_all(p, original_features.begin() + i * fdim, original_classes.begin() + i * ldim); + else if (use_feature) + data[mapIdx].update_features(p, original_features.begin() + i * fdim); + else if (use_classes) + data[mapIdx].update_classes(p, original_classes.begin() + i * ldim); + else + data[mapIdx].update_points(p); + + // Display + i++; + if (verbose > 1 && i%nDisp == 0) + std::cout << "\rSampled Map : " << std::setw(3) << i / nDisp << "%"; + + } + + // Divide for barycentre and transfer to a vector + subsampled_points.reserve(data.size()); + if (use_feature) + subsampled_features.reserve(data.size() * fdim); + if (use_classes) + subsampled_classes.reserve(data.size() * ldim); + for (auto& v : data) + { + subsampled_points.push_back(v.second.point * (1.0 / v.second.count)); + if (use_feature) + { + float count = (float)v.second.count; + transform(v.second.features.begin(), + v.second.features.end(), + v.second.features.begin(), + [count](float f) { return f / count;}); + subsampled_features.insert(subsampled_features.end(),v.second.features.begin(),v.second.features.end()); + } + if (use_classes) + { + for (int i = 0; i < ldim; i++) + subsampled_classes.push_back(max_element(v.second.labels[i].begin(), v.second.labels[i].end(), + [](const pair&a, const pair&b){return a.second < b.second;})->first); + } + } + + return; +} + + +void batch_grid_subsampling(vector& original_points, + vector& subsampled_points, + vector& original_features, + vector& subsampled_features, + vector& original_classes, + vector& subsampled_classes, + vector& original_batches, + vector& subsampled_batches, + float sampleDl, + int max_p) +{ + // Initialize variables + // ****************** + + int b = 0; + int sum_b = 0; + + // Number of points in the cloud + size_t N = original_points.size(); + + // Dimension of the features + size_t fdim = original_features.size() / N; + size_t ldim = original_classes.size() / N; + + // Handle max_p = 0 + if (max_p < 1) + max_p = N; + + // Loop over batches + // ***************** + + for (b = 0; b < original_batches.size(); b++) + { + + // Extract batch points features and labels + vector b_o_points = vector(original_points.begin () + sum_b, + original_points.begin () + sum_b + original_batches[b]); + + vector b_o_features; + if (original_features.size() > 0) + { + b_o_features = vector(original_features.begin () + sum_b * fdim, + original_features.begin () + (sum_b + original_batches[b]) * fdim); + } + + vector b_o_classes; + if (original_classes.size() > 0) + { + b_o_classes = vector(original_classes.begin () + sum_b * ldim, + original_classes.begin () + sum_b + original_batches[b] * ldim); + } + + + // Create result containers + vector b_s_points; + vector b_s_features; + vector b_s_classes; + + // Compute subsampling on current batch + grid_subsampling(b_o_points, + b_s_points, + b_o_features, + b_s_features, + b_o_classes, + b_s_classes, + sampleDl, + 0); + + // Stack batches points features and labels + // **************************************** + + // If too many points remove some + if (b_s_points.size() <= max_p) + { + subsampled_points.insert(subsampled_points.end(), b_s_points.begin(), b_s_points.end()); + + if (original_features.size() > 0) + subsampled_features.insert(subsampled_features.end(), b_s_features.begin(), b_s_features.end()); + + if (original_classes.size() > 0) + subsampled_classes.insert(subsampled_classes.end(), b_s_classes.begin(), b_s_classes.end()); + + subsampled_batches.push_back(b_s_points.size()); + } + else + { + subsampled_points.insert(subsampled_points.end(), b_s_points.begin(), b_s_points.begin() + max_p); + + if (original_features.size() > 0) + subsampled_features.insert(subsampled_features.end(), b_s_features.begin(), b_s_features.begin() + max_p * fdim); + + if (original_classes.size() > 0) + subsampled_classes.insert(subsampled_classes.end(), b_s_classes.begin(), b_s_classes.begin() + max_p * ldim); + + subsampled_batches.push_back(max_p); + } + + // Stack new batch lengths + sum_b += original_batches[b]; + } + + return; +} diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h new file mode 100644 index 0000000..37f775d --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/grid_subsampling/grid_subsampling.h @@ -0,0 +1,101 @@ + + +#include "../../cpp_utils/cloud/cloud.h" + +#include +#include + +using namespace std; + +class SampledData +{ +public: + + // Elements + // ******** + + int count; + PointXYZ point; + vector features; + vector> labels; + + + // Methods + // ******* + + // Constructor + SampledData() + { + count = 0; + point = PointXYZ(); + } + + SampledData(const size_t fdim, const size_t ldim) + { + count = 0; + point = PointXYZ(); + features = vector(fdim); + labels = vector>(ldim); + } + + // Method Update + void update_all(const PointXYZ p, vector::iterator f_begin, vector::iterator l_begin) + { + count += 1; + point += p; + transform (features.begin(), features.end(), f_begin, features.begin(), plus()); + int i = 0; + for(vector::iterator it = l_begin; it != l_begin + labels.size(); ++it) + { + labels[i][*it] += 1; + i++; + } + return; + } + void update_features(const PointXYZ p, vector::iterator f_begin) + { + count += 1; + point += p; + transform (features.begin(), features.end(), f_begin, features.begin(), plus()); + return; + } + void update_classes(const PointXYZ p, vector::iterator l_begin) + { + count += 1; + point += p; + int i = 0; + for(vector::iterator it = l_begin; it != l_begin + labels.size(); ++it) + { + labels[i][*it] += 1; + i++; + } + return; + } + void update_points(const PointXYZ p) + { + count += 1; + point += p; + return; + } +}; + +void grid_subsampling(vector& original_points, + vector& subsampled_points, + vector& original_features, + vector& subsampled_features, + vector& original_classes, + vector& subsampled_classes, + float sampleDl, + int verbose); + +void batch_grid_subsampling(vector& original_points, + vector& subsampled_points, + vector& original_features, + vector& subsampled_features, + vector& original_classes, + vector& subsampled_classes, + vector& original_batches, + vector& subsampled_batches, + float sampleDl, + int max_p); + diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/setup.py b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/setup.py new file mode 100644 index 0000000..3206299 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/setup.py @@ -0,0 +1,28 @@ +from distutils.core import setup, Extension +import numpy.distutils.misc_util + +# Adding OpenCV to project +# ************************ + +# Adding sources of the project +# ***************************** + +SOURCES = ["../cpp_utils/cloud/cloud.cpp", + "grid_subsampling/grid_subsampling.cpp", + "wrapper.cpp"] + +module = Extension(name="grid_subsampling", + sources=SOURCES, + extra_compile_args=['-std=c++11', + '-D_GLIBCXX_USE_CXX11_ABI=0']) + + +setup(ext_modules=[module], include_dirs=numpy.distutils.misc_util.get_numpy_include_dirs()) + + + + + + + + diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/wrapper.cpp b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/wrapper.cpp new file mode 100644 index 0000000..8a92aaa --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_subsampling/wrapper.cpp @@ -0,0 +1,566 @@ +#include +#include +#include "grid_subsampling/grid_subsampling.h" +#include + + + +// docstrings for our module +// ************************* + +static char module_docstring[] = "This module provides an interface for the subsampling of a batch of stacked pointclouds"; + +static char subsample_docstring[] = "function subsampling a pointcloud"; + +static char subsample_batch_docstring[] = "function subsampling a batch of stacked pointclouds"; + + +// Declare the functions +// ********************* + +static PyObject *cloud_subsampling(PyObject* self, PyObject* args, PyObject* keywds); +static PyObject *batch_subsampling(PyObject *self, PyObject *args, PyObject *keywds); + + +// Specify the members of the module +// ********************************* + +static PyMethodDef module_methods[] = +{ + { "subsample", (PyCFunction)cloud_subsampling, METH_VARARGS | METH_KEYWORDS, subsample_docstring }, + { "subsample_batch", (PyCFunction)batch_subsampling, METH_VARARGS | METH_KEYWORDS, subsample_batch_docstring }, + {NULL, NULL, 0, NULL} +}; + + +// Initialize the module +// ********************* + +static struct PyModuleDef moduledef = +{ + PyModuleDef_HEAD_INIT, + "grid_subsampling", // m_name + module_docstring, // m_doc + -1, // m_size + module_methods, // m_methods + NULL, // m_reload + NULL, // m_traverse + NULL, // m_clear + NULL, // m_free +}; + +PyMODINIT_FUNC PyInit_grid_subsampling(void) +{ + import_array(); + return PyModule_Create(&moduledef); +} + + +// Definition of the batch_subsample method +// ********************************** + +static PyObject* batch_subsampling(PyObject* self, PyObject* args, PyObject* keywds) +{ + + // Manage inputs + // ************* + + // Args containers + PyObject* points_obj = NULL; + PyObject* features_obj = NULL; + PyObject* classes_obj = NULL; + PyObject* batches_obj = NULL; + + // Keywords containers + static char* kwlist[] = { "points", "batches", "features", "classes", "sampleDl", "method", "max_p", "verbose", NULL }; + float sampleDl = 0.1; + const char* method_buffer = "barycenters"; + int verbose = 0; + int max_p = 0; + + // Parse the input + if (!PyArg_ParseTupleAndKeywords(args, keywds, "OO|$OOfsii", kwlist, &points_obj, &batches_obj, &features_obj, &classes_obj, &sampleDl, &method_buffer, &max_p, &verbose)) + { + PyErr_SetString(PyExc_RuntimeError, "Error parsing arguments"); + return NULL; + } + + // Get the method argument + string method(method_buffer); + + // Interpret method + if (method.compare("barycenters") && method.compare("voxelcenters")) + { + PyErr_SetString(PyExc_RuntimeError, "Error parsing method. Valid method names are \"barycenters\" and \"voxelcenters\" "); + return NULL; + } + + // Check if using features or classes + bool use_feature = true, use_classes = true; + if (features_obj == NULL) + use_feature = false; + if (classes_obj == NULL) + use_classes = false; + + // Interpret the input objects as numpy arrays. + PyObject* points_array = PyArray_FROM_OTF(points_obj, NPY_FLOAT, NPY_IN_ARRAY); + PyObject* batches_array = PyArray_FROM_OTF(batches_obj, NPY_INT, NPY_IN_ARRAY); + PyObject* features_array = NULL; + PyObject* classes_array = NULL; + if (use_feature) + features_array = PyArray_FROM_OTF(features_obj, NPY_FLOAT, NPY_IN_ARRAY); + if (use_classes) + classes_array = PyArray_FROM_OTF(classes_obj, NPY_INT, NPY_IN_ARRAY); + + // Verify data was load correctly. + if (points_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input points to numpy arrays of type float32"); + return NULL; + } + if (batches_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input batches to numpy arrays of type int32"); + return NULL; + } + if (use_feature && features_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input features to numpy arrays of type float32"); + return NULL; + } + if (use_classes && classes_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input classes to numpy arrays of type int32"); + return NULL; + } + + // Check that the input array respect the dims + if ((int)PyArray_NDIM(points_array) != 2 || (int)PyArray_DIM(points_array, 1) != 3) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : points.shape is not (N, 3)"); + return NULL; + } + if ((int)PyArray_NDIM(batches_array) > 1) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : batches.shape is not (B,) "); + return NULL; + } + if (use_feature && ((int)PyArray_NDIM(features_array) != 2)) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : features.shape is not (N, d)"); + return NULL; + } + + if (use_classes && (int)PyArray_NDIM(classes_array) > 2) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : classes.shape is not (N,) or (N, d)"); + return NULL; + } + + // Number of points + int N = (int)PyArray_DIM(points_array, 0); + + // Number of batches + int Nb = (int)PyArray_DIM(batches_array, 0); + + // Dimension of the features + int fdim = 0; + if (use_feature) + fdim = (int)PyArray_DIM(features_array, 1); + + //Dimension of labels + int ldim = 1; + if (use_classes && (int)PyArray_NDIM(classes_array) == 2) + ldim = (int)PyArray_DIM(classes_array, 1); + + // Check that the input array respect the number of points + if (use_feature && (int)PyArray_DIM(features_array, 0) != N) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : features.shape is not (N, d)"); + return NULL; + } + if (use_classes && (int)PyArray_DIM(classes_array, 0) != N) + { + Py_XDECREF(points_array); + Py_XDECREF(batches_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : classes.shape is not (N,) or (N, d)"); + return NULL; + } + + + // Call the C++ function + // ********************* + + // Create pyramid + if (verbose > 0) + cout << "Computing cloud pyramid with support points: " << endl; + + + // Convert PyArray to Cloud C++ class + vector original_points; + vector original_batches; + vector original_features; + vector original_classes; + original_points = vector((PointXYZ*)PyArray_DATA(points_array), (PointXYZ*)PyArray_DATA(points_array) + N); + original_batches = vector((int*)PyArray_DATA(batches_array), (int*)PyArray_DATA(batches_array) + Nb); + if (use_feature) + original_features = vector((float*)PyArray_DATA(features_array), (float*)PyArray_DATA(features_array) + N * fdim); + if (use_classes) + original_classes = vector((int*)PyArray_DATA(classes_array), (int*)PyArray_DATA(classes_array) + N * ldim); + + // Subsample + vector subsampled_points; + vector subsampled_features; + vector subsampled_classes; + vector subsampled_batches; + batch_grid_subsampling(original_points, + subsampled_points, + original_features, + subsampled_features, + original_classes, + subsampled_classes, + original_batches, + subsampled_batches, + sampleDl, + max_p); + + // Check result + if (subsampled_points.size() < 1) + { + PyErr_SetString(PyExc_RuntimeError, "Error"); + return NULL; + } + + // Manage outputs + // ************** + + // Dimension of input containers + npy_intp* point_dims = new npy_intp[2]; + point_dims[0] = subsampled_points.size(); + point_dims[1] = 3; + npy_intp* feature_dims = new npy_intp[2]; + feature_dims[0] = subsampled_points.size(); + feature_dims[1] = fdim; + npy_intp* classes_dims = new npy_intp[2]; + classes_dims[0] = subsampled_points.size(); + classes_dims[1] = ldim; + npy_intp* batches_dims = new npy_intp[1]; + batches_dims[0] = Nb; + + // Create output array + PyObject* res_points_obj = PyArray_SimpleNew(2, point_dims, NPY_FLOAT); + PyObject* res_batches_obj = PyArray_SimpleNew(1, batches_dims, NPY_INT); + PyObject* res_features_obj = NULL; + PyObject* res_classes_obj = NULL; + PyObject* ret = NULL; + + // Fill output array with values + size_t size_in_bytes = subsampled_points.size() * 3 * sizeof(float); + memcpy(PyArray_DATA(res_points_obj), subsampled_points.data(), size_in_bytes); + size_in_bytes = Nb * sizeof(int); + memcpy(PyArray_DATA(res_batches_obj), subsampled_batches.data(), size_in_bytes); + if (use_feature) + { + size_in_bytes = subsampled_points.size() * fdim * sizeof(float); + res_features_obj = PyArray_SimpleNew(2, feature_dims, NPY_FLOAT); + memcpy(PyArray_DATA(res_features_obj), subsampled_features.data(), size_in_bytes); + } + if (use_classes) + { + size_in_bytes = subsampled_points.size() * ldim * sizeof(int); + res_classes_obj = PyArray_SimpleNew(2, classes_dims, NPY_INT); + memcpy(PyArray_DATA(res_classes_obj), subsampled_classes.data(), size_in_bytes); + } + + + // Merge results + if (use_feature && use_classes) + ret = Py_BuildValue("NNNN", res_points_obj, res_batches_obj, res_features_obj, res_classes_obj); + else if (use_feature) + ret = Py_BuildValue("NNN", res_points_obj, res_batches_obj, res_features_obj); + else if (use_classes) + ret = Py_BuildValue("NNN", res_points_obj, res_batches_obj, res_classes_obj); + else + ret = Py_BuildValue("NN", res_points_obj, res_batches_obj); + + // Clean up + // ******** + + Py_DECREF(points_array); + Py_DECREF(batches_array); + Py_XDECREF(features_array); + Py_XDECREF(classes_array); + + return ret; +} + +// Definition of the subsample method +// **************************************** + +static PyObject* cloud_subsampling(PyObject* self, PyObject* args, PyObject* keywds) +{ + + // Manage inputs + // ************* + + // Args containers + PyObject* points_obj = NULL; + PyObject* features_obj = NULL; + PyObject* classes_obj = NULL; + + // Keywords containers + static char* kwlist[] = { "points", "features", "classes", "sampleDl", "method", "verbose", NULL }; + float sampleDl = 0.1; + const char* method_buffer = "barycenters"; + int verbose = 0; + + // Parse the input + if (!PyArg_ParseTupleAndKeywords(args, keywds, "O|$OOfsi", kwlist, &points_obj, &features_obj, &classes_obj, &sampleDl, &method_buffer, &verbose)) + { + PyErr_SetString(PyExc_RuntimeError, "Error parsing arguments"); + return NULL; + } + + // Get the method argument + string method(method_buffer); + + // Interpret method + if (method.compare("barycenters") && method.compare("voxelcenters")) + { + PyErr_SetString(PyExc_RuntimeError, "Error parsing method. Valid method names are \"barycenters\" and \"voxelcenters\" "); + return NULL; + } + + // Check if using features or classes + bool use_feature = true, use_classes = true; + if (features_obj == NULL) + use_feature = false; + if (classes_obj == NULL) + use_classes = false; + + // Interpret the input objects as numpy arrays. + PyObject* points_array = PyArray_FROM_OTF(points_obj, NPY_FLOAT, NPY_IN_ARRAY); + PyObject* features_array = NULL; + PyObject* classes_array = NULL; + if (use_feature) + features_array = PyArray_FROM_OTF(features_obj, NPY_FLOAT, NPY_IN_ARRAY); + if (use_classes) + classes_array = PyArray_FROM_OTF(classes_obj, NPY_INT, NPY_IN_ARRAY); + + // Verify data was load correctly. + if (points_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input points to numpy arrays of type float32"); + return NULL; + } + if (use_feature && features_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input features to numpy arrays of type float32"); + return NULL; + } + if (use_classes && classes_array == NULL) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Error converting input classes to numpy arrays of type int32"); + return NULL; + } + + // Check that the input array respect the dims + if ((int)PyArray_NDIM(points_array) != 2 || (int)PyArray_DIM(points_array, 1) != 3) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : points.shape is not (N, 3)"); + return NULL; + } + if (use_feature && ((int)PyArray_NDIM(features_array) != 2)) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : features.shape is not (N, d)"); + return NULL; + } + + if (use_classes && (int)PyArray_NDIM(classes_array) > 2) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : classes.shape is not (N,) or (N, d)"); + return NULL; + } + + // Number of points + int N = (int)PyArray_DIM(points_array, 0); + + // Dimension of the features + int fdim = 0; + if (use_feature) + fdim = (int)PyArray_DIM(features_array, 1); + + //Dimension of labels + int ldim = 1; + if (use_classes && (int)PyArray_NDIM(classes_array) == 2) + ldim = (int)PyArray_DIM(classes_array, 1); + + // Check that the input array respect the number of points + if (use_feature && (int)PyArray_DIM(features_array, 0) != N) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : features.shape is not (N, d)"); + return NULL; + } + if (use_classes && (int)PyArray_DIM(classes_array, 0) != N) + { + Py_XDECREF(points_array); + Py_XDECREF(classes_array); + Py_XDECREF(features_array); + PyErr_SetString(PyExc_RuntimeError, "Wrong dimensions : classes.shape is not (N,) or (N, d)"); + return NULL; + } + + + // Call the C++ function + // ********************* + + // Create pyramid + if (verbose > 0) + cout << "Computing cloud pyramid with support points: " << endl; + + + // Convert PyArray to Cloud C++ class + vector original_points; + vector original_features; + vector original_classes; + original_points = vector((PointXYZ*)PyArray_DATA(points_array), (PointXYZ*)PyArray_DATA(points_array) + N); + if (use_feature) + original_features = vector((float*)PyArray_DATA(features_array), (float*)PyArray_DATA(features_array) + N * fdim); + if (use_classes) + original_classes = vector((int*)PyArray_DATA(classes_array), (int*)PyArray_DATA(classes_array) + N * ldim); + + // Subsample + vector subsampled_points; + vector subsampled_features; + vector subsampled_classes; + grid_subsampling(original_points, + subsampled_points, + original_features, + subsampled_features, + original_classes, + subsampled_classes, + sampleDl, + verbose); + + // Check result + if (subsampled_points.size() < 1) + { + PyErr_SetString(PyExc_RuntimeError, "Error"); + return NULL; + } + + // Manage outputs + // ************** + + // Dimension of input containers + npy_intp* point_dims = new npy_intp[2]; + point_dims[0] = subsampled_points.size(); + point_dims[1] = 3; + npy_intp* feature_dims = new npy_intp[2]; + feature_dims[0] = subsampled_points.size(); + feature_dims[1] = fdim; + npy_intp* classes_dims = new npy_intp[2]; + classes_dims[0] = subsampled_points.size(); + classes_dims[1] = ldim; + + // Create output array + PyObject* res_points_obj = PyArray_SimpleNew(2, point_dims, NPY_FLOAT); + PyObject* res_features_obj = NULL; + PyObject* res_classes_obj = NULL; + PyObject* ret = NULL; + + // Fill output array with values + size_t size_in_bytes = subsampled_points.size() * 3 * sizeof(float); + memcpy(PyArray_DATA(res_points_obj), subsampled_points.data(), size_in_bytes); + if (use_feature) + { + size_in_bytes = subsampled_points.size() * fdim * sizeof(float); + res_features_obj = PyArray_SimpleNew(2, feature_dims, NPY_FLOAT); + memcpy(PyArray_DATA(res_features_obj), subsampled_features.data(), size_in_bytes); + } + if (use_classes) + { + size_in_bytes = subsampled_points.size() * ldim * sizeof(int); + res_classes_obj = PyArray_SimpleNew(2, classes_dims, NPY_INT); + memcpy(PyArray_DATA(res_classes_obj), subsampled_classes.data(), size_in_bytes); + } + + + // Merge results + if (use_feature && use_classes) + ret = Py_BuildValue("NNN", res_points_obj, res_features_obj, res_classes_obj); + else if (use_feature) + ret = Py_BuildValue("NN", res_points_obj, res_features_obj); + else if (use_classes) + ret = Py_BuildValue("NN", res_points_obj, res_classes_obj); + else + ret = Py_BuildValue("N", res_points_obj); + + // Clean up + // ******** + + Py_DECREF(points_array); + Py_XDECREF(features_array); + Py_XDECREF(classes_array); + + return ret; +} \ No newline at end of file diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/cloud/cloud.cpp b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/cloud/cloud.cpp new file mode 100644 index 0000000..c285140 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/cloud/cloud.cpp @@ -0,0 +1,67 @@ +// +// +// 0==========================0 +// | Local feature test | +// 0==========================0 +// +// version 1.0 : +// > +// +//--------------------------------------------------- +// +// Cloud source : +// Define usefull Functions/Methods +// +//---------------------------------------------------- +// +// Hugues THOMAS - 10/02/2017 +// + + +#include "cloud.h" + + +// Getters +// ******* + +PointXYZ max_point(std::vector points) +{ + // Initialize limits + PointXYZ maxP(points[0]); + + // Loop over all points + for (auto p : points) + { + if (p.x > maxP.x) + maxP.x = p.x; + + if (p.y > maxP.y) + maxP.y = p.y; + + if (p.z > maxP.z) + maxP.z = p.z; + } + + return maxP; +} + +PointXYZ min_point(std::vector points) +{ + // Initialize limits + PointXYZ minP(points[0]); + + // Loop over all points + for (auto p : points) + { + if (p.x < minP.x) + minP.x = p.x; + + if (p.y < minP.y) + minP.y = p.y; + + if (p.z < minP.z) + minP.z = p.z; + } + + return minP; +} \ No newline at end of file diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/cloud/cloud.h b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/cloud/cloud.h new file mode 100644 index 0000000..99d4e19 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/cloud/cloud.h @@ -0,0 +1,185 @@ +// +// +// 0==========================0 +// | Local feature test | +// 0==========================0 +// +// version 1.0 : +// > +// +//--------------------------------------------------- +// +// Cloud header +// +//---------------------------------------------------- +// +// Hugues THOMAS - 10/02/2017 +// + + +# pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + + + + +// Point class +// *********** + + +class PointXYZ +{ +public: + + // Elements + // ******** + + float x, y, z; + + + // Methods + // ******* + + // Constructor + PointXYZ() { x = 0; y = 0; z = 0; } + PointXYZ(float x0, float y0, float z0) { x = x0; y = y0; z = z0; } + + // array type accessor + float operator [] (int i) const + { + if (i == 0) return x; + else if (i == 1) return y; + else return z; + } + + // opperations + float dot(const PointXYZ P) const + { + return x * P.x + y * P.y + z * P.z; + } + + float sq_norm() + { + return x*x + y*y + z*z; + } + + PointXYZ cross(const PointXYZ P) const + { + return PointXYZ(y*P.z - z*P.y, z*P.x - x*P.z, x*P.y - y*P.x); + } + + PointXYZ& operator+=(const PointXYZ& P) + { + x += P.x; + y += P.y; + z += P.z; + return *this; + } + + PointXYZ& operator-=(const PointXYZ& P) + { + x -= P.x; + y -= P.y; + z -= P.z; + return *this; + } + + PointXYZ& operator*=(const float& a) + { + x *= a; + y *= a; + z *= a; + return *this; + } +}; + + +// Point Opperations +// ***************** + +inline PointXYZ operator + (const PointXYZ A, const PointXYZ B) +{ + return PointXYZ(A.x + B.x, A.y + B.y, A.z + B.z); +} + +inline PointXYZ operator - (const PointXYZ A, const PointXYZ B) +{ + return PointXYZ(A.x - B.x, A.y - B.y, A.z - B.z); +} + +inline PointXYZ operator * (const PointXYZ P, const float a) +{ + return PointXYZ(P.x * a, P.y * a, P.z * a); +} + +inline PointXYZ operator * (const float a, const PointXYZ P) +{ + return PointXYZ(P.x * a, P.y * a, P.z * a); +} + +inline std::ostream& operator << (std::ostream& os, const PointXYZ P) +{ + return os << "[" << P.x << ", " << P.y << ", " << P.z << "]"; +} + +inline bool operator == (const PointXYZ A, const PointXYZ B) +{ + return A.x == B.x && A.y == B.y && A.z == B.z; +} + +inline PointXYZ floor(const PointXYZ P) +{ + return PointXYZ(std::floor(P.x), std::floor(P.y), std::floor(P.z)); +} + + +PointXYZ max_point(std::vector points); +PointXYZ min_point(std::vector points); + + +struct PointCloud +{ + + std::vector pts; + + // Must return the number of data points + inline size_t kdtree_get_point_count() const { return pts.size(); } + + // Returns the dim'th component of the idx'th point in the class: + // Since this is inlined and the "dim" argument is typically an immediate value, the + // "if/else's" are actually solved at compile time. + inline float kdtree_get_pt(const size_t idx, const size_t dim) const + { + if (dim == 0) return pts[idx].x; + else if (dim == 1) return pts[idx].y; + else return pts[idx].z; + } + + // Optional bounding-box computation: return false to default to a standard bbox computation loop. + // Return true if the BBOX was already computed by the class and returned in "bb" so it can be avoided to redo it again. + // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 for point clouds) + template + bool kdtree_get_bbox(BBOX& /* bb */) const { return false; } + +}; + + + + + + + + + + + diff --git a/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/nanoflann/nanoflann.hpp b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/nanoflann/nanoflann.hpp new file mode 100644 index 0000000..8d2ab6c --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/cpp_wrappers/cpp_utils/nanoflann/nanoflann.hpp @@ -0,0 +1,2043 @@ +/*********************************************************************** + * Software License Agreement (BSD License) + * + * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. + * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. + * Copyright 2011-2016 Jose Luis Blanco (joseluisblancoc@gmail.com). + * All rights reserved. + * + * THE BSD LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * + * 1. Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES + * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. + * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, + * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT + * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF + * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + *************************************************************************/ + +/** \mainpage nanoflann C++ API documentation + * nanoflann is a C++ header-only library for building KD-Trees, mostly + * optimized for 2D or 3D point clouds. + * + * nanoflann does not require compiling or installing, just an + * #include in your code. + * + * See: + * - C++ API organized by modules + * - Online README + * - Doxygen + * documentation + */ + +#ifndef NANOFLANN_HPP_ +#define NANOFLANN_HPP_ + +#include +#include +#include +#include // for abs() +#include // for fwrite() +#include // for abs() +#include +#include // std::reference_wrapper +#include +#include + +/** Library version: 0xMmP (M=Major,m=minor,P=patch) */ +#define NANOFLANN_VERSION 0x130 + +// Avoid conflicting declaration of min/max macros in windows headers +#if !defined(NOMINMAX) && \ + (defined(_WIN32) || defined(_WIN32_) || defined(WIN32) || defined(_WIN64)) +#define NOMINMAX +#ifdef max +#undef max +#undef min +#endif +#endif + +namespace nanoflann { +/** @addtogroup nanoflann_grp nanoflann C++ library for ANN + * @{ */ + +/** the PI constant (required to avoid MSVC missing symbols) */ +template T pi_const() { + return static_cast(3.14159265358979323846); +} + +/** + * Traits if object is resizable and assignable (typically has a resize | assign + * method) + */ +template struct has_resize : std::false_type {}; + +template +struct has_resize().resize(1), 0)> + : std::true_type {}; + +template struct has_assign : std::false_type {}; + +template +struct has_assign().assign(1, 0), 0)> + : std::true_type {}; + +/** + * Free function to resize a resizable object + */ +template +inline typename std::enable_if::value, void>::type +resize(Container &c, const size_t nElements) { + c.resize(nElements); +} + +/** + * Free function that has no effects on non resizable containers (e.g. + * std::array) It raises an exception if the expected size does not match + */ +template +inline typename std::enable_if::value, void>::type +resize(Container &c, const size_t nElements) { + if (nElements != c.size()) + throw std::logic_error("Try to change the size of a std::array."); +} + +/** + * Free function to assign to a container + */ +template +inline typename std::enable_if::value, void>::type +assign(Container &c, const size_t nElements, const T &value) { + c.assign(nElements, value); +} + +/** + * Free function to assign to a std::array + */ +template +inline typename std::enable_if::value, void>::type +assign(Container &c, const size_t nElements, const T &value) { + for (size_t i = 0; i < nElements; i++) + c[i] = value; +} + +/** @addtogroup result_sets_grp Result set classes + * @{ */ +template +class KNNResultSet { +public: + typedef _DistanceType DistanceType; + typedef _IndexType IndexType; + typedef _CountType CountType; + +private: + IndexType *indices; + DistanceType *dists; + CountType capacity; + CountType count; + +public: + inline KNNResultSet(CountType capacity_) + : indices(0), dists(0), capacity(capacity_), count(0) {} + + inline void init(IndexType *indices_, DistanceType *dists_) { + indices = indices_; + dists = dists_; + count = 0; + if (capacity) + dists[capacity - 1] = (std::numeric_limits::max)(); + } + + inline CountType size() const { return count; } + + inline bool full() const { return count == capacity; } + + /** + * Called during search to add an element matching the criteria. + * @return true if the search should be continued, false if the results are + * sufficient + */ + inline bool addPoint(DistanceType dist, IndexType index) { + CountType i; + for (i = count; i > 0; --i) { +#ifdef NANOFLANN_FIRST_MATCH // If defined and two points have the same + // distance, the one with the lowest-index will be + // returned first. + if ((dists[i - 1] > dist) || + ((dist == dists[i - 1]) && (indices[i - 1] > index))) { +#else + if (dists[i - 1] > dist) { +#endif + if (i < capacity) { + dists[i] = dists[i - 1]; + indices[i] = indices[i - 1]; + } + } else + break; + } + if (i < capacity) { + dists[i] = dist; + indices[i] = index; + } + if (count < capacity) + count++; + + // tell caller that the search shall continue + return true; + } + + inline DistanceType worstDist() const { return dists[capacity - 1]; } +}; + +/** operator "<" for std::sort() */ +struct IndexDist_Sorter { + /** PairType will be typically: std::pair */ + template + inline bool operator()(const PairType &p1, const PairType &p2) const { + return p1.second < p2.second; + } +}; + +/** + * A result-set class used when performing a radius based search. + */ +template +class RadiusResultSet { +public: + typedef _DistanceType DistanceType; + typedef _IndexType IndexType; + +public: + const DistanceType radius; + + std::vector> &m_indices_dists; + + inline RadiusResultSet( + DistanceType radius_, + std::vector> &indices_dists) + : radius(radius_), m_indices_dists(indices_dists) { + init(); + } + + inline void init() { clear(); } + inline void clear() { m_indices_dists.clear(); } + + inline size_t size() const { return m_indices_dists.size(); } + + inline bool full() const { return true; } + + /** + * Called during search to add an element matching the criteria. + * @return true if the search should be continued, false if the results are + * sufficient + */ + inline bool addPoint(DistanceType dist, IndexType index) { + if (dist < radius) + m_indices_dists.push_back(std::make_pair(index, dist)); + return true; + } + + inline DistanceType worstDist() const { return radius; } + + /** + * Find the worst result (furtherest neighbor) without copying or sorting + * Pre-conditions: size() > 0 + */ + std::pair worst_item() const { + if (m_indices_dists.empty()) + throw std::runtime_error("Cannot invoke RadiusResultSet::worst_item() on " + "an empty list of results."); + typedef + typename std::vector>::const_iterator + DistIt; + DistIt it = std::max_element(m_indices_dists.begin(), m_indices_dists.end(), + IndexDist_Sorter()); + return *it; + } +}; + +/** @} */ + +/** @addtogroup loadsave_grp Load/save auxiliary functions + * @{ */ +template +void save_value(FILE *stream, const T &value, size_t count = 1) { + fwrite(&value, sizeof(value), count, stream); +} + +template +void save_value(FILE *stream, const std::vector &value) { + size_t size = value.size(); + fwrite(&size, sizeof(size_t), 1, stream); + fwrite(&value[0], sizeof(T), size, stream); +} + +template +void load_value(FILE *stream, T &value, size_t count = 1) { + size_t read_cnt = fread(&value, sizeof(value), count, stream); + if (read_cnt != count) { + throw std::runtime_error("Cannot read from file"); + } +} + +template void load_value(FILE *stream, std::vector &value) { + size_t size; + size_t read_cnt = fread(&size, sizeof(size_t), 1, stream); + if (read_cnt != 1) { + throw std::runtime_error("Cannot read from file"); + } + value.resize(size); + read_cnt = fread(&value[0], sizeof(T), size, stream); + if (read_cnt != size) { + throw std::runtime_error("Cannot read from file"); + } +} +/** @} */ + +/** @addtogroup metric_grp Metric (distance) classes + * @{ */ + +struct Metric {}; + +/** Manhattan distance functor (generic version, optimized for + * high-dimensionality data sets). Corresponding distance traits: + * nanoflann::metric_L1 \tparam T Type of the elements (e.g. double, float, + * uint8_t) \tparam _DistanceType Type of distance variables (must be signed) + * (e.g. float, double, int64_t) + */ +template +struct L1_Adaptor { + typedef T ElementType; + typedef _DistanceType DistanceType; + + const DataSource &data_source; + + L1_Adaptor(const DataSource &_data_source) : data_source(_data_source) {} + + inline DistanceType evalMetric(const T *a, const size_t b_idx, size_t size, + DistanceType worst_dist = -1) const { + DistanceType result = DistanceType(); + const T *last = a + size; + const T *lastgroup = last - 3; + size_t d = 0; + + /* Process 4 items with each loop for efficiency. */ + while (a < lastgroup) { + const DistanceType diff0 = + std::abs(a[0] - data_source.kdtree_get_pt(b_idx, d++)); + const DistanceType diff1 = + std::abs(a[1] - data_source.kdtree_get_pt(b_idx, d++)); + const DistanceType diff2 = + std::abs(a[2] - data_source.kdtree_get_pt(b_idx, d++)); + const DistanceType diff3 = + std::abs(a[3] - data_source.kdtree_get_pt(b_idx, d++)); + result += diff0 + diff1 + diff2 + diff3; + a += 4; + if ((worst_dist > 0) && (result > worst_dist)) { + return result; + } + } + /* Process last 0-3 components. Not needed for standard vector lengths. */ + while (a < last) { + result += std::abs(*a++ - data_source.kdtree_get_pt(b_idx, d++)); + } + return result; + } + + template + inline DistanceType accum_dist(const U a, const V b, const size_t) const { + return std::abs(a - b); + } +}; + +/** Squared Euclidean distance functor (generic version, optimized for + * high-dimensionality data sets). Corresponding distance traits: + * nanoflann::metric_L2 \tparam T Type of the elements (e.g. double, float, + * uint8_t) \tparam _DistanceType Type of distance variables (must be signed) + * (e.g. float, double, int64_t) + */ +template +struct L2_Adaptor { + typedef T ElementType; + typedef _DistanceType DistanceType; + + const DataSource &data_source; + + L2_Adaptor(const DataSource &_data_source) : data_source(_data_source) {} + + inline DistanceType evalMetric(const T *a, const size_t b_idx, size_t size, + DistanceType worst_dist = -1) const { + DistanceType result = DistanceType(); + const T *last = a + size; + const T *lastgroup = last - 3; + size_t d = 0; + + /* Process 4 items with each loop for efficiency. */ + while (a < lastgroup) { + const DistanceType diff0 = a[0] - data_source.kdtree_get_pt(b_idx, d++); + const DistanceType diff1 = a[1] - data_source.kdtree_get_pt(b_idx, d++); + const DistanceType diff2 = a[2] - data_source.kdtree_get_pt(b_idx, d++); + const DistanceType diff3 = a[3] - data_source.kdtree_get_pt(b_idx, d++); + result += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3; + a += 4; + if ((worst_dist > 0) && (result > worst_dist)) { + return result; + } + } + /* Process last 0-3 components. Not needed for standard vector lengths. */ + while (a < last) { + const DistanceType diff0 = *a++ - data_source.kdtree_get_pt(b_idx, d++); + result += diff0 * diff0; + } + return result; + } + + template + inline DistanceType accum_dist(const U a, const V b, const size_t) const { + return (a - b) * (a - b); + } +}; + +/** Squared Euclidean (L2) distance functor (suitable for low-dimensionality + * datasets, like 2D or 3D point clouds) Corresponding distance traits: + * nanoflann::metric_L2_Simple \tparam T Type of the elements (e.g. double, + * float, uint8_t) \tparam _DistanceType Type of distance variables (must be + * signed) (e.g. float, double, int64_t) + */ +template +struct L2_Simple_Adaptor { + typedef T ElementType; + typedef _DistanceType DistanceType; + + const DataSource &data_source; + + L2_Simple_Adaptor(const DataSource &_data_source) + : data_source(_data_source) {} + + inline DistanceType evalMetric(const T *a, const size_t b_idx, + size_t size) const { + DistanceType result = DistanceType(); + for (size_t i = 0; i < size; ++i) { + const DistanceType diff = a[i] - data_source.kdtree_get_pt(b_idx, i); + result += diff * diff; + } + return result; + } + + template + inline DistanceType accum_dist(const U a, const V b, const size_t) const { + return (a - b) * (a - b); + } +}; + +/** SO2 distance functor + * Corresponding distance traits: nanoflann::metric_SO2 + * \tparam T Type of the elements (e.g. double, float) + * \tparam _DistanceType Type of distance variables (must be signed) (e.g. + * float, double) orientation is constrained to be in [-pi, pi] + */ +template +struct SO2_Adaptor { + typedef T ElementType; + typedef _DistanceType DistanceType; + + const DataSource &data_source; + + SO2_Adaptor(const DataSource &_data_source) : data_source(_data_source) {} + + inline DistanceType evalMetric(const T *a, const size_t b_idx, + size_t size) const { + return accum_dist(a[size - 1], data_source.kdtree_get_pt(b_idx, size - 1), + size - 1); + } + + /** Note: this assumes that input angles are already in the range [-pi,pi] */ + template + inline DistanceType accum_dist(const U a, const V b, const size_t) const { + DistanceType result = DistanceType(), PI = pi_const(); + result = b - a; + if (result > PI) + result -= 2 * PI; + else if (result < -PI) + result += 2 * PI; + return result; + } +}; + +/** SO3 distance functor (Uses L2_Simple) + * Corresponding distance traits: nanoflann::metric_SO3 + * \tparam T Type of the elements (e.g. double, float) + * \tparam _DistanceType Type of distance variables (must be signed) (e.g. + * float, double) + */ +template +struct SO3_Adaptor { + typedef T ElementType; + typedef _DistanceType DistanceType; + + L2_Simple_Adaptor distance_L2_Simple; + + SO3_Adaptor(const DataSource &_data_source) + : distance_L2_Simple(_data_source) {} + + inline DistanceType evalMetric(const T *a, const size_t b_idx, + size_t size) const { + return distance_L2_Simple.evalMetric(a, b_idx, size); + } + + template + inline DistanceType accum_dist(const U a, const V b, const size_t idx) const { + return distance_L2_Simple.accum_dist(a, b, idx); + } +}; + +/** Metaprogramming helper traits class for the L1 (Manhattan) metric */ +struct metric_L1 : public Metric { + template struct traits { + typedef L1_Adaptor distance_t; + }; +}; +/** Metaprogramming helper traits class for the L2 (Euclidean) metric */ +struct metric_L2 : public Metric { + template struct traits { + typedef L2_Adaptor distance_t; + }; +}; +/** Metaprogramming helper traits class for the L2_simple (Euclidean) metric */ +struct metric_L2_Simple : public Metric { + template struct traits { + typedef L2_Simple_Adaptor distance_t; + }; +}; +/** Metaprogramming helper traits class for the SO3_InnerProdQuat metric */ +struct metric_SO2 : public Metric { + template struct traits { + typedef SO2_Adaptor distance_t; + }; +}; +/** Metaprogramming helper traits class for the SO3_InnerProdQuat metric */ +struct metric_SO3 : public Metric { + template struct traits { + typedef SO3_Adaptor distance_t; + }; +}; + +/** @} */ + +/** @addtogroup param_grp Parameter structs + * @{ */ + +/** Parameters (see README.md) */ +struct KDTreeSingleIndexAdaptorParams { + KDTreeSingleIndexAdaptorParams(size_t _leaf_max_size = 10) + : leaf_max_size(_leaf_max_size) {} + + size_t leaf_max_size; +}; + +/** Search options for KDTreeSingleIndexAdaptor::findNeighbors() */ +struct SearchParams { + /** Note: The first argument (checks_IGNORED_) is ignored, but kept for + * compatibility with the FLANN interface */ + SearchParams(int checks_IGNORED_ = 32, float eps_ = 0, bool sorted_ = true) + : checks(checks_IGNORED_), eps(eps_), sorted(sorted_) {} + + int checks; //!< Ignored parameter (Kept for compatibility with the FLANN + //!< interface). + float eps; //!< search for eps-approximate neighbours (default: 0) + bool sorted; //!< only for radius search, require neighbours sorted by + //!< distance (default: true) +}; +/** @} */ + +/** @addtogroup memalloc_grp Memory allocation + * @{ */ + +/** + * Allocates (using C's malloc) a generic type T. + * + * Params: + * count = number of instances to allocate. + * Returns: pointer (of type T*) to memory buffer + */ +template inline T *allocate(size_t count = 1) { + T *mem = static_cast(::malloc(sizeof(T) * count)); + return mem; +} + +/** + * Pooled storage allocator + * + * The following routines allow for the efficient allocation of storage in + * small chunks from a specified pool. Rather than allowing each structure + * to be freed individually, an entire pool of storage is freed at once. + * This method has two advantages over just using malloc() and free(). First, + * it is far more efficient for allocating small objects, as there is + * no overhead for remembering all the information needed to free each + * object or consolidating fragmented memory. Second, the decision about + * how long to keep an object is made at the time of allocation, and there + * is no need to track down all the objects to free them. + * + */ + +const size_t WORDSIZE = 16; +const size_t BLOCKSIZE = 8192; + +class PooledAllocator { + /* We maintain memory alignment to word boundaries by requiring that all + allocations be in multiples of the machine wordsize. */ + /* Size of machine word in bytes. Must be power of 2. */ + /* Minimum number of bytes requested at a time from the system. Must be + * multiple of WORDSIZE. */ + + size_t remaining; /* Number of bytes left in current block of storage. */ + void *base; /* Pointer to base of current block of storage. */ + void *loc; /* Current location in block to next allocate memory. */ + + void internal_init() { + remaining = 0; + base = NULL; + usedMemory = 0; + wastedMemory = 0; + } + +public: + size_t usedMemory; + size_t wastedMemory; + + /** + Default constructor. Initializes a new pool. + */ + PooledAllocator() { internal_init(); } + + /** + * Destructor. Frees all the memory allocated in this pool. + */ + ~PooledAllocator() { free_all(); } + + /** Frees all allocated memory chunks */ + void free_all() { + while (base != NULL) { + void *prev = + *(static_cast(base)); /* Get pointer to prev block. */ + ::free(base); + base = prev; + } + internal_init(); + } + + /** + * Returns a pointer to a piece of new memory of the given size in bytes + * allocated from the pool. + */ + void *malloc(const size_t req_size) { + /* Round size up to a multiple of wordsize. The following expression + only works for WORDSIZE that is a power of 2, by masking last bits of + incremented size to zero. + */ + const size_t size = (req_size + (WORDSIZE - 1)) & ~(WORDSIZE - 1); + + /* Check whether a new block must be allocated. Note that the first word + of a block is reserved for a pointer to the previous block. + */ + if (size > remaining) { + + wastedMemory += remaining; + + /* Allocate new storage. */ + const size_t blocksize = + (size + sizeof(void *) + (WORDSIZE - 1) > BLOCKSIZE) + ? size + sizeof(void *) + (WORDSIZE - 1) + : BLOCKSIZE; + + // use the standard C malloc to allocate memory + void *m = ::malloc(blocksize); + if (!m) { + fprintf(stderr, "Failed to allocate memory.\n"); + return NULL; + } + + /* Fill first word of new block with pointer to previous block. */ + static_cast(m)[0] = base; + base = m; + + size_t shift = 0; + // int size_t = (WORDSIZE - ( (((size_t)m) + sizeof(void*)) & + // (WORDSIZE-1))) & (WORDSIZE-1); + + remaining = blocksize - sizeof(void *) - shift; + loc = (static_cast(m) + sizeof(void *) + shift); + } + void *rloc = loc; + loc = static_cast(loc) + size; + remaining -= size; + + usedMemory += size; + + return rloc; + } + + /** + * Allocates (using this pool) a generic type T. + * + * Params: + * count = number of instances to allocate. + * Returns: pointer (of type T*) to memory buffer + */ + template T *allocate(const size_t count = 1) { + T *mem = static_cast(this->malloc(sizeof(T) * count)); + return mem; + } +}; +/** @} */ + +/** @addtogroup nanoflann_metaprog_grp Auxiliary metaprogramming stuff + * @{ */ + +/** Used to declare fixed-size arrays when DIM>0, dynamically-allocated vectors + * when DIM=-1. Fixed size version for a generic DIM: + */ +template struct array_or_vector_selector { + typedef std::array container_t; +}; +/** Dynamic size version */ +template struct array_or_vector_selector<-1, T> { + typedef std::vector container_t; +}; + +/** @} */ + +/** kd-tree base-class + * + * Contains the member functions common to the classes KDTreeSingleIndexAdaptor + * and KDTreeSingleIndexDynamicAdaptor_. + * + * \tparam Derived The name of the class which inherits this class. + * \tparam DatasetAdaptor The user-provided adaptor (see comments above). + * \tparam Distance The distance metric to use, these are all classes derived + * from nanoflann::Metric \tparam DIM Dimensionality of data points (e.g. 3 for + * 3D points) \tparam IndexType Will be typically size_t or int + */ + +template +class KDTreeBaseClass { + +public: + /** Frees the previously-built index. Automatically called within + * buildIndex(). */ + void freeIndex(Derived &obj) { + obj.pool.free_all(); + obj.root_node = NULL; + obj.m_size_at_index_build = 0; + } + + typedef typename Distance::ElementType ElementType; + typedef typename Distance::DistanceType DistanceType; + + /*--------------------- Internal Data Structures --------------------------*/ + struct Node { + /** Union used because a node can be either a LEAF node or a non-leaf node, + * so both data fields are never used simultaneously */ + union { + struct leaf { + IndexType left, right; //!< Indices of points in leaf node + } lr; + struct nonleaf { + int divfeat; //!< Dimension used for subdivision. + DistanceType divlow, divhigh; //!< The values used for subdivision. + } sub; + } node_type; + Node *child1, *child2; //!< Child nodes (both=NULL mean its a leaf node) + }; + + typedef Node *NodePtr; + + struct Interval { + ElementType low, high; + }; + + /** + * Array of indices to vectors in the dataset. + */ + std::vector vind; + + NodePtr root_node; + + size_t m_leaf_max_size; + + size_t m_size; //!< Number of current points in the dataset + size_t m_size_at_index_build; //!< Number of points in the dataset when the + //!< index was built + int dim; //!< Dimensionality of each data point + + /** Define "BoundingBox" as a fixed-size or variable-size container depending + * on "DIM" */ + typedef + typename array_or_vector_selector::container_t BoundingBox; + + /** Define "distance_vector_t" as a fixed-size or variable-size container + * depending on "DIM" */ + typedef typename array_or_vector_selector::container_t + distance_vector_t; + + /** The KD-tree used to find neighbours */ + + BoundingBox root_bbox; + + /** + * Pooled memory allocator. + * + * Using a pooled memory allocator is more efficient + * than allocating memory directly when there is a large + * number small of memory allocations. + */ + PooledAllocator pool; + + /** Returns number of points in dataset */ + size_t size(const Derived &obj) const { return obj.m_size; } + + /** Returns the length of each point in the dataset */ + size_t veclen(const Derived &obj) { + return static_cast(DIM > 0 ? DIM : obj.dim); + } + + /// Helper accessor to the dataset points: + inline ElementType dataset_get(const Derived &obj, size_t idx, + int component) const { + return obj.dataset.kdtree_get_pt(idx, component); + } + + /** + * Computes the inde memory usage + * Returns: memory used by the index + */ + size_t usedMemory(Derived &obj) { + return obj.pool.usedMemory + obj.pool.wastedMemory + + obj.dataset.kdtree_get_point_count() * + sizeof(IndexType); // pool memory and vind array memory + } + + void computeMinMax(const Derived &obj, IndexType *ind, IndexType count, + int element, ElementType &min_elem, + ElementType &max_elem) { + min_elem = dataset_get(obj, ind[0], element); + max_elem = dataset_get(obj, ind[0], element); + for (IndexType i = 1; i < count; ++i) { + ElementType val = dataset_get(obj, ind[i], element); + if (val < min_elem) + min_elem = val; + if (val > max_elem) + max_elem = val; + } + } + + /** + * Create a tree node that subdivides the list of vecs from vind[first] + * to vind[last]. The routine is called recursively on each sublist. + * + * @param left index of the first vector + * @param right index of the last vector + */ + NodePtr divideTree(Derived &obj, const IndexType left, const IndexType right, + BoundingBox &bbox) { + NodePtr node = obj.pool.template allocate(); // allocate memory + + /* If too few exemplars remain, then make this a leaf node. */ + if ((right - left) <= static_cast(obj.m_leaf_max_size)) { + node->child1 = node->child2 = NULL; /* Mark as leaf node. */ + node->node_type.lr.left = left; + node->node_type.lr.right = right; + + // compute bounding-box of leaf points + for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) { + bbox[i].low = dataset_get(obj, obj.vind[left], i); + bbox[i].high = dataset_get(obj, obj.vind[left], i); + } + for (IndexType k = left + 1; k < right; ++k) { + for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) { + if (bbox[i].low > dataset_get(obj, obj.vind[k], i)) + bbox[i].low = dataset_get(obj, obj.vind[k], i); + if (bbox[i].high < dataset_get(obj, obj.vind[k], i)) + bbox[i].high = dataset_get(obj, obj.vind[k], i); + } + } + } else { + IndexType idx; + int cutfeat; + DistanceType cutval; + middleSplit_(obj, &obj.vind[0] + left, right - left, idx, cutfeat, cutval, + bbox); + + node->node_type.sub.divfeat = cutfeat; + + BoundingBox left_bbox(bbox); + left_bbox[cutfeat].high = cutval; + node->child1 = divideTree(obj, left, left + idx, left_bbox); + + BoundingBox right_bbox(bbox); + right_bbox[cutfeat].low = cutval; + node->child2 = divideTree(obj, left + idx, right, right_bbox); + + node->node_type.sub.divlow = left_bbox[cutfeat].high; + node->node_type.sub.divhigh = right_bbox[cutfeat].low; + + for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) { + bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low); + bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high); + } + } + + return node; + } + + void middleSplit_(Derived &obj, IndexType *ind, IndexType count, + IndexType &index, int &cutfeat, DistanceType &cutval, + const BoundingBox &bbox) { + const DistanceType EPS = static_cast(0.00001); + ElementType max_span = bbox[0].high - bbox[0].low; + for (int i = 1; i < (DIM > 0 ? DIM : obj.dim); ++i) { + ElementType span = bbox[i].high - bbox[i].low; + if (span > max_span) { + max_span = span; + } + } + ElementType max_spread = -1; + cutfeat = 0; + for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) { + ElementType span = bbox[i].high - bbox[i].low; + if (span > (1 - EPS) * max_span) { + ElementType min_elem, max_elem; + computeMinMax(obj, ind, count, i, min_elem, max_elem); + ElementType spread = max_elem - min_elem; + ; + if (spread > max_spread) { + cutfeat = i; + max_spread = spread; + } + } + } + // split in the middle + DistanceType split_val = (bbox[cutfeat].low + bbox[cutfeat].high) / 2; + ElementType min_elem, max_elem; + computeMinMax(obj, ind, count, cutfeat, min_elem, max_elem); + + if (split_val < min_elem) + cutval = min_elem; + else if (split_val > max_elem) + cutval = max_elem; + else + cutval = split_val; + + IndexType lim1, lim2; + planeSplit(obj, ind, count, cutfeat, cutval, lim1, lim2); + + if (lim1 > count / 2) + index = lim1; + else if (lim2 < count / 2) + index = lim2; + else + index = count / 2; + } + + /** + * Subdivide the list of points by a plane perpendicular on axe corresponding + * to the 'cutfeat' dimension at 'cutval' position. + * + * On return: + * dataset[ind[0..lim1-1]][cutfeat]cutval + */ + void planeSplit(Derived &obj, IndexType *ind, const IndexType count, + int cutfeat, DistanceType &cutval, IndexType &lim1, + IndexType &lim2) { + /* Move vector indices for left subtree to front of list. */ + IndexType left = 0; + IndexType right = count - 1; + for (;;) { + while (left <= right && dataset_get(obj, ind[left], cutfeat) < cutval) + ++left; + while (right && left <= right && + dataset_get(obj, ind[right], cutfeat) >= cutval) + --right; + if (left > right || !right) + break; // "!right" was added to support unsigned Index types + std::swap(ind[left], ind[right]); + ++left; + --right; + } + /* If either list is empty, it means that all remaining features + * are identical. Split in the middle to maintain a balanced tree. + */ + lim1 = left; + right = count - 1; + for (;;) { + while (left <= right && dataset_get(obj, ind[left], cutfeat) <= cutval) + ++left; + while (right && left <= right && + dataset_get(obj, ind[right], cutfeat) > cutval) + --right; + if (left > right || !right) + break; // "!right" was added to support unsigned Index types + std::swap(ind[left], ind[right]); + ++left; + --right; + } + lim2 = left; + } + + DistanceType computeInitialDistances(const Derived &obj, + const ElementType *vec, + distance_vector_t &dists) const { + assert(vec); + DistanceType distsq = DistanceType(); + + for (int i = 0; i < (DIM > 0 ? DIM : obj.dim); ++i) { + if (vec[i] < obj.root_bbox[i].low) { + dists[i] = obj.distance.accum_dist(vec[i], obj.root_bbox[i].low, i); + distsq += dists[i]; + } + if (vec[i] > obj.root_bbox[i].high) { + dists[i] = obj.distance.accum_dist(vec[i], obj.root_bbox[i].high, i); + distsq += dists[i]; + } + } + return distsq; + } + + void save_tree(Derived &obj, FILE *stream, NodePtr tree) { + save_value(stream, *tree); + if (tree->child1 != NULL) { + save_tree(obj, stream, tree->child1); + } + if (tree->child2 != NULL) { + save_tree(obj, stream, tree->child2); + } + } + + void load_tree(Derived &obj, FILE *stream, NodePtr &tree) { + tree = obj.pool.template allocate(); + load_value(stream, *tree); + if (tree->child1 != NULL) { + load_tree(obj, stream, tree->child1); + } + if (tree->child2 != NULL) { + load_tree(obj, stream, tree->child2); + } + } + + /** Stores the index in a binary file. + * IMPORTANT NOTE: The set of data points is NOT stored in the file, so when + * loading the index object it must be constructed associated to the same + * source of data points used while building it. See the example: + * examples/saveload_example.cpp \sa loadIndex */ + void saveIndex_(Derived &obj, FILE *stream) { + save_value(stream, obj.m_size); + save_value(stream, obj.dim); + save_value(stream, obj.root_bbox); + save_value(stream, obj.m_leaf_max_size); + save_value(stream, obj.vind); + save_tree(obj, stream, obj.root_node); + } + + /** Loads a previous index from a binary file. + * IMPORTANT NOTE: The set of data points is NOT stored in the file, so the + * index object must be constructed associated to the same source of data + * points used while building the index. See the example: + * examples/saveload_example.cpp \sa loadIndex */ + void loadIndex_(Derived &obj, FILE *stream) { + load_value(stream, obj.m_size); + load_value(stream, obj.dim); + load_value(stream, obj.root_bbox); + load_value(stream, obj.m_leaf_max_size); + load_value(stream, obj.vind); + load_tree(obj, stream, obj.root_node); + } +}; + +/** @addtogroup kdtrees_grp KD-tree classes and adaptors + * @{ */ + +/** kd-tree static index + * + * Contains the k-d trees and other information for indexing a set of points + * for nearest-neighbor matching. + * + * The class "DatasetAdaptor" must provide the following interface (can be + * non-virtual, inlined methods): + * + * \code + * // Must return the number of data poins + * inline size_t kdtree_get_point_count() const { ... } + * + * + * // Must return the dim'th component of the idx'th point in the class: + * inline T kdtree_get_pt(const size_t idx, const size_t dim) const { ... } + * + * // Optional bounding-box computation: return false to default to a standard + * bbox computation loop. + * // Return true if the BBOX was already computed by the class and returned + * in "bb" so it can be avoided to redo it again. + * // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 + * for point clouds) template bool kdtree_get_bbox(BBOX &bb) const + * { + * bb[0].low = ...; bb[0].high = ...; // 0th dimension limits + * bb[1].low = ...; bb[1].high = ...; // 1st dimension limits + * ... + * return true; + * } + * + * \endcode + * + * \tparam DatasetAdaptor The user-provided adaptor (see comments above). + * \tparam Distance The distance metric to use: nanoflann::metric_L1, + * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \tparam DIM + * Dimensionality of data points (e.g. 3 for 3D points) \tparam IndexType Will + * be typically size_t or int + */ +template +class KDTreeSingleIndexAdaptor + : public KDTreeBaseClass< + KDTreeSingleIndexAdaptor, + Distance, DatasetAdaptor, DIM, IndexType> { +public: + /** Deleted copy constructor*/ + KDTreeSingleIndexAdaptor( + const KDTreeSingleIndexAdaptor + &) = delete; + + /** + * The dataset used by this index + */ + const DatasetAdaptor &dataset; //!< The source of our data + + const KDTreeSingleIndexAdaptorParams index_params; + + Distance distance; + + typedef typename nanoflann::KDTreeBaseClass< + nanoflann::KDTreeSingleIndexAdaptor, + Distance, DatasetAdaptor, DIM, IndexType> + BaseClassRef; + + typedef typename BaseClassRef::ElementType ElementType; + typedef typename BaseClassRef::DistanceType DistanceType; + + typedef typename BaseClassRef::Node Node; + typedef Node *NodePtr; + + typedef typename BaseClassRef::Interval Interval; + /** Define "BoundingBox" as a fixed-size or variable-size container depending + * on "DIM" */ + typedef typename BaseClassRef::BoundingBox BoundingBox; + + /** Define "distance_vector_t" as a fixed-size or variable-size container + * depending on "DIM" */ + typedef typename BaseClassRef::distance_vector_t distance_vector_t; + + /** + * KDTree constructor + * + * Refer to docs in README.md or online in + * https://github.com/jlblancoc/nanoflann + * + * The KD-Tree point dimension (the length of each point in the datase, e.g. 3 + * for 3D points) is determined by means of: + * - The \a DIM template parameter if >0 (highest priority) + * - Otherwise, the \a dimensionality parameter of this constructor. + * + * @param inputData Dataset with the input features + * @param params Basically, the maximum leaf node size + */ + KDTreeSingleIndexAdaptor(const int dimensionality, + const DatasetAdaptor &inputData, + const KDTreeSingleIndexAdaptorParams ¶ms = + KDTreeSingleIndexAdaptorParams()) + : dataset(inputData), index_params(params), distance(inputData) { + BaseClassRef::root_node = NULL; + BaseClassRef::m_size = dataset.kdtree_get_point_count(); + BaseClassRef::m_size_at_index_build = BaseClassRef::m_size; + BaseClassRef::dim = dimensionality; + if (DIM > 0) + BaseClassRef::dim = DIM; + BaseClassRef::m_leaf_max_size = params.leaf_max_size; + + // Create a permutable array of indices to the input vectors. + init_vind(); + } + + /** + * Builds the index + */ + void buildIndex() { + BaseClassRef::m_size = dataset.kdtree_get_point_count(); + BaseClassRef::m_size_at_index_build = BaseClassRef::m_size; + init_vind(); + this->freeIndex(*this); + BaseClassRef::m_size_at_index_build = BaseClassRef::m_size; + if (BaseClassRef::m_size == 0) + return; + computeBoundingBox(BaseClassRef::root_bbox); + BaseClassRef::root_node = + this->divideTree(*this, 0, BaseClassRef::m_size, + BaseClassRef::root_bbox); // construct the tree + } + + /** \name Query methods + * @{ */ + + /** + * Find set of nearest neighbors to vec[0:dim-1]. Their indices are stored + * inside the result object. + * + * Params: + * result = the result object in which the indices of the + * nearest-neighbors are stored vec = the vector for which to search the + * nearest neighbors + * + * \tparam RESULTSET Should be any ResultSet + * \return True if the requested neighbors could be found. + * \sa knnSearch, radiusSearch + */ + template + bool findNeighbors(RESULTSET &result, const ElementType *vec, + const SearchParams &searchParams) const { + assert(vec); + if (this->size(*this) == 0) + return false; + if (!BaseClassRef::root_node) + throw std::runtime_error( + "[nanoflann] findNeighbors() called before building the index."); + float epsError = 1 + searchParams.eps; + + distance_vector_t + dists; // fixed or variable-sized container (depending on DIM) + auto zero = static_cast(0); + assign(dists, (DIM > 0 ? DIM : BaseClassRef::dim), + zero); // Fill it with zeros. + DistanceType distsq = this->computeInitialDistances(*this, vec, dists); + + searchLevel(result, vec, BaseClassRef::root_node, distsq, dists, + epsError); // "count_leaf" parameter removed since was neither + // used nor returned to the user. + + return result.full(); + } + + /** + * Find the "num_closest" nearest neighbors to the \a query_point[0:dim-1]. + * Their indices are stored inside the result object. \sa radiusSearch, + * findNeighbors \note nChecks_IGNORED is ignored but kept for compatibility + * with the original FLANN interface. \return Number `N` of valid points in + * the result set. Only the first `N` entries in `out_indices` and + * `out_distances_sq` will be valid. Return may be less than `num_closest` + * only if the number of elements in the tree is less than `num_closest`. + */ + size_t knnSearch(const ElementType *query_point, const size_t num_closest, + IndexType *out_indices, DistanceType *out_distances_sq, + const int /* nChecks_IGNORED */ = 10) const { + nanoflann::KNNResultSet resultSet(num_closest); + resultSet.init(out_indices, out_distances_sq); + this->findNeighbors(resultSet, query_point, nanoflann::SearchParams()); + return resultSet.size(); + } + + /** + * Find all the neighbors to \a query_point[0:dim-1] within a maximum radius. + * The output is given as a vector of pairs, of which the first element is a + * point index and the second the corresponding distance. Previous contents of + * \a IndicesDists are cleared. + * + * If searchParams.sorted==true, the output list is sorted by ascending + * distances. + * + * For a better performance, it is advisable to do a .reserve() on the vector + * if you have any wild guess about the number of expected matches. + * + * \sa knnSearch, findNeighbors, radiusSearchCustomCallback + * \return The number of points within the given radius (i.e. indices.size() + * or dists.size() ) + */ + size_t + radiusSearch(const ElementType *query_point, const DistanceType &radius, + std::vector> &IndicesDists, + const SearchParams &searchParams) const { + RadiusResultSet resultSet(radius, IndicesDists); + const size_t nFound = + radiusSearchCustomCallback(query_point, resultSet, searchParams); + if (searchParams.sorted) + std::sort(IndicesDists.begin(), IndicesDists.end(), IndexDist_Sorter()); + return nFound; + } + + /** + * Just like radiusSearch() but with a custom callback class for each point + * found in the radius of the query. See the source of RadiusResultSet<> as a + * start point for your own classes. \sa radiusSearch + */ + template + size_t radiusSearchCustomCallback( + const ElementType *query_point, SEARCH_CALLBACK &resultSet, + const SearchParams &searchParams = SearchParams()) const { + this->findNeighbors(resultSet, query_point, searchParams); + return resultSet.size(); + } + + /** @} */ + +public: + /** Make sure the auxiliary list \a vind has the same size than the current + * dataset, and re-generate if size has changed. */ + void init_vind() { + // Create a permutable array of indices to the input vectors. + BaseClassRef::m_size = dataset.kdtree_get_point_count(); + if (BaseClassRef::vind.size() != BaseClassRef::m_size) + BaseClassRef::vind.resize(BaseClassRef::m_size); + for (size_t i = 0; i < BaseClassRef::m_size; i++) + BaseClassRef::vind[i] = i; + } + + void computeBoundingBox(BoundingBox &bbox) { + resize(bbox, (DIM > 0 ? DIM : BaseClassRef::dim)); + if (dataset.kdtree_get_bbox(bbox)) { + // Done! It was implemented in derived class + } else { + const size_t N = dataset.kdtree_get_point_count(); + if (!N) + throw std::runtime_error("[nanoflann] computeBoundingBox() called but " + "no data points found."); + for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) { + bbox[i].low = bbox[i].high = this->dataset_get(*this, 0, i); + } + for (size_t k = 1; k < N; ++k) { + for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) { + if (this->dataset_get(*this, k, i) < bbox[i].low) + bbox[i].low = this->dataset_get(*this, k, i); + if (this->dataset_get(*this, k, i) > bbox[i].high) + bbox[i].high = this->dataset_get(*this, k, i); + } + } + } + } + + /** + * Performs an exact search in the tree starting from a node. + * \tparam RESULTSET Should be any ResultSet + * \return true if the search should be continued, false if the results are + * sufficient + */ + template + bool searchLevel(RESULTSET &result_set, const ElementType *vec, + const NodePtr node, DistanceType mindistsq, + distance_vector_t &dists, const float epsError) const { + /* If this is a leaf node, then do check and return. */ + if ((node->child1 == NULL) && (node->child2 == NULL)) { + // count_leaf += (node->lr.right-node->lr.left); // Removed since was + // neither used nor returned to the user. + DistanceType worst_dist = result_set.worstDist(); + for (IndexType i = node->node_type.lr.left; i < node->node_type.lr.right; + ++i) { + const IndexType index = BaseClassRef::vind[i]; // reorder... : i; + DistanceType dist = distance.evalMetric( + vec, index, (DIM > 0 ? DIM : BaseClassRef::dim)); + if (dist < worst_dist) { + if (!result_set.addPoint(dist, BaseClassRef::vind[i])) { + // the resultset doesn't want to receive any more points, we're done + // searching! + return false; + } + } + } + return true; + } + + /* Which child branch should be taken first? */ + int idx = node->node_type.sub.divfeat; + ElementType val = vec[idx]; + DistanceType diff1 = val - node->node_type.sub.divlow; + DistanceType diff2 = val - node->node_type.sub.divhigh; + + NodePtr bestChild; + NodePtr otherChild; + DistanceType cut_dist; + if ((diff1 + diff2) < 0) { + bestChild = node->child1; + otherChild = node->child2; + cut_dist = distance.accum_dist(val, node->node_type.sub.divhigh, idx); + } else { + bestChild = node->child2; + otherChild = node->child1; + cut_dist = distance.accum_dist(val, node->node_type.sub.divlow, idx); + } + + /* Call recursively to search next level down. */ + if (!searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError)) { + // the resultset doesn't want to receive any more points, we're done + // searching! + return false; + } + + DistanceType dst = dists[idx]; + mindistsq = mindistsq + cut_dist - dst; + dists[idx] = cut_dist; + if (mindistsq * epsError <= result_set.worstDist()) { + if (!searchLevel(result_set, vec, otherChild, mindistsq, dists, + epsError)) { + // the resultset doesn't want to receive any more points, we're done + // searching! + return false; + } + } + dists[idx] = dst; + return true; + } + +public: + /** Stores the index in a binary file. + * IMPORTANT NOTE: The set of data points is NOT stored in the file, so when + * loading the index object it must be constructed associated to the same + * source of data points used while building it. See the example: + * examples/saveload_example.cpp \sa loadIndex */ + void saveIndex(FILE *stream) { this->saveIndex_(*this, stream); } + + /** Loads a previous index from a binary file. + * IMPORTANT NOTE: The set of data points is NOT stored in the file, so the + * index object must be constructed associated to the same source of data + * points used while building the index. See the example: + * examples/saveload_example.cpp \sa loadIndex */ + void loadIndex(FILE *stream) { this->loadIndex_(*this, stream); } + +}; // class KDTree + +/** kd-tree dynamic index + * + * Contains the k-d trees and other information for indexing a set of points + * for nearest-neighbor matching. + * + * The class "DatasetAdaptor" must provide the following interface (can be + * non-virtual, inlined methods): + * + * \code + * // Must return the number of data poins + * inline size_t kdtree_get_point_count() const { ... } + * + * // Must return the dim'th component of the idx'th point in the class: + * inline T kdtree_get_pt(const size_t idx, const size_t dim) const { ... } + * + * // Optional bounding-box computation: return false to default to a standard + * bbox computation loop. + * // Return true if the BBOX was already computed by the class and returned + * in "bb" so it can be avoided to redo it again. + * // Look at bb.size() to find out the expected dimensionality (e.g. 2 or 3 + * for point clouds) template bool kdtree_get_bbox(BBOX &bb) const + * { + * bb[0].low = ...; bb[0].high = ...; // 0th dimension limits + * bb[1].low = ...; bb[1].high = ...; // 1st dimension limits + * ... + * return true; + * } + * + * \endcode + * + * \tparam DatasetAdaptor The user-provided adaptor (see comments above). + * \tparam Distance The distance metric to use: nanoflann::metric_L1, + * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \tparam DIM + * Dimensionality of data points (e.g. 3 for 3D points) \tparam IndexType Will + * be typically size_t or int + */ +template +class KDTreeSingleIndexDynamicAdaptor_ + : public KDTreeBaseClass, + Distance, DatasetAdaptor, DIM, IndexType> { +public: + /** + * The dataset used by this index + */ + const DatasetAdaptor &dataset; //!< The source of our data + + KDTreeSingleIndexAdaptorParams index_params; + + std::vector &treeIndex; + + Distance distance; + + typedef typename nanoflann::KDTreeBaseClass< + nanoflann::KDTreeSingleIndexDynamicAdaptor_, + Distance, DatasetAdaptor, DIM, IndexType> + BaseClassRef; + + typedef typename BaseClassRef::ElementType ElementType; + typedef typename BaseClassRef::DistanceType DistanceType; + + typedef typename BaseClassRef::Node Node; + typedef Node *NodePtr; + + typedef typename BaseClassRef::Interval Interval; + /** Define "BoundingBox" as a fixed-size or variable-size container depending + * on "DIM" */ + typedef typename BaseClassRef::BoundingBox BoundingBox; + + /** Define "distance_vector_t" as a fixed-size or variable-size container + * depending on "DIM" */ + typedef typename BaseClassRef::distance_vector_t distance_vector_t; + + /** + * KDTree constructor + * + * Refer to docs in README.md or online in + * https://github.com/jlblancoc/nanoflann + * + * The KD-Tree point dimension (the length of each point in the datase, e.g. 3 + * for 3D points) is determined by means of: + * - The \a DIM template parameter if >0 (highest priority) + * - Otherwise, the \a dimensionality parameter of this constructor. + * + * @param inputData Dataset with the input features + * @param params Basically, the maximum leaf node size + */ + KDTreeSingleIndexDynamicAdaptor_( + const int dimensionality, const DatasetAdaptor &inputData, + std::vector &treeIndex_, + const KDTreeSingleIndexAdaptorParams ¶ms = + KDTreeSingleIndexAdaptorParams()) + : dataset(inputData), index_params(params), treeIndex(treeIndex_), + distance(inputData) { + BaseClassRef::root_node = NULL; + BaseClassRef::m_size = 0; + BaseClassRef::m_size_at_index_build = 0; + BaseClassRef::dim = dimensionality; + if (DIM > 0) + BaseClassRef::dim = DIM; + BaseClassRef::m_leaf_max_size = params.leaf_max_size; + } + + /** Assignment operator definiton */ + KDTreeSingleIndexDynamicAdaptor_ + operator=(const KDTreeSingleIndexDynamicAdaptor_ &rhs) { + KDTreeSingleIndexDynamicAdaptor_ tmp(rhs); + std::swap(BaseClassRef::vind, tmp.BaseClassRef::vind); + std::swap(BaseClassRef::m_leaf_max_size, tmp.BaseClassRef::m_leaf_max_size); + std::swap(index_params, tmp.index_params); + std::swap(treeIndex, tmp.treeIndex); + std::swap(BaseClassRef::m_size, tmp.BaseClassRef::m_size); + std::swap(BaseClassRef::m_size_at_index_build, + tmp.BaseClassRef::m_size_at_index_build); + std::swap(BaseClassRef::root_node, tmp.BaseClassRef::root_node); + std::swap(BaseClassRef::root_bbox, tmp.BaseClassRef::root_bbox); + std::swap(BaseClassRef::pool, tmp.BaseClassRef::pool); + return *this; + } + + /** + * Builds the index + */ + void buildIndex() { + BaseClassRef::m_size = BaseClassRef::vind.size(); + this->freeIndex(*this); + BaseClassRef::m_size_at_index_build = BaseClassRef::m_size; + if (BaseClassRef::m_size == 0) + return; + computeBoundingBox(BaseClassRef::root_bbox); + BaseClassRef::root_node = + this->divideTree(*this, 0, BaseClassRef::m_size, + BaseClassRef::root_bbox); // construct the tree + } + + /** \name Query methods + * @{ */ + + /** + * Find set of nearest neighbors to vec[0:dim-1]. Their indices are stored + * inside the result object. + * + * Params: + * result = the result object in which the indices of the + * nearest-neighbors are stored vec = the vector for which to search the + * nearest neighbors + * + * \tparam RESULTSET Should be any ResultSet + * \return True if the requested neighbors could be found. + * \sa knnSearch, radiusSearch + */ + template + bool findNeighbors(RESULTSET &result, const ElementType *vec, + const SearchParams &searchParams) const { + assert(vec); + if (this->size(*this) == 0) + return false; + if (!BaseClassRef::root_node) + return false; + float epsError = 1 + searchParams.eps; + + // fixed or variable-sized container (depending on DIM) + distance_vector_t dists; + // Fill it with zeros. + assign(dists, (DIM > 0 ? DIM : BaseClassRef::dim), + static_cast(0)); + DistanceType distsq = this->computeInitialDistances(*this, vec, dists); + + searchLevel(result, vec, BaseClassRef::root_node, distsq, dists, + epsError); // "count_leaf" parameter removed since was neither + // used nor returned to the user. + + return result.full(); + } + + /** + * Find the "num_closest" nearest neighbors to the \a query_point[0:dim-1]. + * Their indices are stored inside the result object. \sa radiusSearch, + * findNeighbors \note nChecks_IGNORED is ignored but kept for compatibility + * with the original FLANN interface. \return Number `N` of valid points in + * the result set. Only the first `N` entries in `out_indices` and + * `out_distances_sq` will be valid. Return may be less than `num_closest` + * only if the number of elements in the tree is less than `num_closest`. + */ + size_t knnSearch(const ElementType *query_point, const size_t num_closest, + IndexType *out_indices, DistanceType *out_distances_sq, + const int /* nChecks_IGNORED */ = 10) const { + nanoflann::KNNResultSet resultSet(num_closest); + resultSet.init(out_indices, out_distances_sq); + this->findNeighbors(resultSet, query_point, nanoflann::SearchParams()); + return resultSet.size(); + } + + /** + * Find all the neighbors to \a query_point[0:dim-1] within a maximum radius. + * The output is given as a vector of pairs, of which the first element is a + * point index and the second the corresponding distance. Previous contents of + * \a IndicesDists are cleared. + * + * If searchParams.sorted==true, the output list is sorted by ascending + * distances. + * + * For a better performance, it is advisable to do a .reserve() on the vector + * if you have any wild guess about the number of expected matches. + * + * \sa knnSearch, findNeighbors, radiusSearchCustomCallback + * \return The number of points within the given radius (i.e. indices.size() + * or dists.size() ) + */ + size_t + radiusSearch(const ElementType *query_point, const DistanceType &radius, + std::vector> &IndicesDists, + const SearchParams &searchParams) const { + RadiusResultSet resultSet(radius, IndicesDists); + const size_t nFound = + radiusSearchCustomCallback(query_point, resultSet, searchParams); + if (searchParams.sorted) + std::sort(IndicesDists.begin(), IndicesDists.end(), IndexDist_Sorter()); + return nFound; + } + + /** + * Just like radiusSearch() but with a custom callback class for each point + * found in the radius of the query. See the source of RadiusResultSet<> as a + * start point for your own classes. \sa radiusSearch + */ + template + size_t radiusSearchCustomCallback( + const ElementType *query_point, SEARCH_CALLBACK &resultSet, + const SearchParams &searchParams = SearchParams()) const { + this->findNeighbors(resultSet, query_point, searchParams); + return resultSet.size(); + } + + /** @} */ + +public: + void computeBoundingBox(BoundingBox &bbox) { + resize(bbox, (DIM > 0 ? DIM : BaseClassRef::dim)); + + if (dataset.kdtree_get_bbox(bbox)) { + // Done! It was implemented in derived class + } else { + const size_t N = BaseClassRef::m_size; + if (!N) + throw std::runtime_error("[nanoflann] computeBoundingBox() called but " + "no data points found."); + for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) { + bbox[i].low = bbox[i].high = + this->dataset_get(*this, BaseClassRef::vind[0], i); + } + for (size_t k = 1; k < N; ++k) { + for (int i = 0; i < (DIM > 0 ? DIM : BaseClassRef::dim); ++i) { + if (this->dataset_get(*this, BaseClassRef::vind[k], i) < bbox[i].low) + bbox[i].low = this->dataset_get(*this, BaseClassRef::vind[k], i); + if (this->dataset_get(*this, BaseClassRef::vind[k], i) > bbox[i].high) + bbox[i].high = this->dataset_get(*this, BaseClassRef::vind[k], i); + } + } + } + } + + /** + * Performs an exact search in the tree starting from a node. + * \tparam RESULTSET Should be any ResultSet + */ + template + void searchLevel(RESULTSET &result_set, const ElementType *vec, + const NodePtr node, DistanceType mindistsq, + distance_vector_t &dists, const float epsError) const { + /* If this is a leaf node, then do check and return. */ + if ((node->child1 == NULL) && (node->child2 == NULL)) { + // count_leaf += (node->lr.right-node->lr.left); // Removed since was + // neither used nor returned to the user. + DistanceType worst_dist = result_set.worstDist(); + for (IndexType i = node->node_type.lr.left; i < node->node_type.lr.right; + ++i) { + const IndexType index = BaseClassRef::vind[i]; // reorder... : i; + if (treeIndex[index] == -1) + continue; + DistanceType dist = distance.evalMetric( + vec, index, (DIM > 0 ? DIM : BaseClassRef::dim)); + if (dist < worst_dist) { + if (!result_set.addPoint( + static_cast(dist), + static_cast( + BaseClassRef::vind[i]))) { + // the resultset doesn't want to receive any more points, we're done + // searching! + return; // false; + } + } + } + return; + } + + /* Which child branch should be taken first? */ + int idx = node->node_type.sub.divfeat; + ElementType val = vec[idx]; + DistanceType diff1 = val - node->node_type.sub.divlow; + DistanceType diff2 = val - node->node_type.sub.divhigh; + + NodePtr bestChild; + NodePtr otherChild; + DistanceType cut_dist; + if ((diff1 + diff2) < 0) { + bestChild = node->child1; + otherChild = node->child2; + cut_dist = distance.accum_dist(val, node->node_type.sub.divhigh, idx); + } else { + bestChild = node->child2; + otherChild = node->child1; + cut_dist = distance.accum_dist(val, node->node_type.sub.divlow, idx); + } + + /* Call recursively to search next level down. */ + searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError); + + DistanceType dst = dists[idx]; + mindistsq = mindistsq + cut_dist - dst; + dists[idx] = cut_dist; + if (mindistsq * epsError <= result_set.worstDist()) { + searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError); + } + dists[idx] = dst; + } + +public: + /** Stores the index in a binary file. + * IMPORTANT NOTE: The set of data points is NOT stored in the file, so when + * loading the index object it must be constructed associated to the same + * source of data points used while building it. See the example: + * examples/saveload_example.cpp \sa loadIndex */ + void saveIndex(FILE *stream) { this->saveIndex_(*this, stream); } + + /** Loads a previous index from a binary file. + * IMPORTANT NOTE: The set of data points is NOT stored in the file, so the + * index object must be constructed associated to the same source of data + * points used while building the index. See the example: + * examples/saveload_example.cpp \sa loadIndex */ + void loadIndex(FILE *stream) { this->loadIndex_(*this, stream); } +}; + +/** kd-tree dynaimic index + * + * class to create multiple static index and merge their results to behave as + * single dynamic index as proposed in Logarithmic Approach. + * + * Example of usage: + * examples/dynamic_pointcloud_example.cpp + * + * \tparam DatasetAdaptor The user-provided adaptor (see comments above). + * \tparam Distance The distance metric to use: nanoflann::metric_L1, + * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. \tparam DIM + * Dimensionality of data points (e.g. 3 for 3D points) \tparam IndexType Will + * be typically size_t or int + */ +template +class KDTreeSingleIndexDynamicAdaptor { +public: + typedef typename Distance::ElementType ElementType; + typedef typename Distance::DistanceType DistanceType; + +protected: + size_t m_leaf_max_size; + size_t treeCount; + size_t pointCount; + + /** + * The dataset used by this index + */ + const DatasetAdaptor &dataset; //!< The source of our data + + std::vector treeIndex; //!< treeIndex[idx] is the index of tree in which + //!< point at idx is stored. treeIndex[idx]=-1 + //!< means that point has been removed. + + KDTreeSingleIndexAdaptorParams index_params; + + int dim; //!< Dimensionality of each data point + + typedef KDTreeSingleIndexDynamicAdaptor_ + index_container_t; + std::vector index; + +public: + /** Get a const ref to the internal list of indices; the number of indices is + * adapted dynamically as the dataset grows in size. */ + const std::vector &getAllIndices() const { return index; } + +private: + /** finds position of least significant unset bit */ + int First0Bit(IndexType num) { + int pos = 0; + while (num & 1) { + num = num >> 1; + pos++; + } + return pos; + } + + /** Creates multiple empty trees to handle dynamic support */ + void init() { + typedef KDTreeSingleIndexDynamicAdaptor_ + my_kd_tree_t; + std::vector index_( + treeCount, my_kd_tree_t(dim /*dim*/, dataset, treeIndex, index_params)); + index = index_; + } + +public: + Distance distance; + + /** + * KDTree constructor + * + * Refer to docs in README.md or online in + * https://github.com/jlblancoc/nanoflann + * + * The KD-Tree point dimension (the length of each point in the datase, e.g. 3 + * for 3D points) is determined by means of: + * - The \a DIM template parameter if >0 (highest priority) + * - Otherwise, the \a dimensionality parameter of this constructor. + * + * @param inputData Dataset with the input features + * @param params Basically, the maximum leaf node size + */ + KDTreeSingleIndexDynamicAdaptor(const int dimensionality, + const DatasetAdaptor &inputData, + const KDTreeSingleIndexAdaptorParams ¶ms = + KDTreeSingleIndexAdaptorParams(), + const size_t maximumPointCount = 1000000000U) + : dataset(inputData), index_params(params), distance(inputData) { + treeCount = static_cast(std::log2(maximumPointCount)); + pointCount = 0U; + dim = dimensionality; + treeIndex.clear(); + if (DIM > 0) + dim = DIM; + m_leaf_max_size = params.leaf_max_size; + init(); + const size_t num_initial_points = dataset.kdtree_get_point_count(); + if (num_initial_points > 0) { + addPoints(0, num_initial_points - 1); + } + } + + /** Deleted copy constructor*/ + KDTreeSingleIndexDynamicAdaptor( + const KDTreeSingleIndexDynamicAdaptor &) = delete; + + /** Add points to the set, Inserts all points from [start, end] */ + void addPoints(IndexType start, IndexType end) { + size_t count = end - start + 1; + treeIndex.resize(treeIndex.size() + count); + for (IndexType idx = start; idx <= end; idx++) { + int pos = First0Bit(pointCount); + index[pos].vind.clear(); + treeIndex[pointCount] = pos; + for (int i = 0; i < pos; i++) { + for (int j = 0; j < static_cast(index[i].vind.size()); j++) { + index[pos].vind.push_back(index[i].vind[j]); + if (treeIndex[index[i].vind[j]] != -1) + treeIndex[index[i].vind[j]] = pos; + } + index[i].vind.clear(); + index[i].freeIndex(index[i]); + } + index[pos].vind.push_back(idx); + index[pos].buildIndex(); + pointCount++; + } + } + + /** Remove a point from the set (Lazy Deletion) */ + void removePoint(size_t idx) { + if (idx >= pointCount) + return; + treeIndex[idx] = -1; + } + + /** + * Find set of nearest neighbors to vec[0:dim-1]. Their indices are stored + * inside the result object. + * + * Params: + * result = the result object in which the indices of the + * nearest-neighbors are stored vec = the vector for which to search the + * nearest neighbors + * + * \tparam RESULTSET Should be any ResultSet + * \return True if the requested neighbors could be found. + * \sa knnSearch, radiusSearch + */ + template + bool findNeighbors(RESULTSET &result, const ElementType *vec, + const SearchParams &searchParams) const { + for (size_t i = 0; i < treeCount; i++) { + index[i].findNeighbors(result, &vec[0], searchParams); + } + return result.full(); + } +}; + +/** An L2-metric KD-tree adaptor for working with data directly stored in an + * Eigen Matrix, without duplicating the data storage. Each row in the matrix + * represents a point in the state space. + * + * Example of usage: + * \code + * Eigen::Matrix mat; + * // Fill out "mat"... + * + * typedef KDTreeEigenMatrixAdaptor< Eigen::Matrix > + * my_kd_tree_t; const int max_leaf = 10; my_kd_tree_t mat_index(mat, max_leaf + * ); mat_index.index->buildIndex(); mat_index.index->... \endcode + * + * \tparam DIM If set to >0, it specifies a compile-time fixed dimensionality + * for the points in the data set, allowing more compiler optimizations. \tparam + * Distance The distance metric to use: nanoflann::metric_L1, + * nanoflann::metric_L2, nanoflann::metric_L2_Simple, etc. + */ +template +struct KDTreeEigenMatrixAdaptor { + typedef KDTreeEigenMatrixAdaptor self_t; + typedef typename MatrixType::Scalar num_t; + typedef typename MatrixType::Index IndexType; + typedef + typename Distance::template traits::distance_t metric_t; + typedef KDTreeSingleIndexAdaptor + index_t; + + index_t *index; //! The kd-tree index for the user to call its methods as + //! usual with any other FLANN index. + + /// Constructor: takes a const ref to the matrix object with the data points + KDTreeEigenMatrixAdaptor(const size_t dimensionality, + const std::reference_wrapper &mat, + const int leaf_max_size = 10) + : m_data_matrix(mat) { + const auto dims = mat.get().cols(); + if (size_t(dims) != dimensionality) + throw std::runtime_error( + "Error: 'dimensionality' must match column count in data matrix"); + if (DIM > 0 && int(dims) != DIM) + throw std::runtime_error( + "Data set dimensionality does not match the 'DIM' template argument"); + index = + new index_t(static_cast(dims), *this /* adaptor */, + nanoflann::KDTreeSingleIndexAdaptorParams(leaf_max_size)); + index->buildIndex(); + } + +public: + /** Deleted copy constructor */ + KDTreeEigenMatrixAdaptor(const self_t &) = delete; + + ~KDTreeEigenMatrixAdaptor() { delete index; } + + const std::reference_wrapper m_data_matrix; + + /** Query for the \a num_closest closest points to a given point (entered as + * query_point[0:dim-1]). Note that this is a short-cut method for + * index->findNeighbors(). The user can also call index->... methods as + * desired. \note nChecks_IGNORED is ignored but kept for compatibility with + * the original FLANN interface. + */ + inline void query(const num_t *query_point, const size_t num_closest, + IndexType *out_indices, num_t *out_distances_sq, + const int /* nChecks_IGNORED */ = 10) const { + nanoflann::KNNResultSet resultSet(num_closest); + resultSet.init(out_indices, out_distances_sq); + index->findNeighbors(resultSet, query_point, nanoflann::SearchParams()); + } + + /** @name Interface expected by KDTreeSingleIndexAdaptor + * @{ */ + + const self_t &derived() const { return *this; } + self_t &derived() { return *this; } + + // Must return the number of data points + inline size_t kdtree_get_point_count() const { + return m_data_matrix.get().rows(); + } + + // Returns the dim'th component of the idx'th point in the class: + inline num_t kdtree_get_pt(const IndexType idx, size_t dim) const { + return m_data_matrix.get().coeff(idx, IndexType(dim)); + } + + // Optional bounding-box computation: return false to default to a standard + // bbox computation loop. + // Return true if the BBOX was already computed by the class and returned in + // "bb" so it can be avoided to redo it again. Look at bb.size() to find out + // the expected dimensionality (e.g. 2 or 3 for point clouds) + template bool kdtree_get_bbox(BBOX & /*bb*/) const { + return false; + } + + /** @} */ + +}; // end of KDTreeEigenMatrixAdaptor + /** @} */ + +/** @} */ // end of grouping +} // namespace nanoflann + +#endif /* NANOFLANN_HPP_ */ diff --git a/torch-points3d/torch_points3d/modules/KPConv/kernel_points.py b/torch-points3d/torch_points3d/modules/KPConv/kernel_points.py new file mode 100644 index 0000000..ecdae37 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/KPConv/kernel_points.py @@ -0,0 +1,413 @@ +# +# +# 0=================================0 +# | Kernel Point Convolutions | +# 0=================================0 +# +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Functions handling the disposition of kernel points. +# +# ---------------------------------------------------------------------------------------------------------------------- +# +# Hugues THOMAS - 11/06/2018 +# +import logging + +# ------------------------------------------------------------------------------------------ +# +# Imports and global variables +# \**********************************/ +# + + +# Import numpy package and name it "np" +import numpy as np +from os import makedirs +from os.path import join, exists + + +# ------------------------------------------------------------------------------------------ +# +# Functions +# \***************/ +# +# + +def create_3D_rotations(axis, angle): + """ + Create rotation matrices from a list of axes and angles. Code from wikipedia on quaternions + :param axis: float32[N, 3] + :param angle: float32[N,] + :return: float32[N, 3, 3] + """ + + t1 = np.cos(angle) + t2 = 1 - t1 + t3 = axis[:, 0] * axis[:, 0] + t6 = t2 * axis[:, 0] + t7 = t6 * axis[:, 1] + t8 = np.sin(angle) + t9 = t8 * axis[:, 2] + t11 = t6 * axis[:, 2] + t12 = t8 * axis[:, 1] + t15 = axis[:, 1] * axis[:, 1] + t19 = t2 * axis[:, 1] * axis[:, 2] + t20 = t8 * axis[:, 0] + t24 = axis[:, 2] * axis[:, 2] + R = np.stack([t1 + t2 * t3, + t7 - t9, + t11 + t12, + t7 + t9, + t1 + t2 * t15, + t19 - t20, + t11 - t12, + t19 + t20, + t1 + t2 * t24], axis=1) + + return np.reshape(R, (-1, 3, 3)) + + +def spherical_Lloyd(radius, num_cells, dimension=3, fixed='center', approximation='monte-carlo', + approx_n=5000, max_iter=500, momentum=0.9, verbose=0): + """ + Creation of kernel point via Lloyd algorithm. We use an approximation of the algorithm, and compute the Voronoi + cell centers with discretization of space. The exact formula is not trivial with part of the sphere as sides. + :param radius: Radius of the kernels + :param num_cells: Number of cell (kernel points) in the Voronoi diagram. + :param dimension: dimension of the space + :param fixed: fix position of certain kernel points ('none', 'center' or 'verticals') + :param approximation: Approximation method for Lloyd's algorithm ('discretization', 'monte-carlo') + :param approx_n: Number of point used for approximation. + :param max_iter: Maximum nu;ber of iteration for the algorithm. + :param momentum: Momentum of the low pass filter smoothing kernel point positions + :param verbose: display option + :return: points [num_kernels, num_points, dimension] + """ + + ####################### + # Parameters definition + ####################### + + # Radius used for optimization (points are rescaled afterwards) + radius0 = 1.0 + + ####################### + # Kernel initialization + ####################### + + # Random kernel points (Uniform distribution in a sphere) + kernel_points = np.zeros((0, dimension)) + while kernel_points.shape[0] < num_cells: + new_points = np.random.rand(num_cells, dimension) * 2 * radius0 - radius0 + kernel_points = np.vstack((kernel_points, new_points)) + d2 = np.sum(np.power(kernel_points, 2), axis=1) + kernel_points = kernel_points[np.logical_and(d2 < radius0 ** 2, (0.9 * radius0) ** 2 < d2), :] + kernel_points = kernel_points[:num_cells, :].reshape((num_cells, -1)) + + # Optional fixing + if fixed == 'center': + kernel_points[0, :] *= 0 + if fixed == 'verticals': + kernel_points[:3, :] *= 0 + kernel_points[1, -1] += 2 * radius0 / 3 + kernel_points[2, -1] -= 2 * radius0 / 3 + + ############################## + # Approximation initialization + ############################## + + # Initialize discretization in this method is chosen + if approximation == 'discretization': + side_n = int(np.floor(approx_n ** (1. / dimension))) + dl = 2 * radius0 / side_n + coords = np.arange(-radius0 + dl / 2, radius0, dl) + if dimension == 2: + x, y = np.meshgrid(coords, coords) + X = np.vstack((np.ravel(x), np.ravel(y))).T + elif dimension == 3: + x, y, z = np.meshgrid(coords, coords, coords) + X = np.vstack((np.ravel(x), np.ravel(y), np.ravel(z))).T + elif dimension == 4: + x, y, z, t = np.meshgrid(coords, coords, coords, coords) + X = np.vstack((np.ravel(x), np.ravel(y), np.ravel(z), np.ravel(t))).T + else: + raise ValueError('Unsupported dimension (max is 4)') + elif approximation == 'monte-carlo': + X = np.zeros((0, dimension)) + else: + raise ValueError('Wrong approximation method chosen: "{:s}"'.format(approximation)) + + # Only points inside the sphere are used + d2 = np.sum(np.power(X, 2), axis=1) + X = X[d2 < radius0 * radius0, :] + + ##################### + # Kernel optimization + ##################### + + # Warning if at least one kernel point has no cell + warning = False + + # moving vectors of kernel points saved to detect convergence + max_moves = np.zeros((0,)) + + for iter in range(max_iter): + + # In the case of monte-carlo, renew the sampled points + if approximation == 'monte-carlo': + X = np.random.rand(approx_n, dimension) * 2 * radius0 - radius0 + d2 = np.sum(np.power(X, 2), axis=1) + X = X[d2 < radius0 * radius0, :] + + # Get the distances matrix [n_approx, K, dim] + differences = np.expand_dims(X, 1) - kernel_points + sq_distances = np.sum(np.square(differences), axis=2) + + # Compute cell centers + cell_inds = np.argmin(sq_distances, axis=1) + centers = [] + for c in range(num_cells): + bool_c = (cell_inds == c) + num_c = np.sum(bool_c.astype(np.int32)) + if num_c > 0: + centers.append(np.sum(X[bool_c, :], axis=0) / num_c) + else: + warning = True + centers.append(kernel_points[c]) + + # Update kernel points with low pass filter to smooth mote carlo + centers = np.vstack(centers) + moves = (1 - momentum) * (centers - kernel_points) + kernel_points += moves + + # Check moves for convergence + max_moves = np.append(max_moves, np.max(np.linalg.norm(moves, axis=1))) + + # Optional fixing + if fixed == 'center': + kernel_points[0, :] *= 0 + if fixed == 'verticals': + kernel_points[0, :] *= 0 + kernel_points[:3, :-1] *= 0 + + if verbose: + logging.log('iter {:5d} / max move = {:f}'.format(iter, np.max(np.linalg.norm(moves, axis=1)))) + if warning: + logging.warning('{t least one point has no cell') + + # Rescale kernels with real radius + return kernel_points * radius + + +def kernel_point_optimization_debug(radius, num_points, num_kernels=1, dimension=3, + fixed='center', ratio=0.66, verbose=0): + """ + Creation of kernel point via optimization of potentials. + :param radius: Radius of the kernels + :param num_points: points composing kernels + :param num_kernels: number of wanted kernels + :param dimension: dimension of the space + :param fixed: fix position of certain kernel points ('none', 'center' or 'verticals') + :param ratio: ratio of the radius where you want the kernels points to be placed + :param verbose: display option + :return: points [num_kernels, num_points, dimension] + """ + + ####################### + # Parameters definition + ####################### + + # Radius used for optimization (points are rescaled afterwards) + radius0 = 1 + diameter0 = 2 + + # Factor multiplicating gradients for moving points (~learning rate) + moving_factor = 1e-2 + continuous_moving_decay = 0.9995 + + # Gradient threshold to stop optimization + thresh = 1e-5 + + # Gradient clipping value + clip = 0.05 * radius0 + + ####################### + # Kernel initialization + ####################### + + # Random kernel points + kernel_points = np.random.rand(num_kernels * num_points - 1, dimension) * diameter0 - radius0 + while (kernel_points.shape[0] < num_kernels * num_points): + new_points = np.random.rand(num_kernels * num_points - 1, dimension) * diameter0 - radius0 + kernel_points = np.vstack((kernel_points, new_points)) + d2 = np.sum(np.power(kernel_points, 2), axis=1) + kernel_points = kernel_points[d2 < 0.5 * radius0 * radius0, :] + kernel_points = kernel_points[:num_kernels * num_points, :].reshape((num_kernels, num_points, -1)) + + # Optional fixing + if fixed == 'center': + kernel_points[:, 0, :] *= 0 + if fixed == 'verticals': + kernel_points[:, :3, :] *= 0 + kernel_points[:, 1, -1] += 2 * radius0 / 3 + kernel_points[:, 2, -1] -= 2 * radius0 / 3 + + ##################### + # Kernel optimization + ##################### + + saved_gradient_norms = np.zeros((10000, num_kernels)) + old_gradient_norms = np.zeros((num_kernels, num_points)) + step = -1 + while step < 10000: + + # Increment + step += 1 + + # Compute gradients + # ***************** + + # Derivative of the sum of potentials of all points + A = np.expand_dims(kernel_points, axis=2) + B = np.expand_dims(kernel_points, axis=1) + interd2 = np.sum(np.power(A - B, 2), axis=-1) + inter_grads = (A - B) / (np.power(np.expand_dims(interd2, -1), 3 / 2) + 1e-6) + inter_grads = np.sum(inter_grads, axis=1) + + # Derivative of the radius potential + circle_grads = 10 * kernel_points + + # All gradients + gradients = inter_grads + circle_grads + + if fixed == 'verticals': + gradients[:, 1:3, :-1] = 0 + + # Stop condition + # ************** + + # Compute norm of gradients + gradients_norms = np.sqrt(np.sum(np.power(gradients, 2), axis=-1)) + saved_gradient_norms[step, :] = np.max(gradients_norms, axis=1) + + # Stop if all moving points are gradients fixed (low gradients diff) + + if fixed == 'center' and np.max(np.abs(old_gradient_norms[:, 1:] - gradients_norms[:, 1:])) < thresh: + break + elif fixed == 'verticals' and np.max(np.abs(old_gradient_norms[:, 3:] - gradients_norms[:, 3:])) < thresh: + break + elif np.max(np.abs(old_gradient_norms - gradients_norms)) < thresh: + break + old_gradient_norms = gradients_norms + + # Move points + # *********** + + # Clip gradient to get moving dists + moving_dists = np.minimum(moving_factor * gradients_norms, clip) + + # Fix central point + if fixed == 'center': + moving_dists[:, 0] = 0 + if fixed == 'verticals': + moving_dists[:, 0] = 0 + + # Move points + kernel_points -= np.expand_dims(moving_dists, -1) * gradients / np.expand_dims(gradients_norms + 1e-6, -1) + + if verbose: + logging.log('step {:5d} / max grad = {:f}'.format(step, np.max(gradients_norms[:, 3:]))) + + # moving factor decay + moving_factor *= continuous_moving_decay + + # Remove unused lines in the saved gradients + if step < 10000: + saved_gradient_norms = saved_gradient_norms[:step + 1, :] + + # Rescale radius to fit the wanted ratio of radius + r = np.sqrt(np.sum(np.power(kernel_points, 2), axis=-1)) + kernel_points *= ratio / np.mean(r[:, 1:]) + + # Rescale kernels with real radius + return kernel_points * radius, saved_gradient_norms + + +def load_kernels(radius, num_kpoints, dimension, fixed, lloyd=False): + # Kernel directory + kernel_dir = 'kernels/dispositions' + if not exists(kernel_dir): + makedirs(kernel_dir) + + # To many points switch to Lloyds + if num_kpoints > 30: + lloyd = True + + # Kernel_file + kernel_file = join(kernel_dir, 'k_{:03d}_{:s}_{:d}D.ply'.format(num_kpoints, fixed, dimension)) + + # Check if already done + if not exists(kernel_file): + if lloyd: + # Create kernels + kernel_points = spherical_Lloyd(1.0, + num_kpoints, + dimension=dimension, + fixed=fixed, + verbose=0) + + else: + # Create kernels + kernel_points, grad_norms = kernel_point_optimization_debug(1.0, + num_kpoints, + num_kernels=100, + dimension=dimension, + fixed=fixed, + verbose=0) + + # Find best candidate + best_k = np.argmin(grad_norms[-1, :]) + + # Save points + kernel_points = kernel_points[best_k, :, :] + + # Random roations for the kernel + # N.B. 4D random rotations not supported yet + R = np.eye(dimension) + theta = np.random.rand() * 2 * np.pi + if dimension == 2: + if fixed != 'vertical': + c, s = np.cos(theta), np.sin(theta) + R = np.array([[c, -s], [s, c]], dtype=np.float32) + + elif dimension == 3: + if fixed != 'vertical': + c, s = np.cos(theta), np.sin(theta) + R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]], dtype=np.float32) + + else: + phi = (np.random.rand() - 0.5) * np.pi + + # Create the first vector in carthesian coordinates + u = np.array([np.cos(theta) * np.cos(phi), np.sin(theta) * np.cos(phi), np.sin(phi)]) + + # Choose a random rotation angle + alpha = np.random.rand() * 2 * np.pi + + # Create the rotation matrix with this vector and angle + R = create_3D_rotations(np.reshape(u, (1, -1)), np.reshape(alpha, (1, -1)))[0] + + R = R.astype(np.float32) + + # Add a small noise + kernel_points = kernel_points + np.random.normal(scale=0.01, size=kernel_points.shape) + + # Scale kernels + kernel_points = radius * kernel_points + + # Rotate kernels + kernel_points = np.matmul(kernel_points, R) + + return kernel_points.astype(np.float32) diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/PointNet.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/PointNet.py new file mode 100644 index 0000000..c8a4faa --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/PointNet.py @@ -0,0 +1,49 @@ +import MinkowskiEngine as ME +import torch +import torch.nn as nn +from torch.cuda.amp import custom_fwd + +from .common import ACTIVATIONS, GLOBAL_POOL + + +class MinkowskiPointNet(nn.Module): + def __init__(self, in_channels, out_channels, activation="relu", global_pool="max", embedding_channel=1024, D=3, + dropout=0.0, bn_momentum=.1, + **kwargs): + super().__init__() + self.act_fn = ACTIVATIONS[activation]() + + self.blocks = nn.Sequential( + ME.MinkowskiLinear(D + in_channels, 64, bias=False), + ME.MinkowskiBatchNorm(64, momentum=bn_momentum), + self.act_fn, + + ME.MinkowskiLinear(64, 128, bias=False), + ME.MinkowskiBatchNorm(128, momentum=bn_momentum), + self.act_fn, + + ME.MinkowskiLinear(128, embedding_channel, bias=False), + ME.MinkowskiBatchNorm(embedding_channel, momentum=bn_momentum), + self.act_fn, + ) + self.global_pool = GLOBAL_POOL[global_pool]() + + self.mlp = nn.Sequential( + ME.MinkowskiLinear(embedding_channel, 512, bias=False), + ME.MinkowskiBatchNorm(512, momentum=bn_momentum), + self.act_fn, + + ME.MinkowskiLinear(512, 256, bias=False), + ME.MinkowskiBatchNorm(256, momentum=bn_momentum), + self.act_fn, + ) + self.dp1 = ME.MinkowskiDropout(dropout) + self.final = ME.MinkowskiLinear(256, out_channels, bias=True) + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, x): + x = self.blocks(x) + x = self.global_pool(x) + x = self.mlp(x) + x = self.dp1(x) + return self.final(x) diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/SENet.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/SENet.py new file mode 100644 index 0000000..64280f1 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/SENet.py @@ -0,0 +1,194 @@ +from functools import partial + +import MinkowskiEngine as ME +import torch.nn as nn +from MinkowskiEngine import MinkowskiNormalization as N + +from .common import ConvNormActivation, MinkowskiLayerNorm, ACTIVATIONS, GLOBAL_POOL +from .resnet_block import BasicBlock, Bottleneck +from .senet_block import SEBasicBlock, SEBottleneck + + + + +class ResNetBase(nn.Module): + BLOCK = None + LAYERS = () + INIT_DIM = 64 + PLANES = (64, 128, 256, 512) + + def __init__(self, in_channels, out_channels, activation="relu", D=3, first_stride=2, dropout=0.0, drop_path=0.0, + bn_momentum=0.1, norm_type="bn", global_pool="mean", use_gn=False, bias=True, **kwargs): + nn.Module.__init__(self) + self.D = D + self.bias = bias + self.bn_momentum = bn_momentum + self.cross_dims = [] + self.drop_path = drop_path + assert self.BLOCK is not None, "BLOCK is not defined" + assert self.PLANES is not None, "PLANES is not defined" + assert self.STRIDES is not None, "STRIDES is not defined" + + self.act_fn = ACTIVATIONS[activation]() + self.norm_type = norm_type + if norm_type == "bn": + self.norm_layer = partial(N.MinkowskiBatchNorm, momentum=bn_momentum) + elif norm_type == "bn_no_affine": + self.norm_layer = partial(N.MinkowskiBatchNorm, momentum=bn_momentum, affine=False) + elif norm_type == "in": + self.norm_layer = N.MinkowskiInstanceNorm + elif norm_type == "ln": + self.norm_layer = MinkowskiLayerNorm + else: + raise NotImplementedError(f"Choose either 'bn', 'in', or 'ln'. Given: {norm_type}") + + self.inplanes = self.INIT_DIM + first_out_planes = self.inplanes + self.blocks = [ + nn.Sequential( + ConvNormActivation( + in_channels, first_out_planes, kernel_size=7, stride=first_stride, D=D, + bias=bias, activation_layer=self.act_fn, norm_layer=self.norm_layer + ), + ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D) + ) + ] + + for planes, layers, stride in zip(self.PLANES, self.LAYERS, self.STRIDES): + self.blocks.append( + self._make_layer(self.BLOCK, planes, layers, stride=stride) + ) + self.blocks = nn.ModuleList(self.blocks) + + self.glob_avg = GLOBAL_POOL[global_pool]() # dimension=D) + if dropout > 0: + self.glob_avg = nn.Sequential( + self.glob_avg, + ME.MinkowskiDropout(dropout), + ) + + self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) + + self.apply(self.init_weights) + + @staticmethod + def init_weights(m): + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + if isinstance(m, ME.MinkowskiConvolution): + nn.init.trunc_normal_(m.kernel, std=.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if isinstance(m, ME.MinkowskiLinear): + nn.init.trunc_normal_(m.linear.weight, std=.02) + if m.linear.bias is not None: + nn.init.constant_(m.linear.bias, 0) + + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + ME.MinkowskiConvolution( + self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, dimension=self.D, + dilation=1, bias=self.bias, + ), + self.norm_layer(planes * block.expansion), + ) + layers = [block( + self.inplanes, planes, self.act_fn, stride=stride, dilation=dilation, downsample=downsample, + dimension=self.D, drop_path=self.drop_path, bias=self.bias, norm_layer=self.norm_layer + )] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block( + self.inplanes, planes, self.act_fn, stride=1, dilation=dilation, dimension=self.D, + drop_path=self.drop_path, bias=self.bias, norm_layer=self.norm_layer + )) + + return nn.Sequential(*layers) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + x = self.glob_avg(x) + return self.final(x) + + +class ResNet14_(ResNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1) + STRIDES = (1, 2, 2, 2) + + +class ResNet18_(ResNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2) + STRIDES = (1, 2, 2, 2) + + +class ResNet34_(ResNetBase): + BLOCK = BasicBlock + LAYERS = (3, 4, 6, 3) + STRIDES = (1, 2, 2, 2) + + +class ResNet50_(ResNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 6, 3) + STRIDES = (1, 2, 2, 2) + + +class ResNet101_(ResNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 23, 3) + STRIDES = (1, 2, 2, 2) + + +class SENet14(ResNetBase): + BLOCK = SEBasicBlock + LAYERS = (1, 1, 1, 1) + STRIDES = (1, 2, 2, 2) + + +class SENet17_6deep(ResNetBase): + BLOCK = SEBasicBlock + LAYERS = (1, 1, 1, 1, 2, 1) + STRIDES = (1, 2, 2, 2, 2, 2) + INIT_DIM = 32 + PLANES = (32, 64, 128, 256, 512, 1024) + + +class SENet17_5deep(ResNetBase): + BLOCK = SEBasicBlock + LAYERS = (1, 1, 1, 2, 2) + STRIDES = (1, 2, 2, 2, 2) + INIT_DIM = 64 + PLANES = (64, 128, 256, 512, 1024) + + +class SENet18(ResNetBase): + BLOCK = SEBasicBlock + LAYERS = (2, 2, 2, 2) + STRIDES = (1, 2, 2, 2) + + +class SENet34(ResNetBase): + BLOCK = SEBasicBlock + LAYERS = (3, 4, 6, 3) + STRIDES = (1, 2, 2, 2) + + +class SENet50(ResNetBase): + BLOCK = SEBottleneck + LAYERS = (3, 4, 6, 3) + STRIDES = (1, 2, 2, 2) + + +class SENet101(ResNetBase): + BLOCK = SEBottleneck + LAYERS = (3, 4, 23, 3) + STRIDES = (1, 2, 2, 2) diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/__init__.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/__init__.py new file mode 100644 index 0000000..1227285 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/__init__.py @@ -0,0 +1,25 @@ +import sys + + +from .networks import * +from .SENet import * +from .VAE import * +from .barlow import * +from .UNet import * +from .res16unet import * +from .resunet import * +from .PointNet import MinkowskiPointNet + +_custom_models = sys.modules[__name__] + + +def initialize_minkowski_unet( + model_name, in_channels, out_channels, D=3, conv1_kernel_size=3, **kwargs +): + net_cls = getattr(_custom_models, model_name) + return net_cls( + in_channels=in_channels, out_channels=out_channels, D=D, conv1_kernel_size=conv1_kernel_size, **kwargs + ) + + + diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/api_modules.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/api_modules.py new file mode 100644 index 0000000..137b546 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/api_modules.py @@ -0,0 +1,311 @@ +import sys + +import MinkowskiEngine as ME +import torch + +from torch_points3d.core.common_modules import Seq + + +class ResBlock(ME.MinkowskiNetwork): + """ + Basic ResNet type block + + Parameters + ---------- + input_nc: + Number of input channels + output_nc: + number of output channels + convolution + Either MinkowskConvolution or MinkowskiConvolutionTranspose + dimension: + Dimension of the spatial grid + """ + + def __init__(self, input_nc, output_nc, convolution, dimension=3): + ME.MinkowskiNetwork.__init__(self, dimension) + self.block = ( + Seq() + .append( + convolution( + in_channels=input_nc, + out_channels=output_nc, + kernel_size=3, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc)) + .append(ME.MinkowskiReLU()) + .append( + convolution( + in_channels=output_nc, + out_channels=output_nc, + kernel_size=3, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc)) + .append(ME.MinkowskiReLU()) + ) + + if input_nc != output_nc: + self.downsample = ( + Seq() + .append( + convolution( + in_channels=input_nc, + out_channels=output_nc, + kernel_size=1, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc)) + ) + else: + self.downsample = None + + def forward(self, x): + out = self.block(x) + if self.downsample: + out += self.downsample(x) + else: + out += x + return out + + +class BottleneckBlock(ME.MinkowskiNetwork): + """ + Bottleneck block with residual + """ + + def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=4): + self.block = ( + Seq() + .append( + convolution( + in_channels=input_nc, + out_channels=output_nc // reduction, + kernel_size=1, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc // reduction)) + .append(ME.MinkowskiReLU()) + .append( + convolution( + output_nc // reduction, + output_nc // reduction, + kernel_size=3, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc // reduction)) + .append(ME.MinkowskiReLU()) + .append( + convolution( + output_nc // reduction, + output_nc, + kernel_size=1, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc)) + .append(ME.MinkowskiReLU()) + ) + + if input_nc != output_nc: + self.downsample = ( + Seq() + .append( + convolution( + in_channels=input_nc, + out_channels=output_nc, + kernel_size=1, + stride=1, + dilation=1, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(output_nc)) + ) + else: + self.downsample = None + + def forward(self, x): + out = self.block(x) + if self.downsample: + out += self.downsample(x) + else: + out += x + return out + + +class SELayer(torch.nn.Module): + """ + Squeeze and excite layer + + Parameters + ---------- + channel: + size of the input and output + reduction: + magnitude of the compression + D: + dimension of the kernels + """ + + def __init__(self, channel, reduction=16, dimension=3): + # Global coords does not require coords_key + super(SELayer, self).__init__() + self.fc = torch.nn.Sequential( + ME.MinkowskiLinear(channel, channel // reduction), + ME.MinkowskiReLU(), + ME.MinkowskiLinear(channel // reduction, channel), + ME.MinkowskiSigmoid(), + ) + self.pooling = ME.MinkowskiGlobalPooling() + self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() + + def forward(self, x): + y = self.pooling(x) + y = self.fc(y) + return self.broadcast_mul(x, y) + + +class SEBlock(ResBlock): + """ + ResBlock with SE layer + """ + + def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=16): + super().__init__(input_nc, output_nc, convolution, dimension=3) + self.SE = SELayer(output_nc, reduction=reduction, dimension=dimension) + + def forward(self, x): + out = self.block(x) + out = self.SE(out) + if self.downsample: + out += self.downsample(x) + else: + out += x + return out + + +class SEBottleneckBlock(BottleneckBlock): + """ + BottleneckBlock with SE layer + """ + + def __init__(self, input_nc, output_nc, convolution, dimension=3, reduction=16): + super().__init__(input_nc, output_nc, convolution, dimension=3, reduction=4) + self.SE = SELayer(output_nc, reduction=reduction, dimension=dimension) + + def forward(self, x): + out = self.block(x) + out = self.SE(out) + if self.downsample: + out += self.downsample(x) + else: + out += x + return out + + +_res_blocks = sys.modules[__name__] + + +class ResNetDown(ME.MinkowskiNetwork): + """ + Resnet block that looks like + + in --- strided conv ---- Block ---- sum --[... N times] + | | + |-- 1x1 - BN --| + """ + + CONVOLUTION = ME.MinkowskiConvolution + + def __init__( + self, down_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, block="ResBlock", **kwargs + ): + block = getattr(_res_blocks, block) + ME.MinkowskiNetwork.__init__(self, dimension) + if stride > 1: + conv1_output = down_conv_nn[0] + else: + conv1_output = down_conv_nn[1] + + self.conv_in = ( + Seq() + .append( + self.CONVOLUTION( + in_channels=down_conv_nn[0], + out_channels=conv1_output, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=False, + dimension=dimension, + ) + ) + .append(ME.MinkowskiBatchNorm(conv1_output)) + .append(ME.MinkowskiReLU()) + ) + + if N > 0: + self.blocks = Seq() + for _ in range(N): + self.blocks.append(block(conv1_output, down_conv_nn[1], self.CONVOLUTION, dimension=dimension)) + conv1_output = down_conv_nn[1] + else: + self.blocks = None + + def forward(self, x): + out = self.conv_in(x) + if self.blocks: + out = self.blocks(out) + return out + + +class ResNetUp(ResNetDown): + """ + Same as Down conv but for the Decoder + """ + + CONVOLUTION = ME.MinkowskiConvolutionTranspose + + def __init__(self, up_conv_nn=[], kernel_size=2, dilation=1, dimension=3, stride=2, N=1, **kwargs): + super().__init__( + down_conv_nn=up_conv_nn, + kernel_size=kernel_size, + dilation=dilation, + dimension=dimension, + stride=stride, + N=N, + **kwargs + ) + + def forward(self, x, skip): + if skip is not None: + inp = ME.cat(x, skip) + else: + inp = x + return super().forward(inp) diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/common.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/common.py new file mode 100644 index 0000000..293788e --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/common.py @@ -0,0 +1,386 @@ +import collections +import random +from enum import Enum +from functools import partial + +import MinkowskiEngine as ME +import torch +import torch.nn as nn +from MinkowskiEngine import MinkowskiNonlinearity as NL +from MinkowskiEngine import SparseTensor + + +class NormType(Enum): + BATCH_NORM = 0 + INSTANCE_NORM = 1 + INSTANCE_BATCH_NORM = 2 + + +def get_norm(norm_type, n_channels, D, bn_momentum=0.1): + if norm_type == NormType.BATCH_NORM: + return ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) + elif norm_type == NormType.INSTANCE_NORM: + return ME.MinkowskiInstanceNorm(n_channels) + elif norm_type == NormType.INSTANCE_BATCH_NORM: + return nn.Sequential( + ME.MinkowskiInstanceNorm(n_channels), ME.MinkowskiBatchNorm(n_channels, momentum=bn_momentum) + ) + else: + raise ValueError(f"Norm type: {norm_type} not supported") + + +ACTIVATIONS = { + "relu": partial(NL.MinkowskiReLU, inplace=True), + "celu": partial(NL.MinkowskiCELU, inplace=True, alpha=0.54), + "silu": partial(NL.MinkowskiSiLU, inplace=True), + "swish": partial(NL.MinkowskiSiLU, inplace=True), + "elu": partial(NL.MinkowskiELU, inplace=True, alpha=0.54), + "sigmoid": partial(NL.MinkowskiSigmoid), + "tanh": partial(NL.MinkowskiTanh), + "siren": partial(NL.MinkowskiSinusoidal), + "gelu": partial(NL.MinkowskiGELU), +} + +GLOBAL_POOL = { + "max": ME.MinkowskiGlobalMaxPooling, + "mean": ME.MinkowskiGlobalAvgPooling, + "sum": ME.MinkowskiGlobalSumPooling, +} + + +class ConvType(Enum): + """ + Define the kernel region type + """ + + HYPERCUBE = 0, "HYPERCUBE" + SPATIAL_HYPERCUBE = 1, "SPATIAL_HYPERCUBE" + SPATIO_TEMPORAL_HYPERCUBE = 2, "SPATIO_TEMPORAL_HYPERCUBE" + HYPERCROSS = 3, "HYPERCROSS" + SPATIAL_HYPERCROSS = 4, "SPATIAL_HYPERCROSS" + SPATIO_TEMPORAL_HYPERCROSS = 5, "SPATIO_TEMPORAL_HYPERCROSS" + SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS = 6, "SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS " + + def __new__(cls, value, name): + member = object.__new__(cls) + member._value_ = value + member.fullname = name + return member + + def __int__(self): + return self.value + + +# Covert the ConvType var to a RegionType var +conv_to_region_type = { + # kernel_size = [k, k, k, 1] + ConvType.HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.SPATIAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.SPATIO_TEMPORAL_HYPERCUBE: ME.RegionType.HYPER_CUBE, + ConvType.HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIO_TEMPORAL_HYPERCROSS: ME.RegionType.HYPER_CROSS, + ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: ME.RegionType.CUSTOM, +} + +int_to_region_type = {0: ME.RegionType.HYPER_CUBE, 1: ME.RegionType.HYPER_CROSS, 2: ME.RegionType.CUSTOM} + + +def convert_region_type(region_type): + """ + Convert the integer region_type to the corresponding RegionType enum object. + """ + return int_to_region_type[region_type] + + +def convert_conv_type(conv_type, kernel_size, D): + assert isinstance(conv_type, ConvType), "conv_type must be of ConvType" + region_type = conv_to_region_type[conv_type] + axis_types = None + if conv_type == ConvType.SPATIAL_HYPERCUBE: + # No temporal convolution + if isinstance(kernel_size, collections.Sequence): + kernel_size = kernel_size[:3] + else: + kernel_size = [kernel_size, ] * 3 + if D == 4: + kernel_size.append(1) + elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCUBE: + # conv_type conversion already handled + assert D == 4 + elif conv_type == ConvType.HYPERCUBE: + # conv_type conversion already handled + pass + elif conv_type == ConvType.SPATIAL_HYPERCROSS: + if isinstance(kernel_size, collections.Sequence): + kernel_size = kernel_size[:3] + else: + kernel_size = [kernel_size, ] * 3 + if D == 4: + kernel_size.append(1) + elif conv_type == ConvType.HYPERCROSS: + # conv_type conversion already handled + pass + elif conv_type == ConvType.SPATIO_TEMPORAL_HYPERCROSS: + # conv_type conversion already handled + assert D == 4 + elif conv_type == ConvType.SPATIAL_HYPERCUBE_TEMPORAL_HYPERCROSS: + # Define the CUBIC conv kernel for spatial dims and CROSS conv for temp dim + if D < 4: + region_type = ME.RegionType.HYPER_CUBE + else: + axis_types = [ME.RegionType.HYPER_CUBE, ] * 3 + if D == 4: + axis_types.append(ME.RegionType.HYPER_CROSS) + return region_type, axis_types, kernel_size + + +def conv(in_planes, out_planes, kernel_size, stride=1, dilation=1, bias=False, conv_type=ConvType.HYPERCUBE, D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D + ) + + return ME.MinkowskiConvolution( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=bias, + kernel_generator=kernel_generator, + dimension=D, + ) + + +def conv_tr( + in_planes, out_planes, kernel_size, upsample_stride=1, dilation=1, bias=False, conv_type=ConvType.HYPERCUBE, + D=-1 +): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, upsample_stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D + ) + + return ME.MinkowskiConvolutionTranspose( + in_channels=in_planes, + out_channels=out_planes, + kernel_size=kernel_size, + stride=upsample_stride, + dilation=dilation, + bias=bias, + kernel_generator=kernel_generator, + dimension=D, + ) + + +def avg_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, in_coords_key=None, D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D + ) + + return ME.MinkowskiAvgPooling( + kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D + ) + + +def avg_unpool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D + ) + + return ME.MinkowskiAvgUnpooling( + kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D + ) + + +def sum_pool(kernel_size, stride=1, dilation=1, conv_type=ConvType.HYPERCUBE, D=-1): + assert D > 0, "Dimension must be a positive integer" + region_type, axis_types, kernel_size = convert_conv_type(conv_type, kernel_size, D) + kernel_generator = ME.KernelGenerator( + kernel_size, stride, dilation, region_type=region_type, axis_types=axis_types, dimension=D + ) + + return ME.MinkowskiSumPooling( + kernel_size=kernel_size, stride=stride, dilation=dilation, kernel_generator=kernel_generator, dimension=D + ) + + +class ConvNormActivation(nn.Module): + def __init__(self, input_channels, out_channels, kernel_size, stride, norm_layer, + activation_layer, bias, D): + super().__init__() + self.conv = ME.MinkowskiConvolution( + input_channels, out_channels, kernel_size=kernel_size, stride=stride, dimension=D, bias=bias + ) + self.norm = norm_layer(out_channels) + self.act = nn.Identity() if activation_layer is None else activation_layer + + def forward(self, x): + return self.act(self.norm(self.conv(x))) + + +def batch_norm(X, moving_mean, moving_var, gamma, beta, training, momentum, eps, meanpool): + # Use is_grad_enabled to determine whether we are in training mode + if not torch.is_grad_enabled() or not training: + # In prediction mode, use mean and variance obtained by moving average + X_hat = (X.F - moving_mean) / torch.sqrt(moving_var + eps) + else: + assert len(X.shape) in (2, 4) + + # When using a fully connected layer, calculate the mean and + # variance on the feature dimension + mean = meanpool(X).F.mean(0) + diff = X.F - mean + var = (diff ** 2).mean(0) + + # In training mode, the current mean and variance are used + X_hat = diff / torch.sqrt(var + eps) + # Update the mean and variance using moving average + if training: + moving_mean = (1.0 - momentum) * moving_mean + momentum * mean + moving_var = (1.0 - momentum) * moving_var + momentum * var + return gamma * X_hat + beta # Scale and shift + + +class MinkowskiBatchNorm(nn.BatchNorm1d): + r"""A batch normalization layer for a sparse tensor. + + See the pytorch :attr:`torch.nn.BatchNorm1d` for more details. + """ + + def __init__( + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + ): + super(MinkowskiBatchNorm, self).__init__(num_features, eps, momentum, affine, track_running_stats) + self.meanpool = ME.MinkowskiGlobalAvgPooling() + + def forward(self, input_): + input = input_.F + self._check_input_dim(input) + + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] + + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + output = batch_norm( + input_, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + self.momentum, + self.eps, + self.meanpool + ) + + return SparseTensor( + output, + coordinate_map_key=input_.coordinate_map_key, + coordinate_manager=input_.coordinate_manager, + ) + + def __repr__(self): + s = "({}, eps={}, momentum={}, affine={}, track_running_stats={})".format( + self.num_features, + self.eps, + self.momentum, + self.affine, + self.track_running_stats, + ) + return self.__class__.__name__ + s + + +# from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/utils.py +class MinkowskiGRN(nn.Module): + """ GRN layer for sparse tensors. + """ + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, dim)) + self.beta = nn.Parameter(torch.zeros(1, dim)) + + def forward(self, x): + cm = x.coordinate_manager + in_key = x.coordinate_map_key + + Gx = torch.norm(x.F, p=2, dim=0, keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return SparseTensor( + self.gamma * (x.F * Nx) + self.beta + x.F, + coordinate_map_key=in_key, + coordinate_manager=cm + ) + + +class MinkowskiDropPath(nn.Module): + """ Drop Path for sparse tensors. + """ + + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): + super(MinkowskiDropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + if not self.training: + return x + keep_prob = 1 - self.drop_prob + mask = torch.cat([ + torch.ones(len(_)) if random.uniform(0, 1) > self.drop_prob + else torch.zeros(len(_)) for _ in x.decomposed_coordinates + ]).view(-1, 1).to(x.device) + if keep_prob > 0.0 and self.scale_by_keep: + mask.div_(keep_prob) + return SparseTensor( + x.F * mask, + coordinate_map_key=x.coordinate_map_key, + coordinate_manager=x.coordinate_manager) + + +class MinkowskiLayerNorm(nn.Module): + """ Channel-wise layer normalization for sparse tensors. + """ + + def __init__( + self, + normalized_shape, + eps=1e-6, + ): + super(MinkowskiLayerNorm, self).__init__() + self.ln = nn.LayerNorm(normalized_shape, eps=eps) + + def forward(self, input): + output = self.ln(input.F) + return SparseTensor( + output, + coordinate_map_key=input.coordinate_map_key, + coordinate_manager=input.coordinate_manager) diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/modules.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/modules.py new file mode 100644 index 0000000..3add6a7 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/modules.py @@ -0,0 +1,378 @@ +import torch.nn as nn +import MinkowskiEngine as ME +from .common import ConvType, NormType + +from torch_points3d.utils.config import is_list + + +class BasicBlock(nn.Module): + """This module implements a basic residual convolution block using MinkowskiEngine + + Parameters + ---------- + inplanes: int + Input dimension + planes: int + Output dimension + dilation: int + Dilation value + downsample: nn.Module + If provided, downsample will be applied on input before doing residual addition + bn_momentum: float + Input dimension + """ + + EXPANSION = 1 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, dimension=-1): + super(BasicBlock, self).__init__() + assert dimension > 0 + + self.conv1 = ME.MinkowskiConvolution( + inplanes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension + ) + self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + self.conv2 = ME.MinkowskiConvolution( + planes, planes, kernel_size=3, stride=1, dilation=dilation, dimension=dimension + ) + self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + self.relu = ME.MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + EXPANSION = 4 + + def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, bn_momentum=0.1, dimension=-1): + super(Bottleneck, self).__init__() + assert dimension > 0 + + self.conv1 = ME.MinkowskiConvolution(inplanes, planes, kernel_size=1, dimension=dimension) + self.norm1 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + + self.conv2 = ME.MinkowskiConvolution( + planes, planes, kernel_size=3, stride=stride, dilation=dilation, dimension=dimension + ) + self.norm2 = ME.MinkowskiBatchNorm(planes, momentum=bn_momentum) + + self.conv3 = ME.MinkowskiConvolution(planes, planes * self.EXPANSION, kernel_size=1, dimension=dimension) + self.norm3 = ME.MinkowskiBatchNorm(planes * self.EXPANSION, momentum=bn_momentum) + + self.relu = ME.MinkowskiReLU(inplace=True) + self.downsample = downsample + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class BaseResBlock(nn.Module): + def __init__( + self, + feat_in, + feat_mid, + feat_out, + kernel_sizes=[], + strides=[], + dilations=[], + has_biases=[], + kernel_generators=[], + kernel_size=3, + stride=1, + dilation=1, + bias=False, + kernel_generator=None, + norm_layer=ME.MinkowskiBatchNorm, + activation=ME.MinkowskiReLU, + bn_momentum=0.1, + dimension=-1, + **kwargs + ): + + super(BaseResBlock, self).__init__() + assert dimension > 0 + + modules = [] + + convolutions_dim = [[feat_in, feat_mid], [feat_mid, feat_mid], [feat_mid, feat_out]] + + kernel_sizes = self.create_arguments_list(kernel_sizes, kernel_size) + strides = self.create_arguments_list(strides, stride) + dilations = self.create_arguments_list(dilations, dilation) + has_biases = self.create_arguments_list(has_biases, bias) + kernel_generators = self.create_arguments_list(kernel_generators, kernel_generator) + + for conv_dim, kernel_size, stride, dilation, has_bias, kernel_generator in zip( + convolutions_dim, kernel_sizes, strides, dilations, has_biases, kernel_generators + ): + + modules.append( + ME.MinkowskiConvolution( + conv_dim[0], + conv_dim[1], + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + bias=has_bias, + kernel_generator=kernel_generator, + dimension=dimension, + ) + ) + + if norm_layer: + modules.append(norm_layer(conv_dim[1], momentum=bn_momentum)) + + if activation: + modules.append(activation(inplace=True)) + + self.conv = nn.Sequential(*modules) + + @staticmethod + def create_arguments_list(arg_list, arg): + if len(arg_list) == 3: + return arg_list + return [arg for _ in range(3)] + + def forward(self, x): + return x, self.conv(x) + + +class ResnetBlockDown(BaseResBlock): + def __init__( + self, + down_conv_nn=[], + kernel_sizes=[], + strides=[], + dilations=[], + kernel_size=3, + stride=1, + dilation=1, + norm_layer=ME.MinkowskiBatchNorm, + activation=ME.MinkowskiReLU, + bn_momentum=0.1, + dimension=-1, + down_stride=2, + **kwargs + ): + + super(ResnetBlockDown, self).__init__( + down_conv_nn[0], + down_conv_nn[1], + down_conv_nn[2], + kernel_sizes=kernel_sizes, + strides=strides, + dilations=dilations, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + norm_layer=norm_layer, + activation=activation, + bn_momentum=bn_momentum, + dimension=dimension, + ) + + self.downsample = nn.Sequential( + ME.MinkowskiConvolution( + down_conv_nn[0], down_conv_nn[2], kernel_size=2, stride=down_stride, dimension=dimension + ), + ME.MinkowskiBatchNorm(down_conv_nn[2]), + ) + + def forward(self, x): + + residual, x = super().forward(x) + + return self.downsample(residual) + x + + +class ResnetBlockUp(BaseResBlock): + def __init__( + self, + up_conv_nn=[], + kernel_sizes=[], + strides=[], + dilations=[], + kernel_size=3, + stride=1, + dilation=1, + norm_layer=ME.MinkowskiBatchNorm, + activation=ME.MinkowskiReLU, + bn_momentum=0.1, + dimension=-1, + up_stride=2, + skip=True, + **kwargs + ): + + self.skip = skip + + super(ResnetBlockUp, self).__init__( + up_conv_nn[0], + up_conv_nn[1], + up_conv_nn[2], + kernel_sizes=kernel_sizes, + strides=strides, + dilations=dilations, + kernel_size=kernel_size, + stride=stride, + dilation=dilation, + norm_layer=norm_layer, + activation=activation, + bn_momentum=bn_momentum, + dimension=dimension, + ) + + self.upsample = ME.MinkowskiConvolutionTranspose( + up_conv_nn[0], up_conv_nn[2], kernel_size=2, stride=up_stride, dimension=dimension + ) + + def forward(self, x, x_skip): + residual, x = super().forward(x) + + x = self.upsample(residual) + x + + if self.skip: + return ME.cat(x, x_skip) + else: + return x + + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16, D=-1): + # Global coords does not require coords_key + super(SELayer, self).__init__() + self.fc = nn.Sequential( + ME.MinkowskiLinear(channel, channel // reduction), + ME.MinkowskiReLU(inplace=True), + ME.MinkowskiLinear(channel // reduction, channel), + ME.MinkowskiSigmoid(), + ) + self.pooling = ME.MinkowskiGlobalPooling(dimension=D) + self.broadcast_mul = ME.MinkowskiBroadcastMultiplication(dimension=D) + + def forward(self, x): + y = self.pooling(x) + y = self.fc(y) + return self.broadcast_mul(x, y) + + +class SEBasicBlock(BasicBlock): + def __init__( + self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, reduction=16, D=-1 + ): + super(SEBasicBlock, self).__init__( + inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, conv_type=conv_type, D=D + ) + self.se = SELayer(planes, reduction=reduction, D=D) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBasicBlockBN(SEBasicBlock): + NORM_TYPE = NormType.BATCH_NORM + + +class SEBasicBlockIN(SEBasicBlock): + NORM_TYPE = NormType.INSTANCE_NORM + + +class SEBasicBlockIBN(SEBasicBlock): + NORM_TYPE = NormType.INSTANCE_BATCH_NORM + + +class SEBottleneck(Bottleneck): + def __init__( + self, inplanes, planes, stride=1, dilation=1, downsample=None, conv_type=ConvType.HYPERCUBE, D=3, reduction=16 + ): + super(SEBottleneck, self).__init__( + inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, conv_type=conv_type, D=D + ) + self.se = SELayer(planes * self.expansion, reduction=reduction, D=D) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + out = self.se(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class SEBottleneckBN(SEBottleneck): + NORM_TYPE = NormType.BATCH_NORM + + +class SEBottleneckIN(SEBottleneck): + NORM_TYPE = NormType.INSTANCE_NORM + + +class SEBottleneckIBN(SEBottleneck): + NORM_TYPE = NormType.INSTANCE_BATCH_NORM diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/networks.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/networks.py new file mode 100644 index 0000000..ea22972 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/networks.py @@ -0,0 +1,310 @@ +import torch.nn as nn + +import MinkowskiEngine as ME +from .modules import BasicBlock, Bottleneck + + +class ResNetBase(nn.Module): + BLOCK = None + LAYERS = () + INIT_DIM = 64 + PLANES = (64, 128, 256, 512) + + def __init__(self, in_channels, out_channels, D=3, **kwargs): + nn.Module.__init__(self) + self.D = D + assert self.BLOCK is not None, "BLOCK is not defined" + assert self.PLANES is not None, "PLANES is not defined" + self.network_initialization(in_channels, out_channels, D) + self.weight_initialization() + + def network_initialization(self, in_channels, out_channels, D): + + self.inplanes = self.INIT_DIM + self.conv1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=5, stride=2, dimension=D) + + self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) + self.relu = ME.MinkowskiReLU(inplace=True) + + self.pool = ME.MinkowskiAvgPooling(kernel_size=2, stride=2, dimension=D) + + self.layer1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0], stride=2) + self.layer2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1], stride=2) + self.layer3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2], stride=2) + self.layer4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3], stride=2) + + self.conv5 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=3, stride=3, dimension=D) + self.bn5 = ME.MinkowskiBatchNorm(self.inplanes) + + self.glob_avg = ME.MinkowskiGlobalMaxPooling()#dimension=D) + + self.final = ME.MinkowskiLinear(self.inplanes, out_channels, bias=True) + + def weight_initialization(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiConvolution): + ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu") + + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1, bn_momentum=0.1): + downsample = None + if stride != 1 or self.inplanes != planes * block.EXPANSION: + downsample = nn.Sequential( + ME.MinkowskiConvolution( + self.inplanes, planes * block.EXPANSION, kernel_size=1, stride=stride, dimension=self.D + ), + ME.MinkowskiBatchNorm(planes * block.EXPANSION), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride=stride, dilation=dilation, downsample=downsample, dimension=self.D) + ) + self.inplanes = planes * block.EXPANSION + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, stride=1, dilation=dilation, dimension=self.D)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.pool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.conv5(x) + x = self.bn5(x) + x = self.relu(x) + + x = self.glob_avg(x) + return self.final(x) + + +class ResNet14(ResNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1) + + +class ResNet18(ResNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2) + + +class ResNet34(ResNetBase): + BLOCK = BasicBlock + LAYERS = (3, 4, 6, 3) + + +class ResNet50(ResNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 6, 3) + + +class ResNet101(ResNetBase): + BLOCK = Bottleneck + LAYERS = (3, 4, 23, 3) + + +class MinkUNetBase(ResNetBase): + BLOCK = None + PLANES = None + DILATIONS = (1, 1, 1, 1, 1, 1, 1, 1) + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + INIT_DIM = 32 + OUT_TENSOR_STRIDE = 1 + + # To use the model, must call initialize_coords before forward pass. + # Once data is processed, call clear to reset the model before calling + # initialize_coords + def __init__(self, in_channels, out_channels, D=3, **kwargs): + ResNetBase.__init__(self, in_channels, out_channels, D) + + def network_initialization(self, in_channels, out_channels, D): + # Output of the first conv concated to conv6 + self.inplanes = self.INIT_DIM + self.conv0p1s1 = ME.MinkowskiConvolution(in_channels, self.inplanes, kernel_size=5, dimension=D) + + self.bn0 = ME.MinkowskiBatchNorm(self.inplanes) + + self.conv1p1s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + self.bn1 = ME.MinkowskiBatchNorm(self.inplanes) + + self.block1 = self._make_layer(self.BLOCK, self.PLANES[0], self.LAYERS[0]) + + self.conv2p2s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + self.bn2 = ME.MinkowskiBatchNorm(self.inplanes) + + self.block2 = self._make_layer(self.BLOCK, self.PLANES[1], self.LAYERS[1]) + + self.conv3p4s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + + self.bn3 = ME.MinkowskiBatchNorm(self.inplanes) + self.block3 = self._make_layer(self.BLOCK, self.PLANES[2], self.LAYERS[2]) + + self.conv4p8s2 = ME.MinkowskiConvolution(self.inplanes, self.inplanes, kernel_size=2, stride=2, dimension=D) + self.bn4 = ME.MinkowskiBatchNorm(self.inplanes) + self.block4 = self._make_layer(self.BLOCK, self.PLANES[3], self.LAYERS[3]) + + self.convtr4p16s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[4], kernel_size=2, stride=2, dimension=D + ) + self.bntr4 = ME.MinkowskiBatchNorm(self.PLANES[4]) + + self.inplanes = self.PLANES[4] + self.PLANES[2] * self.BLOCK.EXPANSION + self.block5 = self._make_layer(self.BLOCK, self.PLANES[4], self.LAYERS[4]) + self.convtr5p8s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[5], kernel_size=2, stride=2, dimension=D + ) + self.bntr5 = ME.MinkowskiBatchNorm(self.PLANES[5]) + + self.inplanes = self.PLANES[5] + self.PLANES[1] * self.BLOCK.EXPANSION + self.block6 = self._make_layer(self.BLOCK, self.PLANES[5], self.LAYERS[5]) + self.convtr6p4s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[6], kernel_size=2, stride=2, dimension=D + ) + self.bntr6 = ME.MinkowskiBatchNorm(self.PLANES[6]) + + self.inplanes = self.PLANES[6] + self.PLANES[0] * self.BLOCK.EXPANSION + self.block7 = self._make_layer(self.BLOCK, self.PLANES[6], self.LAYERS[6]) + self.convtr7p2s2 = ME.MinkowskiConvolutionTranspose( + self.inplanes, self.PLANES[7], kernel_size=2, stride=2, dimension=D + ) + self.bntr7 = ME.MinkowskiBatchNorm(self.PLANES[7]) + + self.inplanes = self.PLANES[7] + self.INIT_DIM + self.block8 = self._make_layer(self.BLOCK, self.PLANES[7], self.LAYERS[7]) + + self.final = ME.MinkowskiConvolution(self.PLANES[7], out_channels, kernel_size=1, bias=True, dimension=D) + self.relu = ME.MinkowskiReLU(inplace=True) + + def forward(self, x): + out = self.conv0p1s1(x) + out = self.bn0(out) + out_p1 = self.relu(out) + + out = self.conv1p1s2(out_p1) + out = self.bn1(out) + out = self.relu(out) + out_b1p2 = self.block1(out) + + out = self.conv2p2s2(out_b1p2) + out = self.bn2(out) + out = self.relu(out) + out_b2p4 = self.block2(out) + + out = self.conv3p4s2(out_b2p4) + out = self.bn3(out) + out = self.relu(out) + out_b3p8 = self.block3(out) + + # tensor_stride=16 + out = self.conv4p8s2(out_b3p8) + out = self.bn4(out) + out = self.relu(out) + out = self.block4(out) + + # tensor_stride=8 + out = self.convtr4p16s2(out) + out = self.bntr4(out) + out = self.relu(out) + + out = ME.cat(out, out_b3p8) + out = self.block5(out) + + # tensor_stride=4 + out = self.convtr5p8s2(out) + out = self.bntr5(out) + out = self.relu(out) + + out = ME.cat(out, out_b2p4) + out = self.block6(out) + + # tensor_stride=2 + out = self.convtr6p4s2(out) + out = self.bntr6(out) + out = self.relu(out) + + out = ME.cat(out, out_b1p2) + out = self.block7(out) + + # tensor_stride=1 + out = self.convtr7p2s2(out) + out = self.bntr7(out) + out = self.relu(out) + + out = ME.cat(out, out_p1) + out = self.block8(out) + + return self.final(out) + + +class MinkUNet14(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (1, 1, 1, 1, 1, 1, 1, 1) + + +class MinkUNet18(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (2, 2, 2, 2, 2, 2, 2, 2) + + +class MinkUNet34(MinkUNetBase): + BLOCK = BasicBlock + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +class MinkUNet50(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 6, 2, 2, 2, 2) + + +class MinkUNet101(MinkUNetBase): + BLOCK = Bottleneck + LAYERS = (2, 3, 4, 23, 2, 2, 2, 2) + + +class MinkUNet14A(MinkUNet14): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +class MinkUNet14B(MinkUNet14): + PLANES = (32, 64, 128, 256, 128, 128, 128, 128) + + +class MinkUNet14C(MinkUNet14): + PLANES = (32, 64, 128, 256, 192, 192, 128, 128) + + +class MinkUNet14D(MinkUNet14): + PLANES = (32, 64, 128, 256, 384, 384, 384, 384) + + +class MinkUNet18A(MinkUNet18): + PLANES = (32, 64, 128, 256, 128, 128, 96, 96) + + +class MinkUNet18B(MinkUNet18): + PLANES = (32, 64, 128, 256, 128, 128, 128, 128) + + +class MinkUNet18D(MinkUNet18): + PLANES = (32, 64, 128, 256, 384, 384, 384, 384) + + +class MinkUNet34A(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 64) + + +class MinkUNet34B(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 64, 32) + + +class MinkUNet34C(MinkUNet34): + PLANES = (32, 64, 128, 256, 256, 128, 96, 96) diff --git a/torch-points3d/torch_points3d/modules/MinkowskiEngine/senet_block.py b/torch-points3d/torch_points3d/modules/MinkowskiEngine/senet_block.py new file mode 100644 index 0000000..5692d34 --- /dev/null +++ b/torch-points3d/torch_points3d/modules/MinkowskiEngine/senet_block.py @@ -0,0 +1,147 @@ +# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu). +# +# Permission is hereby granted, free of charge, to any person obtaining a copy of +# this software and associated documentation files (the "Software"), to deal in +# the Software without restriction, including without limitation the rights to +# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies +# of the Software, and to permit persons to whom the Software is furnished to do +# so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural +# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part +# of the code. +import torch +import torch.nn as nn + +import MinkowskiEngine as ME +from torch.cuda.amp import custom_fwd + +from .resnet_block import BasicBlock, Bottleneck + + +class SELayer(nn.Module): + + def __init__(self, channel, act_fn, reduction=16, dimension=-1): + # Global coords does not require coords_key + super(SELayer, self).__init__() + self.fc = nn.Sequential( + ME.MinkowskiLinear(channel, channel // reduction), + act_fn, + ME.MinkowskiLinear(channel // reduction, channel), + ME.MinkowskiSigmoid()) + self.pooling = ME.MinkowskiGlobalPooling() + self.broadcast_mul = ME.MinkowskiBroadcastMultiplication() + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, x): + y = self.pooling(x) + y = self.fc(y) + return self.broadcast_mul(x, y) + + +class SEBasicBlock(BasicBlock): + + def __init__(self, + inplanes, + planes, + act_fn, + norm_layer, + stride=1, + dilation=1, + downsample=None, + reduction=16, + drop_path=0.0, + bias: bool = True, + dimension=-1): + super(SEBasicBlock, self).__init__( + inplanes, + planes, + act_fn, + norm_layer, + stride=stride, + dilation=dilation, + downsample=downsample, + drop_path=drop_path, + bias=bias, + dimension=dimension) + self.se = SELayer(planes, act_fn, reduction=reduction, dimension=dimension) + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.se(out) + + residual = self.downsample(residual) + + out = self.drop_path(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + + def __init__(self, + inplanes, + planes, + act_fn, + norm_layer, + stride=1, + dilation=1, + downsample=None, + dimension=-1, + drop_path=0.0, + bias: bool = True, + reduction=16): + super(SEBottleneck, self).__init__( + inplanes, + planes, + act_fn, + norm_layer, + stride=stride, + dilation=dilation, + downsample=downsample, + drop_path=drop_path, + bias=bias, + dimension=dimension) + self.se = SELayer(planes * self.expansion, act_fn, reduction=reduction, dimension=dimension) + + @custom_fwd(cast_inputs=torch.float32) + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.norm3(out) + out = self.se(out) + + residual = self.downsample(residual) + + out = self.drop_path(out) + residual + out = self.relu(out) + + return out diff --git a/torch-points3d/torch_points3d/modules/__init__.py b/torch-points3d/torch_points3d/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch-points3d/torch_points3d/trainer.py b/torch-points3d/torch_points3d/trainer.py new file mode 100644 index 0000000..305a808 --- /dev/null +++ b/torch-points3d/torch_points3d/trainer.py @@ -0,0 +1,521 @@ +import copy +import logging +import os +import time +from contextlib import nullcontext + +import torch + +torch.multiprocessing.set_sharing_strategy('file_system') +import torch.autograd.profiler +# PyTorch Profiler import +import torch.profiler +from omegaconf import ListConfig +from torch import nn + +from torch_points3d.datasets.base_dataset import BaseDataset +# Import building function for model and dataset +from torch_points3d.datasets.dataset_factory import instantiate_dataset +# Import from metrics +from torch_points3d.metrics.base_tracker import BaseTracker +from torch_points3d.metrics.colored_tqdm import Coloredtqdm as Ctq +from torch_points3d.metrics.model_checkpoint import ModelCheckpoint +# Import BaseModel / BaseDataset for type checking +from torch_points3d.models.base_model import BaseModel +from torch_points3d.models.model_factory import instantiate_model +# Utils import +from torch_points3d.utils.colors import COLORS +from torch_points3d.utils.wandb_utils import Wandb +from torch_points3d.visualization import Visualizer + +log = logging.getLogger(__name__) + + +class Trainer: + """ + TorchPoints3d Trainer handles the logic between + - BaseModel, + - Dataset and its Tracker + - A custom ModelCheckpoint + - A custom Visualizer + It supports MC dropout - multiple voting_runs for val / test datasets + """ + + def __init__(self, cfg): + self._cfg = cfg + self._initialize_trainer() + + def _initialize_trainer(self): + if not self.has_training: + self._cfg.training = self._cfg + resume = bool(self._cfg.checkpoint_dir) + else: + resume = bool(self._cfg.training.checkpoint_dir) + + # Enable CUDNN BACKEND + torch.backends.cudnn.enabled = self.enable_cudnn + + # Get device + self._multi_gpu = False + if (isinstance(self._cfg.training.cuda, ListConfig) or self._cfg.training.cuda > -1) \ + and torch.cuda.is_available(): + device = "cuda" + if isinstance(self._cfg.training.cuda, ListConfig): + self._multi_gpu = True + else: + device = "cpu" + self._device = torch.device(device) + log.info("DEVICE : {}".format(self._device)) + + # Profiling + if self.profiling: + # Set the num_workers as torch.utils.bottleneck doesn't work well with it + self._cfg.training.num_workers = 0 + + # Start Wandb if public + if self.wandb_log: + Wandb.launch(self._cfg, self._cfg.training.wandb.public and self.wandb_log) + + # Checkpoint + + self._checkpoint: ModelCheckpoint = ModelCheckpoint( + self._cfg.training.checkpoint_dir, + self._cfg.model_name, + self._cfg.training.weight_name, + run_config=self._cfg, + resume=resume, + resume_opt=self._cfg.get("resume_opt", None) + ) + + # Create model and datasets + # always freshly init dataset instead of using checkpoint cfg + self._dataset: BaseDataset = instantiate_dataset(self._cfg.data) + if not self._checkpoint.is_empty: + self._model: BaseModel = self._checkpoint.create_model( + self._dataset, weight_name=self._cfg.training.weight_name + ) + self._model = self._model.to(self._device) + if self.has_training: + if self._model.optimizer is None and self.has_training: + self._model.init_optim(self._cfg) + else: + # workaround for https://github.com/pytorch/pytorch/issues/80809 for pytorch 1.12 + opt_dict = self._model.optimizer.state_dict() + + base_lr = self._cfg.training.optim.base_lr + + def nested_iter(dict_obj): + for k, v in dict_obj.items(): + if isinstance(v, dict): + nested_iter(v) + elif isinstance(v, list): + nested_iter((dict(zip(['list_' + str(i) for i in range(len(v))], v)))) + else: + if 'step' in k: + try: + if isinstance(v, torch.Tensor): + tst = v.cpu() + assert torch.all(tst == v) + else: + tst = v + except: + pass + dict_obj[k] = tst + elif "initial_lr" in k: + dict_obj[k] = base_lr + + nested_iter(opt_dict) + self._model.optimizer.load_state_dict(opt_dict) + if len(self._model.schedulers) == 0: + self._model.init_schedulers(self._cfg) + else: + self._checkpoint._checkpoint.load_optim_sched(self._model, load_state=self._checkpoint._resume_opt) + if self._model.grad_scale is None: + self._model.init_grad_scaler(self._cfg) + else: + self._checkpoint._checkpoint.load_grad_scale(self._model, load_state=self._checkpoint._resume_opt) + else: + self._model: BaseModel = instantiate_model(copy.deepcopy(self._cfg), self._dataset) + self._model.init_train_objects(self._cfg) + self._model.set_pretrained_weights() + if not self._checkpoint.validate(self._dataset.used_properties): + log.warning( + "The model will not be able to be used from pretrained weights without the corresponding dataset. Current properties are {}".format( + self._dataset.used_properties + ) + ) + self._model = self._model.to(self._device) + + if self._multi_gpu: + self._model.model = nn.DataParallel(self._model.model, device_ids=self._cfg.training.cuda) + + self._checkpoint.dataset_properties = self._dataset.used_properties + + log.info(self._model) + + self._model.log_optimizers() + log.info("Model size = %i", sum(param.numel() for param in self._model.parameters() if param.requires_grad)) + + # Set dataloaders + self._dataset.create_dataloaders( + self._model, + self._cfg.training.batch_size, + self._cfg.training.get("shuffle", True), + self._cfg.training.get("drop_last", True), + self._cfg.training.num_workers, + self.precompute_multi_scale, + ) + log.info(self._dataset) + + # Verify attributes in dataset + if self._dataset.has_train_loader: + dataset = self._dataset.train_dataset[0] + elif self._dataset.has_val_loader: + dataset = self._dataset.val_dataset[0] + else: + dataset = self._dataset.test_dataset[0] + + self._model.verify_data(dataset) + + # Choose selection stage + selection_stage = getattr(self._cfg, "selection_stage", "") + self._checkpoint.selection_stage = self._dataset.resolve_saving_stage(selection_stage) + self._tracker: BaseTracker = self._dataset.get_tracker(self.wandb_log, self.tensorboard_log) + + if self.wandb_log: + Wandb.launch(self._cfg, not self._cfg.training.wandb.public and self.wandb_log) + + # Run training / evaluation + if self.has_visualization: + self._visualizer = Visualizer( + self._cfg.visualization, self._dataset.num_batches, self._dataset.batch_size, os.getcwd(), self._tracker + ) + + def train(self): + self._is_training = True + + for epoch in range(self._checkpoint.start_epoch, self._cfg.training.epochs): + log.info("EPOCH %i / %i", epoch, self._cfg.training.epochs) + + self._train_epoch(epoch) + + if self.profiling: + return 0 + + if epoch % self.eval_frequency != 0: + continue + + if self._dataset.has_val_loader: + self._test_epoch(epoch, "val") + + if self._dataset.has_test_loader: + self._test_epoch(epoch, "test") + + # Single test evaluation in resume case + if self._checkpoint.start_epoch > self._cfg.training.epochs: + if self._dataset.has_test_loader: + self._test_epoch(epoch, "test") + + def eval(self, stage_name): + self._is_training = False + + epoch = self._checkpoint.start_epoch + if getattr(self._dataset, f"has_{stage_name}_loader"): + self._test_epoch(epoch, stage_name) + metrics = self._tracker.get_publish_metrics(epoch) + self._tracker.publish_metrics(metrics["all_metrics"], epoch) + else: + log.warning(f"No {stage_name} dataset") + + def iterate_epochs(self, epochs: int): + self._is_training = True + + for epoch in range(1, epochs + 1): + self._iterate_epoch(epoch, is_train=True) + + def _iterate_epoch(self, epoch: int, is_train: bool): + + if is_train: + self._model.train() + else: + self._model.eval() + self._tracker.reset("train") + self._visualizer.reset(epoch, "train") + train_loader = self._dataset.train_dataloader + + with self.profiler_profile(epoch) as prof: + iter_data_time = time.time() + with Ctq(train_loader) as tq_train_loader: + for i, data in enumerate(tq_train_loader): + t_data = time.time() - iter_data_time + iter_start_time = time.time() + + with self.profiler_record_function('train_step'): + with torch.no_grad(): + self._model.set_input(data, self._device) + # enable autocasting if supported + with torch.cuda.amp.autocast(enabled=self._model.is_mixed_precision()): + self._model(epoch=epoch) # iterate through model + + with self.profiler_record_function('track/log/visualize'): + if i % 10 == 0: + with torch.no_grad(): + self._tracker.track(self._model, data=data, **self.tracker_options) + + tq_train_loader.set_postfix( + **self._tracker.get_loss(), + data_loading=float(t_data), + iteration=float(time.time() - iter_start_time), + color=COLORS.TRAIN_COLOR + ) + + if self._visualizer.is_active: + self._visualizer.save_visuals(self._model, train_loader) + + iter_data_time = time.time() + + if self.pytorch_profiler_log: + prof.step() + + if self.early_break: + break + + self._finalize_epoch(epoch, ) + + def _finalize_epoch(self, epoch): + self._tracker.finalise(**self.tracker_options) + if self._is_training: + metrics = self._tracker.get_publish_metrics(epoch) + p_metrics = self._checkpoint.save_best_models_under_current_metrics( + self._model, metrics, self._tracker.metric_func, self.wandb_log + ) + self._tracker.publish_metrics(p_metrics, epoch) + + if self.wandb_log and self._cfg.training.wandb.public: + Wandb.add_file(self._checkpoint.checkpoint_path) + if self._tracker._stage == "train": + log.info("Learning rate = %f" % self._model.learning_rate) + else: + if self.has_visualization: + + if self._tracker._stage == "test": + loaders = self._dataset.test_dataloaders + elif self._tracker._stage == "val": + loaders = [self._dataset.val_dataloader] + elif self._tracker._stage == "train": + loaders = [self._dataset.train_dataloader] + + self._visualizer.finalize_epoch(loaders) + + def _train_epoch(self, epoch: int): + + self._model.train() + self._tracker.reset("train") + self._visualizer.reset(epoch, "train") + train_loader = self._dataset.train_dataloader + n_batches = len(train_loader) + + with self.profiler_profile(epoch) as prof: + iter_data_time = time.time() + with Ctq(train_loader) as tq_train_loader: + for i, data in enumerate(tq_train_loader): + t_data = time.time() - iter_data_time + iter_start_time = time.time() + + with self.profiler_record_function('train_step'): + self._model.set_input(data, self._device) + self._model.optimize_parameters( + epoch, self._dataset.batch_size, + self._dataset.num_batches[self._dataset.train_dataset.name] + ) + + with self.profiler_record_function('track/log/visualize'): + if i % 10 == 0: + with torch.no_grad(): + self._tracker.track(self._model, data=data, **self.tracker_options) + + tq_train_loader.set_postfix( + **self._tracker.get_loss(), + data_loading=float(t_data), + iteration=float(time.time() - iter_start_time), + color=COLORS.TRAIN_COLOR + ) + + if self._visualizer.is_active: + self._visualizer.save_visuals(self._model, train_loader) + + iter_data_time = time.time() + + if self.pytorch_profiler_log: + prof.step() + + if self.early_break: + break + + if self.profiling: + if i > self.num_batches: + return 0 + + self._finalize_epoch(epoch) + + def _test_epoch(self, epoch, stage_name: str): + voting_runs = self._cfg.get("voting_runs", 1) + if stage_name == "test": + loaders = self._dataset.test_dataloaders + elif stage_name == "val": + loaders = [self._dataset.val_dataloader] + elif stage_name == "train": + loaders = [self._dataset.train_dataloader] + else: + raise NotImplemented(f"The following stage is not implemented: {stage_name}") + + self._model.eval() + if self.enable_dropout: + self._model.enable_dropout_in_eval() + + if self.enable_bn: + self._model.enable_bn_in_eval() + + for loader in loaders: + stage_name = loader.dataset.name + self._tracker.reset(stage_name) + if self.has_visualization: + self._visualizer.reset(epoch, stage_name) + if not self._dataset.has_labels(stage_name) and not self.tracker_options.get( + "make_submission", False + ): # No label, no submission -> do nothing + log.warning("No forward will be run on dataset %s." % stage_name) + continue + + with self.profiler_profile(epoch) as prof: + for i in range(voting_runs): + with Ctq(loader) as tq_loader: + for data in tq_loader: + with torch.no_grad(): + with self.profiler_record_function('test_step'): + self._model.set_input(data, self._device) + with torch.cuda.amp.autocast(enabled=self._model.is_mixed_precision()): + self._model.forward(epoch=epoch) + + with self.profiler_record_function('track/log/visualize'): + self._tracker.track(self._model, data=data, **self.tracker_options) + tq_loader.set_postfix(**self._tracker.get_loss(), color=COLORS.TEST_COLOR) + + if self.has_visualization and self._visualizer.is_active: + self._visualizer.save_visuals(self._model, loader) + + if self.pytorch_profiler_log: + prof.step() + + if self.early_break: + break + + if self.profiling: + if i > self.num_batches: + return 0 + + self._finalize_epoch(epoch) + self._tracker.print_summary() + + @property + def early_break(self): + return getattr(self._cfg.debugging, "early_break", False) and self._is_training + + @property + def profiling(self): + return getattr(self._cfg.debugging, "profiling", False) + + @property + def num_batches(self): + return getattr(self._cfg.debugging, "num_batches", 50) + + @property + def enable_cudnn(self): + return getattr(self._cfg.training, "enable_cudnn", True) + + @property + def enable_dropout(self): + return getattr(self._cfg, "enable_dropout", False) + + @property + def enable_bn(self): + return getattr(self._cfg, "enable_bn", False) + + @property + def has_visualization(self): + return getattr(self._cfg, "visualization", False) + + @property + def has_tensorboard(self): + return getattr(self._cfg.training, "tensorboard", False) + + _has_training = None + + @property + def has_training(self): + if self._has_training is None: + self._has_training = getattr(self._cfg, "training", None) is not None + return self._has_training + + @property + def precompute_multi_scale(self): + return self._model.conv_type == "PARTIAL_DENSE" and getattr(self._cfg.training, "precompute_multi_scale", False) + + @property + def wandb_log(self): + if getattr(self._cfg.training, "wandb", False): + return getattr(self._cfg.training.wandb, "log", False) + else: + return False + + @property + def tensorboard_log(self): + if self.has_tensorboard: + return getattr(self._cfg.training.tensorboard, "log", False) + else: + return False + + @property + def pytorch_profiler_log(self): + if self.tensorboard_log: + if getattr(self._cfg.training.tensorboard, "pytorch_profiler", False): + return getattr(self._cfg.training.tensorboard.pytorch_profiler, "log", False) + return False + + # pyTorch Profiler + def profiler_profile(self, epoch): + if (self.pytorch_profiler_log and ( + getattr(self._cfg.training.tensorboard.pytorch_profiler, "nb_epoch", 3) == 0 or epoch <= getattr( + self._cfg.training.tensorboard.pytorch_profiler, "nb_epoch", 3))): + return torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA] \ + if isinstance(self._cfg.training.cuda, ListConfig) or self._cfg.training.cuda > -1 else \ + [torch.profiler.ProfilerActivity.CPU], + schedule=torch.profiler.schedule( + skip_first=getattr(self._cfg.training.tensorboard.pytorch_profiler, "skip_first", 10), + wait=getattr(self._cfg.training.tensorboard.pytorch_profiler, "wait", 5), + warmup=getattr(self._cfg.training.tensorboard.pytorch_profiler, "warmup", 3), + active=getattr(self._cfg.training.tensorboard.pytorch_profiler, "active", 5), + repeat=getattr(self._cfg.training.tensorboard.pytorch_profiler, "repeat", 0)), + on_trace_ready=torch.profiler.tensorboard_trace_handler(self._tracker._tensorboard_dir), + record_shapes=getattr(self._cfg.training.tensorboard.pytorch_profiler, "record_shapes", True), + profile_memory=getattr(self._cfg.training.tensorboard.pytorch_profiler, "profile_memory", True), + with_stack=getattr(self._cfg.training.tensorboard.pytorch_profiler, "with_stack", True), + with_flops=getattr(self._cfg.training.tensorboard.pytorch_profiler, "with_flops", True) + ) + else: + return nullcontext(type('', (), {"step": lambda self: None})()) + + def profiler_record_function(self, name: str): + if self.pytorch_profiler_log: + return torch.autograd.profiler.record_function(name) + else: + return nullcontext() + + @property + def tracker_options(self): + return self._cfg.get("tracker_options", {}) + + @property + def eval_frequency(self): + return self._cfg.get("eval_frequency", 1) diff --git a/torch-points3d/torch_points3d/utils/__init__.py b/torch-points3d/torch_points3d/utils/__init__.py new file mode 100644 index 0000000..7d71942 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/__init__.py @@ -0,0 +1,6 @@ +from .colors import * +from .config import * +from .enums import * +from .running_stats import * +from .timer import * +from .transform_utils import * diff --git a/torch-points3d/torch_points3d/utils/box_utils.py b/torch-points3d/torch_points3d/utils/box_utils.py new file mode 100644 index 0000000..4131022 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/box_utils.py @@ -0,0 +1,236 @@ +import torch +import numpy as np +from scipy.spatial import ConvexHull + +from .geometry import euler_angles_to_rotation_matrix + + +def box_corners_from_param(box_size, heading_angle, center): + """ Generates box corners from a parameterised box. + box_size is array(size_x,size_y,size_z), heading_angle is radius clockwise from pos x axis, center is xyz of box center + output (8,3) array for 3D box corners + """ + R = euler_angles_to_rotation_matrix(torch.tensor([0.0, 0.0, float(heading_angle)])) + if torch.is_tensor(box_size): + box_size = box_size.float() + l, w, h = box_size + x_corners = torch.tensor([-l / 2, l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2]) + y_corners = torch.tensor([-w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2, w / 2]) + z_corners = torch.tensor([-h / 2, -h / 2, -h / 2, -h / 2, h / 2, h / 2, h / 2, h / 2]) + corners_3d = R @ torch.stack([x_corners, y_corners, z_corners]) + corners_3d[0, :] = corners_3d[0, :] + center[0] + corners_3d[1, :] = corners_3d[1, :] + center[1] + corners_3d[2, :] = corners_3d[2, :] + center[2] + corners_3d = corners_3d.T + return corners_3d + + +def nms_samecls(boxes, classes, scores, overlap_threshold=0.25): + """ Returns the list of boxes that are kept after nms. + A box is suppressed only if it overlaps with + another box of the same class that has a higher score + + Parameters + ---------- + boxes : [num_boxes, 6] + xmin, ymin, zmin, xmax, ymax, zmax + classes : [num_shapes] + Class of each box + scores : [num_shapes,] + score of each box + overlap_threshold : float, optional + [description], by default 0.25 + """ + if torch.is_tensor(boxes): + boxes = boxes.cpu().numpy() + if torch.is_tensor(scores): + scores = scores.cpu().numpy() + if torch.is_tensor(classes): + classes = classes.cpu().numpy() + + x1 = boxes[:, 0] + y1 = boxes[:, 1] + z1 = boxes[:, 2] + x2 = boxes[:, 3] + y2 = boxes[:, 4] + z2 = boxes[:, 5] + area = (x2 - x1) * (y2 - y1) * (z2 - z1) + + I = np.argsort(scores) + pick = [] + while I.size != 0: + last = I.size + i = I[-1] + pick.append(i) + + xx1 = np.maximum(x1[i], x1[I[: last - 1]]) + yy1 = np.maximum(y1[i], y1[I[: last - 1]]) + zz1 = np.maximum(z1[i], z1[I[: last - 1]]) + xx2 = np.minimum(x2[i], x2[I[: last - 1]]) + yy2 = np.minimum(y2[i], y2[I[: last - 1]]) + zz2 = np.minimum(z2[i], z2[I[: last - 1]]) + cls1 = classes[i] + cls2 = classes[I[: last - 1]] + + l = np.maximum(0, xx2 - xx1) + w = np.maximum(0, yy2 - yy1) + h = np.maximum(0, zz2 - zz1) + + inter = l * w * h + o = inter / (area[i] + area[I[: last - 1]] - inter) + o = o * (cls1 == cls2) + + I = np.delete(I, np.concatenate(([last - 1], np.where(o > overlap_threshold)[0]))) + + return pick + + +def box3d_iou(corners1, corners2): + """ Compute 3D bounding box IoU. + + Input: + corners1: array (8,3), assume up direction is Z + corners2: array (8,3), assume up direction is Z + Output: + iou: 3D bounding box IoU + """ + # corner points are in counter clockwise order + assert corners1.shape == (8, 3) + assert corners2.shape == (8, 3) + rect1 = np.asarray([(corners1[i, 0], corners1[i, 1]) for i in range(4)]) + rect2 = np.asarray([(corners2[i, 0], corners2[i, 1]) for i in range(4)]) + inter_area = intersection_area(rect1, rect2) + z_min = max(corners1[0, 2], corners2[0, 2]) + z_max = min(corners1[4, 2], corners2[4, 2]) + inter_vol = inter_area * max(0.0, z_max - z_min) + vol1 = box3d_vol(corners1) + vol2 = box3d_vol(corners2) + iou = inter_vol / (vol1 + vol2 - inter_vol) + return iou + + +def box3d_vol(corners): + """ corners: (8,3). No order required""" + corners = np.asarray(corners) + a = np.sqrt(np.sum((corners[0, :] - corners[1, :]) ** 2)) + b = np.sqrt(np.sum((corners[1, :] - corners[2, :]) ** 2)) + c = np.sqrt(np.sum((corners[0, :] - corners[4, :]) ** 2)) + return a * b * c + + +def intersection_area(p1, p2): + """ Compute area of two convex hull's intersection area. + p1,p2 are a list of (x,y) tuples of hull vertices. + return intersection volume + """ + assert len(p1[0]) == 2 and len(p2[0]) == 2 + inter_p = polygon_clip(p1, p2) + if inter_p is not None: + hull_inter = ConvexHull(inter_p) + return hull_inter.volume + else: + return 0.0 + + +def polygon_clip(subjectPolygon, clipPolygon): + """ Clip a polygon with another polygon. + + Ref: https://rosettacode.org/wiki/Sutherland-Hodgman_polygon_clipping#Python + + Args: + subjectPolygon: a list of (x,y) 2d points, any polygon. + clipPolygon: a list of (x,y) 2d points, has to be *convex* + Note: + **points have to be counter-clockwise ordered** + + Return: + a list of (x,y) vertex point for the intersection polygon. + """ + + def inside(p): + return (cp2[0] - cp1[0]) * (p[1] - cp1[1]) > (cp2[1] - cp1[1]) * (p[0] - cp1[0]) + + def computeIntersection(): + dc = [cp1[0] - cp2[0], cp1[1] - cp2[1]] + dp = [s[0] - e[0], s[1] - e[1]] + n1 = cp1[0] * cp2[1] - cp1[1] * cp2[0] + n2 = s[0] * e[1] - s[1] * e[0] + n3 = 1.0 / (dc[0] * dp[1] - dc[1] * dp[0]) + return [(n1 * dp[0] - n2 * dc[0]) * n3, (n1 * dp[1] - n2 * dc[1]) * n3] + + outputList = subjectPolygon + cp1 = clipPolygon[-1] + + for clipVertex in clipPolygon: + cp2 = clipVertex + inputList = outputList + outputList = [] + s = inputList[-1] + + for subjectVertex in inputList: + e = subjectVertex + if inside(e): + if not inside(s): + outputList.append(computeIntersection()) + outputList.append(e) + elif inside(s): + outputList.append(computeIntersection()) + s = e + cp1 = cp2 + if len(outputList) == 0: + return None + return outputList + + +################################################################################################ +# Intersection area without scipy. Could be used with numba +################################################################################################ + + +def intersection_area_noscipy(p1, p2): + """ Compute area of two convex hull's intersection area. + p1,p2 are a list of (x,y) tuples of hull vertices. + return intersection volume + """ + assert len(p1[0]) == 2 and len(p2[0]) == 2 + inter_p = polygon_clip(p1, p2) + if inter_p is not None: + hull_inter = np.asarray(convex_hull_graham(inter_p)) + area = polygon_area(hull_inter[:, 0], hull_inter[:, 1]) + return area + else: + return 0.0 + + +# Function to know if we have a CCW turn +def RightTurn(p1, p2, p3): + if (p3[1] - p1[1]) * (p2[0] - p1[0]) >= (p2[1] - p1[1]) * (p3[0] - p1[0]): + return False + return True + + +# Main algorithm: +def convex_hull_graham(P): + P.sort() # Sort the set of points + L_upper = [P[0], P[1]] # Initialize upper part + # Compute the upper part of the hull + for i in range(2, len(P)): + L_upper.append(P[i]) + while len(L_upper) > 2 and not RightTurn(L_upper[-1], L_upper[-2], L_upper[-3]): + del L_upper[-2] + L_lower = [P[-1], P[-2]] # Initialize the lower part + # Compute the lower part of the hull + for i in range(len(P) - 3, -1, -1): + L_lower.append(P[i]) + while len(L_lower) > 2 and not RightTurn(L_lower[-1], L_lower[-2], L_lower[-3]): + del L_lower[-2] + del L_lower[0] + del L_lower[-1] + L = L_upper + L_lower # Build the full hull + return L + + +def polygon_area(x, y): + correction = x[-1] * y[0] - y[-1] * x[0] + main_area = np.dot(x[:-1], y[1:]) - np.dot(y[:-1], x[1:]) + return 0.5 * np.abs(main_area + correction) diff --git a/torch-points3d/torch_points3d/utils/colors.py b/torch-points3d/torch_points3d/utils/colors.py new file mode 100644 index 0000000..d153361 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/colors.py @@ -0,0 +1,89 @@ +import logging + +log = logging.getLogger(__name__) + + +class COLORS: + """[This class is used to color the bash shell by using {} {} {} with 'COLORS.{}, text, COLORS.END_TOKEN'] + """ + + TRAIN_COLOR = "\033[0;92m" + VAL_COLOR = "\033[0;94m" + TEST_COLOR = "\033[0;93m" + BEST_COLOR = "\033[0;92m" + + END_TOKEN = "\033[0m)" + END_NO_TOKEN = "\033[0m" + + Black = "\033[0;30m" # Black + Red = "\033[0;31m" # Red + Green = "\033[0;32m" # Green + Yellow = "\033[0;33m" # Yellow + Blue = "\033[0;34m" # Blue + Purple = "\033[0;35m" # Purple + Cyan = "\033[0;36m" # Cyan + White = "\033[0;37m" # White + + # Bold + BBlack = "\033[1;30m" # Black + BRed = "\033[1;31m" # Red + BGreen = "\033[1;32m" # Green + BYellow = "\033[1;33m" # Yellow + BBlue = "\033[1;34m" # Blue + BPurple = "\033[1;35m" # Purple + BCyan = "\033[1;36m" # Cyan + BWhite = "\033[1;37m" # White + + # Underline + UBlack = "\033[4;30m" # Black + URed = "\033[4;31m" # Red + UGreen = "\033[4;32m" # Green + UYellow = "\033[4;33m" # Yellow + UBlue = "\033[4;34m" # Blue + UPurple = "\033[4;35m" # Purple + UCyan = "\033[4;36m" # Cyan + UWhite = "\033[4;37m" # White + + # Background + On_Black = "\033[40m" # Black + On_Red = "\033[41m" # Red + On_Green = "\033[42m" # Green + On_Yellow = "\033[43m" # Yellow + On_Blue = "\033[44m" # Blue + On_Purple = "\033[45m" # Purple + On_Cyan = "\033[46m" # Cyan + On_White = "\033[47m" # White + + # High Intensty + IBlack = "\033[0;90m" # Black + IRed = "\033[0;91m" # Red + IGreen = "\033[0;92m" # Green + IYellow = "\033[0;93m" # Yellow + IBlue = "\033[0;94m" # Blue + IPurple = "\033[0;95m" # Purple + ICyan = "\033[0;96m" # Cyan + IWhite = "\033[0;97m" # White + + # Bold High Intensty + BIBlack = "\033[1;90m" # Black + BIRed = "\033[1;91m" # Red + BIGreen = "\033[1;92m" # Green + BIYellow = "\033[1;93m" # Yellow + BIBlue = "\033[1;94m" # Blue + BIPurple = "\033[1;95m" # Purple + BICyan = "\033[1;96m" # Cyan + BIWhite = "\033[1;97m" # White + + # High Intensty backgrounds + On_IBlack = "\033[0;100m" # Black + On_IRed = "\033[0;101m" # Red + On_IGreen = "\033[0;102m" # Green + On_IYellow = "\033[0;103m" # Yellow + On_IBlue = "\033[0;104m" # Blue + On_IPurple = "\033[10;95m" # Purple + On_ICyan = "\033[0;106m" # Cyan + On_IWhite = "\033[0;107m" # White + + +def colored_print(color, msg): + log.info(color + msg + COLORS.END_NO_TOKEN) diff --git a/torch-points3d/torch_points3d/utils/config.py b/torch-points3d/torch_points3d/utils/config.py new file mode 100644 index 0000000..06c4377 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/config.py @@ -0,0 +1,75 @@ +import numpy as np +from typing import List +import shutil +import matplotlib.pyplot as plt +import os +from os import path as osp +import torch +import logging +from collections import namedtuple +from omegaconf import OmegaConf +from omegaconf.listconfig import ListConfig +from omegaconf.dictconfig import DictConfig +from .enums import ConvolutionFormat +from torch_points3d.utils.debugging_vars import DEBUGGING_VARS +from torch_points3d.utils.colors import COLORS, colored_print +import subprocess + +log = logging.getLogger(__name__) + + +class ConvolutionFormatFactory: + @staticmethod + def check_is_dense_format(conv_type): + if ( + conv_type.lower() == ConvolutionFormat.PARTIAL_DENSE.value.lower() + or conv_type.lower() == ConvolutionFormat.MESSAGE_PASSING.value.lower() + or conv_type.lower() == ConvolutionFormat.SPARSE.value.lower() + ): + return False + elif conv_type.lower() == ConvolutionFormat.DENSE.value.lower(): + return True + else: + raise NotImplementedError("Conv type {} not supported".format(conv_type)) + + +class Option: + """This class is used to enable accessing arguments as attributes without having OmaConf. + It is used along convert_to_base_obj function + """ + + def __init__(self, opt): + for key, value in opt.items(): + setattr(self, key, value) + + +def convert_to_base_obj(opt): + return Option(OmegaConf.to_container(opt)) + + +def set_debugging_vars_to_global(cfg): + for key in cfg.keys(): + key_upper = key.upper() + if key_upper in DEBUGGING_VARS.keys(): + DEBUGGING_VARS[key_upper] = cfg[key] + log.info(DEBUGGING_VARS) + + +def is_list(entity): + return isinstance(entity, list) or isinstance(entity, ListConfig) + + +def is_iterable(entity): + return isinstance(entity, list) or isinstance(entity, ListConfig) or isinstance(entity, tuple) + + +def is_dict(entity): + return isinstance(entity, dict) or isinstance(entity, DictConfig) + + +def create_symlink_from_eval_to_train(eval_checkpoint_dir): + root = os.path.join(os.getcwd(), "evals") + if not os.path.exists(root): + os.makedirs(root) + num_files = len(os.listdir(root)) + 1 + os.symlink(eval_checkpoint_dir, os.path.join(root, "eval_{}".format(num_files))) diff --git a/torch-points3d/torch_points3d/utils/debugging_vars.py b/torch-points3d/torch_points3d/utils/debugging_vars.py new file mode 100644 index 0000000..41c582c --- /dev/null +++ b/torch-points3d/torch_points3d/utils/debugging_vars.py @@ -0,0 +1,48 @@ +import numpy as np + +DEBUGGING_VARS = {"FIND_NEIGHBOUR_DIST": False} + + +def extract_histogram(spatial_ops, normalize=True): + out = [] + for idx, nf in enumerate(spatial_ops["neighbour_finder"]): + dist_meters = nf.dist_meters + temp = {} + for dist_meter in dist_meters: + hist = dist_meter.histogram.copy() + if normalize: + hist /= hist.sum() + temp[str(dist_meter.radius)] = hist.tolist() + dist_meter.reset() + out.append(temp) + return out + + +class DistributionNeighbour(object): + def __init__(self, radius, bins=1000): + self._radius = radius + self._bins = bins + self._histogram = np.zeros(self._bins) + + def reset(self): + self._histogram = np.zeros(self._bins) + + @property + def radius(self): + return self._radius + + @property + def histogram(self): + return self._histogram + + @property + def histogram_non_zero(self): + idx = len(self._histogram) - np.cumsum(self._histogram[::-1]).nonzero()[0][0] + return self._histogram[:idx] + + def add_valid_neighbours(self, points): + for num_valid in points: + self._histogram[num_valid] += 1 + + def __repr__(self): + return "{}(radius={}, bins={})".format(self.__class__.__name__, self._radius, self._bins) diff --git a/torch-points3d/torch_points3d/utils/download.py b/torch-points3d/torch_points3d/utils/download.py new file mode 100644 index 0000000..de47a46 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/download.py @@ -0,0 +1,38 @@ +import os +import os.path as osp +from six.moves import urllib +import ssl + + +def download_url(url, folder, log=True): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (string): The url. + folder (string): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + + filename = url.rpartition("/")[2] + path = osp.join(folder, filename) + + if osp.exists(path): # pragma: no cover + if log: + print("Using exist file", filename) + return path + + if log: + print("Downloading", url) + + try: + os.makedirs(folder) + except: + pass + context = ssl._create_unverified_context() + data = urllib.request.urlopen(url, context=context) + + with open(path, "wb") as f: + f.write(data.read()) + + return path diff --git a/torch-points3d/torch_points3d/utils/enums.py b/torch-points3d/torch_points3d/utils/enums.py new file mode 100644 index 0000000..f9de27f --- /dev/null +++ b/torch-points3d/torch_points3d/utils/enums.py @@ -0,0 +1,14 @@ +import enum + + +class SchedulerUpdateOn(enum.Enum): + ON_EPOCH = "on_epoch" + ON_NUM_BATCH = "on_num_batch" + ON_NUM_SAMPLE = "on_num_sample" + + +class ConvolutionFormat(enum.Enum): + DENSE = "dense" + PARTIAL_DENSE = "partial_dense" + MESSAGE_PASSING = "message_passing" + SPARSE = "sparse" diff --git a/torch-points3d/torch_points3d/utils/geometry.py b/torch-points3d/torch_points3d/utils/geometry.py new file mode 100644 index 0000000..02a40b9 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/geometry.py @@ -0,0 +1,51 @@ +import torch +import random + + +def euler_angles_to_rotation_matrix(theta, random_order=False): + R_x = torch.tensor( + [[1, 0, 0], [0, torch.cos(theta[0]), -torch.sin(theta[0])], [0, torch.sin(theta[0]), torch.cos(theta[0])]] + ) + + R_y = torch.tensor( + [[torch.cos(theta[1]), 0, torch.sin(theta[1])], [0, 1, 0], [-torch.sin(theta[1]), 0, torch.cos(theta[1])]] + ) + + R_z = torch.tensor( + [[torch.cos(theta[2]), -torch.sin(theta[2]), 0], [torch.sin(theta[2]), torch.cos(theta[2]), 0], [0, 0, 1]] + ) + + matrices = [R_x, R_y, R_z] + if random_order: + random.shuffle(matrices) + R = torch.mm(matrices[2], torch.mm(matrices[1], matrices[0])) + return R + + +def get_cross_product_matrix(k): + return torch.tensor([[0, -k[2], k[1]], [k[2], 0, -k[0]], [-k[1], k[0], 0]], device=k.device) + + +def rodrigues(axis, theta): + """ + given an axis of norm one and an angle, compute the rotation matrix using rodrigues formula + source : https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula + """ + K = get_cross_product_matrix(axis) + t = torch.tensor([theta], device=axis.device) + R = torch.eye(3, device=axis.device) + torch.sin(t) * K + (1 - torch.cos(t)) * K.mm(K) + return R + + +def get_trans(x): + """ + get the rotation matrix from the vector representation using the rodrigues formula + """ + T = torch.eye(4, device=x.device) + T[:3, 3] = x[3:] + axis = x[:3] + theta = torch.norm(axis) + if theta > 0: + axis = axis / theta + T[:3, :3] = rodrigues(axis, theta) + return T diff --git a/torch-points3d/torch_points3d/utils/mock.py b/torch-points3d/torch_points3d/utils/mock.py new file mode 100644 index 0000000..b2c54a1 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/mock.py @@ -0,0 +1,74 @@ +import torch +from torch_geometric.data import Data, Batch + +from torch_points3d.datasets.batch import SimpleBatch +from torch_points3d.core.data_transform import MultiScaleTransform +from torch_points3d.datasets.multiscale_data import MultiScaleBatch + + +class MockDatasetConfig(object): + def __init__(self): + pass + + def keys(self): + return [] + + def get(self, dataset_name, default): + return None + + +class MockDataset(torch.utils.data.Dataset): + def __init__(self, feature_size=0, transform=None, num_points=100): + self.feature_dimension = feature_size + self.num_classes = 10 + self.num_points = num_points + self.batch_size = 2 + self.weight_classes = None + if feature_size > 0: + self._feature = torch.tensor([range(feature_size) for i in range(self.num_points)], dtype=torch.float,) + else: + self._feature = None + self._y = torch.tensor([0 for i in range(self.num_points)], dtype=torch.long) + self._category = torch.ones((self.num_points,), dtype=torch.long) + self._ms_transform = None + self._transform = transform + + def __len__(self): + return self.num_points + + def len(self): + return len(self) + + @property + def datalist(self): + torch.manual_seed(0) + torch.randn((self.num_points, 3)) + datalist = [ + Data(pos=torch.randn((self.num_points, 3)), x=self._feature, y=self._y, category=self._category) + for i in range(self.batch_size) + ] + if self._transform: + datalist = [self._transform(d.clone()) for d in datalist] + if self._ms_transform: + datalist = [self._ms_transform(d.clone()) for d in datalist] + return datalist + + def __getitem__(self, index): + return SimpleBatch.from_data_list(self.datalist) + + @property + def class_to_segments(self): + return {"class1": [0, 1, 2, 3, 4, 5], "class2": [6, 7, 8, 9]} + + def set_strategies(self, model): + strategies = model.get_spatial_ops() + transform = MultiScaleTransform(strategies) + self._ms_transform = transform + + +class MockDatasetGeometric(MockDataset): + def __getitem__(self, index): + if self._ms_transform: + return MultiScaleBatch.from_data_list(self.datalist) + else: + return Batch.from_data_list(self.datalist) diff --git a/torch-points3d/torch_points3d/utils/model_building_utils/activation_resolver.py b/torch-points3d/torch_points3d/utils/model_building_utils/activation_resolver.py new file mode 100644 index 0000000..e89b994 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/model_building_utils/activation_resolver.py @@ -0,0 +1,19 @@ +import torch.nn + +from torch_points3d.utils.config import is_dict + + +def get_activation(act_opt, create_cls=True): + if is_dict(act_opt): + act_opt = dict(act_opt) + act = getattr(torch.nn, act_opt["name"]) + del act_opt["name"] + args = dict(act_opt) + else: + act = getattr(torch.nn, act_opt) + args = {} + + if create_cls: + return act(**args) + else: + return act diff --git a/torch-points3d/torch_points3d/utils/model_building_utils/model_definition_resolver.py b/torch-points3d/torch_points3d/utils/model_building_utils/model_definition_resolver.py new file mode 100644 index 0000000..5d6118c --- /dev/null +++ b/torch-points3d/torch_points3d/utils/model_building_utils/model_definition_resolver.py @@ -0,0 +1,51 @@ +from omegaconf.dictconfig import DictConfig +from omegaconf.listconfig import ListConfig + + +def resolve_model(model_config, dataset, tested_task): + """ Parses the model config and evaluates any expression that may contain constants + """ + # placeholders to subsitute + constants = { + "FEAT": max(dataset.feature_dimension, 0), + "TASK": tested_task, + "N_CLS": dataset.num_classes if hasattr(dataset, "num_classes") else None, + } + + # user defined constants to substitute + if "define_constants" in model_config.keys(): + constants.update(dict(model_config.define_constants)) + + resolve(model_config, constants) + + +def resolve(obj, constants): + """ Resolves expressions and constants in obj. + returns False if obj is a ListConfig or DictConfig, True is obj is a primative type. + """ + if type(obj) == DictConfig: + it = (k for k in obj) + elif type(obj) == ListConfig: + it = range(len(obj)) + else: + # obj is a single element + return True + + # recursively resolve all children of obj + for k in it: + + # if obj[k] is a primitive type, evaluate it + if resolve(obj[k], constants): + if type(obj[k]) is str and obj[k] != "path_pretrained": + try: + obj[k] = eval(obj[k], constants) + except NameError: + # we tried to resolve a string which isn't an expression + pass + except ValueError: + # we tried to resolve a string which is also a builtin (e.g. max) + pass + except Exception as e: + print(e) + + return False diff --git a/torch-points3d/torch_points3d/utils/model_building_utils/resolver_utils.py b/torch-points3d/torch_points3d/utils/model_building_utils/resolver_utils.py new file mode 100644 index 0000000..dee5023 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/model_building_utils/resolver_utils.py @@ -0,0 +1,15 @@ +import collections + +# from https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys +# flattens nested dicts to a single dict, with keys concatenated +# e.g. flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}) +# {'a': 1, 'c_a': 2, 'c_b_x': 5, 'd': [1, 2, 3], 'c_b_y': 10} +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.abc.MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) diff --git a/torch-points3d/torch_points3d/utils/o3d_utils.py b/torch-points3d/torch_points3d/utils/o3d_utils.py new file mode 100644 index 0000000..86d2c47 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/o3d_utils.py @@ -0,0 +1,44 @@ +import open3d +import random + + +def get_random_color(pastel_factor=0.5): + return [(x + pastel_factor) / (1.0 + pastel_factor) for x in [random.uniform(0, 1.0) for i in [1, 2, 3]]] + + +def color_distance(c1, c2): + return sum([abs(x[0] - x[1]) for x in zip(c1, c2)]) + + +def generate_new_color(existing_colors, pastel_factor=0.5): + max_distance = None + best_color = None + for i in range(0, 100): + color = get_random_color(pastel_factor=pastel_factor) + if not existing_colors: + return color + best_distance = min([color_distance(color, c) for c in existing_colors]) + if not max_distance or best_distance > max_distance: + max_distance = best_distance + best_color = color + return best_color + + +def torch2o3d(data, color=[1, 0, 0]): + xyz = data.pos + norm = getattr(data, "norm", None) + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(xyz.detach().cpu().numpy()) + if norm is not None: + pcd.normals = open3d.utility.Vector3dVector(norm.detach().cpu().numpy()) + pcd.paint_uniform_color(color) + return pcd + + +def apply_mask(d, mask, skip_keys=[]): + data = d.clone() + size_pos = len(data.pos) + for k in data.keys: + if size_pos == len(data[k]) and k not in skip_keys: + data[k] = data[k][mask] + return data diff --git a/torch-points3d/torch_points3d/utils/registration.py b/torch-points3d/torch_points3d/utils/registration.py new file mode 100644 index 0000000..d75dc89 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/registration.py @@ -0,0 +1,163 @@ +""" +registration toolbox (algorithm for some registration algorithm) +Implemented: fast_global_registration +teaser +""" +import open3d +import numpy as np +import torch +from torch_points3d.utils.geometry import get_trans +from torch_geometric.nn import knn + + +def get_matches(feat_source, feat_target, sym=False): + + matches = knn(feat_target, feat_source, k=1).T + if sym: + match_inv = knn(feat_source, feat_target, k=1).T + mask = match_inv[matches[:, 1], 1] == torch.arange(matches.shape[0], device=feat_source.device) + return matches[mask] + else: + return matches + + +def estimate_transfo(xyz, xyz_target): + """ + estimate the rotation and translation using Kabsch algorithm + Parameters: + xyz : + xyz_target: + """ + assert xyz.shape == xyz.shape + xyz_c = xyz - xyz.mean(0) + xyz_target_c = xyz_target - xyz_target.mean(0) + Q = xyz_c.T.mm(xyz_target_c) / len(xyz) + U, S, V = torch.svd(Q) + d = torch.det(V.mm(U.T)) + diag = torch.diag(torch.tensor([1, 1, d], device=xyz.device)) + R = V.mm(diag).mm(U.T) + t = xyz_target.mean(0) - R @ xyz.mean(0) + T = torch.eye(4, device=xyz.device) + T[:3, :3] = R + T[:3, 3] = t + return T + + +def get_geman_mclure_weight(xyz, xyz_target, mu): + """ + compute the weights defined here for the iterative reweighted least square. + http://vladlen.info/papers/fast-global-registration.pdf + """ + norm2 = torch.norm(xyz_target - xyz, dim=1) ** 2 + return (mu / (mu + norm2)).view(-1, 1) + + +def get_matrix_system(xyz, xyz_target, weight): + """ + Build matrix of size 3N x 6 and b of size 3N + + xyz size N x 3 + xyz_target size N x 3 + weight size N + the matrix is minus cross product matrix concatenate with the identity (rearanged). + """ + assert xyz.shape == xyz_target.shape + A_x = torch.zeros(xyz.shape[0], 6, device=xyz.device) + A_y = torch.zeros(xyz.shape[0], 6, device=xyz.device) + A_z = torch.zeros(xyz.shape[0], 6, device=xyz.device) + b_x = weight.view(-1) * (xyz_target[:, 0] - xyz[:, 0]) + b_y = weight.view(-1) * (xyz_target[:, 1] - xyz[:, 1]) + b_z = weight.view(-1) * (xyz_target[:, 2] - xyz[:, 2]) + A_x[:, 1] = weight.view(-1) * xyz[:, 2] + A_x[:, 2] = -weight.view(-1) * xyz[:, 1] + A_x[:, 3] = weight.view(-1) * 1 + A_y[:, 0] = -weight.view(-1) * xyz[:, 2] + A_y[:, 2] = weight.view(-1) * xyz[:, 0] + A_y[:, 4] = weight.view(-1) * 1 + A_z[:, 0] = weight.view(-1) * xyz[:, 1] + A_z[:, 1] = -weight.view(-1) * xyz[:, 0] + A_z[:, 5] = weight.view(-1) * 1 + return torch.cat([A_x, A_y, A_z], 0), torch.cat([b_x, b_y, b_z], 0).view(-1, 1) + + +def fast_global_registration(xyz, xyz_target, mu_init=1, num_iter=20): + """ + estimate the rotation and translation using Fast Global Registration algorithm (M estimator for robust estimation) + http://vladlen.info/papers/fast-global-registration.pdf + """ + assert xyz.shape == xyz_target.shape + + T_res = torch.eye(4, device=xyz.device) + mu = mu_init + source = xyz.clone() + weight = torch.ones(len(source), 1, device=xyz.device) + for i in range(num_iter): + if i > 0 and i % 5 == 0: + mu /= 2.0 + A, b = get_matrix_system(source, xyz_target, weight) + x = torch.linalg.solve(A.T.mm(A), A.T @ b) + T = get_trans(x.view(-1)) + source = source.mm(T[:3, :3].T) + T[:3, 3] + T_res = T @ T_res + weight = get_geman_mclure_weight(source, xyz_target, mu) + return T_res + + +def teaser_pp_registration( + xyz, + xyz_target, + noise_bound=0.05, + cbar2=1, + rotation_gnc_factor=1.4, + rotation_max_iterations=100, + rotation_cost_threshold=1e-12, +): + assert xyz.shape == xyz_target.shape + import teaserpp_python + + # Populating the parameters + solver_params = teaserpp_python.RobustRegistrationSolver.Params() + solver_params.cbar2 = cbar2 + solver_params.noise_bound = noise_bound + solver_params.estimate_scaling = False + solver_params.rotation_estimation_algorithm = ( + teaserpp_python.RobustRegistrationSolver.ROTATION_ESTIMATION_ALGORITHM.GNC_TLS + ) + solver_params.rotation_gnc_factor = rotation_gnc_factor + solver_params.rotation_max_iterations = rotation_max_iterations + solver_params.rotation_cost_threshold = rotation_cost_threshold + + solver = teaserpp_python.RobustRegistrationSolver(solver_params) + + solver.solve(xyz.T.detach().cpu().numpy(), xyz_target.T.detach().cpu().numpy()) + + solution = solver.getSolution() + T_res = torch.eye(4, device=xyz.device) + T_res[:3, :3] = torch.from_numpy(solution.rotation).to(xyz.device) + T_res[:3, 3] = torch.from_numpy(solution.translation).to(xyz.device) + return T_res + + +def ransac_registration(xyz, xyz_target, distance_threshold=0.05, num_iterations=80000): + """ + use Open3D version of RANSAC + """ + pcd = open3d.geometry.PointCloud() + pcd.points = open3d.utility.Vector3dVector(xyz.detach().cpu().numpy()) + + pcd_t = open3d.geometry.PointCloud() + pcd_t.points = open3d.utility.Vector3dVector(xyz_target.detach().cpu().numpy()) + rang = np.arange(len(xyz)) + corres = np.stack((rang, rang), axis=1) + corres = open3d.utility.Vector2iVector(corres) + result = open3d.pipelines.registration.registration_ransac_based_on_correspondence( + pcd, + pcd_t, + corres, + distance_threshold, + estimation_method=open3d.pipelines.registration.TransformationEstimationPointToPoint(False), + ransac_n=4, + criteria=open3d.pipelines.registration.RANSACConvergenceCriteria(4000000, num_iterations), + ) + + return torch.from_numpy(result.transformation).float() diff --git a/torch-points3d/torch_points3d/utils/running_stats.py b/torch-points3d/torch_points3d/utils/running_stats.py new file mode 100644 index 0000000..2c2b0b7 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/running_stats.py @@ -0,0 +1,35 @@ +import numpy as np + + +class RunningStats: + def __init__(self): + self.n = 0 + self.old_m = 0 + self.new_m = 0 + self.old_s = 0 + self.new_s = 0 + + def clear(self): + self.n = 0 + + def push(self, x): + self.n += 1 + + if self.n == 1: + self.old_m = self.new_m = x + self.old_s = 0 + else: + self.new_m = self.old_m + (x - self.old_m) / self.n + self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m) + + self.old_m = self.new_m + self.old_s = self.new_s + + def mean(self): + return self.new_m if self.n else 0.0 + + def variance(self): + return self.new_s / (self.n - 1) if self.n > 1 else 0.0 + + def std(self): + return np.sqrt(self.variance()) diff --git a/torch-points3d/torch_points3d/utils/timer.py b/torch-points3d/torch_points3d/utils/timer.py new file mode 100644 index 0000000..6e9debe --- /dev/null +++ b/torch-points3d/torch_points3d/utils/timer.py @@ -0,0 +1,53 @@ +from time import time +from collections import defaultdict +import functools +from .running_stats import RunningStats + +FunctionStats: defaultdict = defaultdict(RunningStats) + + +def time_func(*outer_args, **outer_kwargs): + print_rec = outer_kwargs.get("print_rec", 100) + measure_runtime = outer_kwargs.get("measure_runtime", False) + name = outer_kwargs.get("name", "") + + def time_func_inner(func): + @functools.wraps(func) + def func_wrapper(*args, **kwargs): + if measure_runtime: + func_name = name if name else func.__name__ + if FunctionStats.get(func_name, None) is not None: + if FunctionStats[func_name].n % print_rec == 0: + stats = FunctionStats[func_name] + stats_mean = stats.mean() + print( + "{} run in {} | {} over {} runs".format( + func_name, stats_mean, stats_mean * stats.n, stats.n + ) + ) + # print('{} run in {} +/- {} over {} runs'.format(func.__name__, stats.mean(), stats.std(), stats.n)) + t0 = time() + out = func(*args, **kwargs) + diff = time() - t0 + FunctionStats[func_name].push(diff) + return out + else: + return func(*args, **kwargs) + + return func_wrapper + + return time_func_inner + + +@time_func(print_rec=50, measure_runtime=True) +def do_nothing(): + pass + + +def iteration(): + for _ in range(10000): + do_nothing() + + +if __name__ == "__main__": + iteration() diff --git a/torch-points3d/torch_points3d/utils/transform_utils.py b/torch-points3d/torch_points3d/utils/transform_utils.py new file mode 100644 index 0000000..8529a3f --- /dev/null +++ b/torch-points3d/torch_points3d/utils/transform_utils.py @@ -0,0 +1,39 @@ +import numpy as np + + +class SamplingStrategy(object): + + STRATEGIES = ["random", "freq_class_based"] + CLASS_WEIGHT_METHODS = ["sqrt"] + + def __init__(self, strategy="random", class_weight_method="sqrt"): + + if strategy.lower() in self.STRATEGIES: + self._strategy = strategy.lower() + + if class_weight_method.lower() in self.CLASS_WEIGHT_METHODS: + self._class_weight_method = class_weight_method.lower() + + def __call__(self, data): + + if self._strategy == "random": + random_center = np.random.randint(0, len(data.pos)) + + elif self._strategy == "freq_class_based": + labels = np.asarray(data.y) + uni, uni_counts = np.unique(np.asarray(data.y), return_counts=True) + uni_counts = uni_counts.mean() / uni_counts + if self._class_weight_method == "sqrt": + uni_counts = np.sqrt(uni_counts) + uni_counts /= np.sum(uni_counts) + chosen_label = np.random.choice(uni, p=uni_counts) + random_center = np.random.choice(np.argwhere(labels == chosen_label).flatten()) + else: + raise NotImplementedError + + return random_center + + def __repr__(self): + return "{}(strategy={}, class_weight_method={})".format( + self.__class__.__name__, self._strategy, self._class_weight_method + ) diff --git a/torch-points3d/torch_points3d/utils/wandb_utils.py b/torch-points3d/torch_points3d/utils/wandb_utils.py new file mode 100644 index 0000000..4de45c5 --- /dev/null +++ b/torch-points3d/torch_points3d/utils/wandb_utils.py @@ -0,0 +1,109 @@ +import os +import shutil +import subprocess + + +class WandbUrls: + def __init__(self, url): + hash = url.split("/")[-2] + project = url.split("/")[-3] + entity = url.split("/")[-4] + + self.weight_url = url + self.log_url = "https://app.wandb.ai/{}/{}/runs/{}/logs".format(entity, project, hash) + self.chart_url = "https://app.wandb.ai/{}/{}/runs/{}".format(entity, project, hash) + self.overview_url = "https://app.wandb.ai/{}/{}/runs/{}/overview".format(entity, project, hash) + self.hydra_config_url = "https://app.wandb.ai/{}/{}/runs/{}/files/hydra-config.yaml".format( + entity, project, hash + ) + self.overrides_url = "https://app.wandb.ai/{}/{}/runs/{}/files/overrides.yaml".format(entity, project, hash) + + def __repr__(self): + msg = "=================================================== WANDB URLS ===================================================================\n" + for k, v in self.__dict__.items(): + msg += "{}: {}\n".format(k.upper(), v) + msg += "=================================================================================================================================\n" + return msg + + +class Wandb: + IS_ACTIVE = False + + @staticmethod + def set_urls_to_model(model, url): + wandb_urls = WandbUrls(url) + model.wandb = wandb_urls + + @staticmethod + def _set_to_wandb_args(wandb_args, cfg, name): + var = getattr(cfg.training.wandb, name, None) + if var: + wandb_args[name] = var + + @staticmethod + def launch(cfg, launch: bool): + if launch: + import wandb + + Wandb.IS_ACTIVE = True + + model_config = getattr(cfg.models, cfg.model_name, None) + model_class = getattr(model_config, "class", "loaded model") + tested_dataset_class = getattr(cfg.data, "class") + optim = cfg.training.get("optim", None) + otimizer_class = getattr(optim.optimizer, "class") if optim is not None else "loaded model" + lr_scheduler = cfg.training.get("lr_scheduler", None) + scheduler_class = getattr(lr_scheduler, "class") if lr_scheduler is not None else "loaded model" + tags = [ + cfg.model_name, + model_class.split(".")[0], + tested_dataset_class, + otimizer_class, + scheduler_class, + ] + + wandb_args = {} + wandb_args["project"] = cfg.training.wandb.project + wandb_args["tags"] = tags + wandb_args["resume"] = "allow" + Wandb._set_to_wandb_args(wandb_args, cfg, "name") + Wandb._set_to_wandb_args(wandb_args, cfg, "entity") + Wandb._set_to_wandb_args(wandb_args, cfg, "notes") + Wandb._set_to_wandb_args(wandb_args, cfg, "config") + Wandb._set_to_wandb_args(wandb_args, cfg, "id") + + try: + commit_sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + gitdiff = subprocess.check_output(["git", "diff", "--", "':!notebooks'"]).decode() + except: + commit_sha = "n/a" + gitdiff = "" + + config = wandb_args.get("config", {}) + wandb_args["config"] = { + **config, + "run_path": os.getcwd(), + "commit": commit_sha, + } + + wandb_args["settings"] = wandb.Settings(start_method="fork") + + wandb.init(**wandb_args) + shutil.copyfile( + os.path.join(os.getcwd(), ".hydra/config.yaml"), os.path.join(os.getcwd(), ".hydra/hydra-config.yaml") + ) + wandb.save(os.path.join(os.getcwd(), ".hydra/hydra-config.yaml")) + wandb.save(os.path.join(os.getcwd(), ".hydra/overrides.yaml")) + + with open("change.patch", "w") as f: + f.write(gitdiff) + wandb.save(os.path.join(os.getcwd(), "change.patch")) + + @staticmethod + def add_file(file_path: str): + if not Wandb.IS_ACTIVE: + raise RuntimeError("wandb is inactive, please launch first.") + import wandb + + filename = os.path.basename(file_path) + shutil.copyfile(file_path, os.path.join(wandb.run.dir, filename)) diff --git a/torch-points3d/torch_points3d/visualization/__init__.py b/torch-points3d/torch_points3d/visualization/__init__.py new file mode 100644 index 0000000..9f2b5a4 --- /dev/null +++ b/torch-points3d/torch_points3d/visualization/__init__.py @@ -0,0 +1,2 @@ +from .visualizer import * +from .experiment_manager import ExperimentManager diff --git a/torch-points3d/torch_points3d/visualization/experiment_manager.py b/torch-points3d/torch_points3d/visualization/experiment_manager.py new file mode 100644 index 0000000..12451ed --- /dev/null +++ b/torch-points3d/torch_points3d/visualization/experiment_manager.py @@ -0,0 +1,191 @@ +import os +from glob import glob +from collections import defaultdict +import torch +from plyfile import PlyData, PlyElement +from numpy.lib import recfunctions as rfn +from torch_points3d.utils.colors import COLORS +import numpy as np + + +def colored_print(color, msg): + print(color + msg + COLORS.END_NO_TOKEN) + + +class ExperimentFolder: + + POS_KEYS = ["x", "y", "z"] + + def __init__(self, run_path): + self._run_path = run_path + self._model_name = None + self._stats = None + self._find_files() + + def _find_files(self): + self._files = os.listdir(self._run_path) + + def __repr__(self): + return self._run_path.split("outputs")[1] + + @property + def model_name(self): + return self._model_name + + @property + def epochs(self): + return os.listdir(self._viz_path) + + def get_splits(self, epoch): + return os.listdir(os.path.join(self._viz_path, str(epoch))) + + def get_files(self, epoch, split): + return os.listdir(os.path.join(self._viz_path, str(epoch), split)) + + def load_ply(self, epoch, split, file): + self._data_name = "data_{}_{}_{}".format(epoch, split, file) + if not hasattr(self, self._data_name): + path_to_ply = os.path.join(self._viz_path, str(epoch), split, file) + if os.path.exists(path_to_ply): + plydata = PlyData.read(path_to_ply) + arr = np.asarray([e.data for e in plydata.elements]) + names = list(arr.dtype.names) + pos_indices = [names.index(n) for n in self.POS_KEYS] + non_pos_indices = {n: names.index(n) for n in names if n not in self.POS_KEYS} + arr_ = rfn.structured_to_unstructured(arr).squeeze() + xyz = arr_[:, pos_indices] + data = {"xyz": xyz, "columns": non_pos_indices.keys(), "name": self._data_name} + for n, i in non_pos_indices.items(): + data[n] = arr_[:, i] + setattr(self, self._data_name, data) + else: + print("The file doesn' t exist: Wierd !") + else: + return getattr(self, self._data_name) + + @property + def current_pointcloud(self): + return getattr(self, self._data_name) + + @property + def contains_viz(self): + if not hasattr(self, "_contains_viz"): + for f in self._files: + if "viz" in f: + self._viz_path = os.path.join(self._run_path, "viz") + vizs = os.listdir(self._viz_path) + self._contains_viz = len(vizs) > 0 + return self._contains_viz + self._contains_viz = False + return self._contains_viz + else: + return self._contains_viz + + @property + def contains_trained_model(self): + if not hasattr(self, "_contains_trained_model"): + for f in self._files: + if ".pt" in f: + self._contains_trained_model = True + self._model_name = f + return self._contains_trained_model + self._contains_trained_model = False + return self._contains_trained_model + else: + return self._contains_trained_model + + def extract_stats(self): + path_to_checkpoint = os.path.join(self._run_path, self.model_name) + stats = torch.load(path_to_checkpoint)["stats"] + self._stats = stats + num_epoch = len(stats["train"]) + stats_dict = defaultdict(dict) + for split_name in stats.keys(): + if len(stats[split_name]) > 0: + latest_epoch = stats[split_name][-1] + for metric_name in latest_epoch.keys(): + if "best" in metric_name: + stats_dict[metric_name][split_name] = latest_epoch[metric_name] + return num_epoch, stats_dict + + +class ExperimentManager(object): + def __init__(self, experiments_root): + self._experiments_root = experiments_root + self._collect_experiments() + + def _collect_experiments(self): + self._experiment_with_models = defaultdict(list) + run_paths = glob(os.path.join(self._experiments_root, "outputs", "*", "*")) + for run_path in run_paths: + experiment = ExperimentFolder(run_path) + if experiment.contains_trained_model: + self._experiment_with_models[experiment.model_name].append(experiment) + + self._find_experiments_with_viz() + + def _find_experiments_with_viz(self): + if not hasattr(self, "_experiment_with_viz"): + self._experiment_with_viz = defaultdict(list) + for model_name in self._experiment_with_models.keys(): + for experiment in self._experiment_with_models[model_name]: + if experiment.contains_viz: + self._experiment_with_viz[experiment.model_name].append(experiment) + + @property + def model_name_wviz(self): + keys = list(self._experiment_with_viz.keys()) + return [k.replace(".pt", "") for k in keys] + + @property + def current_pointcloud(self): + return self._current_experiment.current_pointcloud + + def load_ply_file(self, file): + if hasattr(self, "_current_split"): + self._current_file = file + self._current_experiment.load_ply(self._current_epoch, self._current_split, self._current_file) + else: + return [] + + def from_split_to_file(self, split_name): + if hasattr(self, "_current_epoch"): + self._current_split = split_name + return self._current_experiment.get_files(self._current_epoch, self._current_split) + else: + return [] + + def from_epoch_to_split(self, epoch): + if hasattr(self, "_current_experiment"): + self._current_epoch = epoch + return self._current_experiment.get_splits(self._current_epoch) + else: + return [] + + def from_paths_to_epoch(self, run_path): + for exp in self._current_exps: + if str(run_path) == str(exp.__repr__()): + self._current_experiment = exp + return sorted(self._current_experiment.epochs) + + def get_model_wviz_paths(self, model_path): + model_name = model_path + ".pt" + self._current_exps = self._experiment_with_viz[model_name] + return self._current_exps + + def display_stats(self): + print("") + for model_name in self._experiment_with_models.keys(): + colored_print(COLORS.Green, str(model_name)) + for experiment in self._experiment_with_models[model_name]: + print(experiment) + num_epoch, stats = experiment.extract_stats() + colored_print(COLORS.Red, "Epoch: {}".format(num_epoch)) + for metric_name in stats: + sentence = "" + for split_name in stats[metric_name].keys(): + sentence += "{}: {}, ".format(split_name, stats[metric_name][split_name]) + metric_sentence = metric_name + "({})".format(sentence[:-2]) + colored_print(COLORS.BBlue, metric_sentence) + print("") + print("") diff --git a/torch-points3d/torch_points3d/visualization/visualizer.py b/torch-points3d/torch_points3d/visualization/visualizer.py new file mode 100644 index 0000000..fa2b3f8 --- /dev/null +++ b/torch-points3d/torch_points3d/visualization/visualizer.py @@ -0,0 +1,414 @@ +import logging +import os +from itertools import product +from math import log10, ceil +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import wandb +from matplotlib.cm import get_cmap +from plyfile import PlyData, PlyElement + +from torch_points3d.utils.config import is_list + +log = logging.getLogger(__name__) + + +class Visualizer: + """Initialize the Visualizer class. + Parameters: + viz_conf (OmegaConf Dictionary) -- stores all config for the visualizer + num_batches (dict) -- This dictionary maps stage_name to #batches + batch_size (int) -- Current batch size usef + save_dir (str) -- The path used by hydra to store the experiment + + This class is responsible to save visuals into different formats. Currently supported formats are: + ply -- Either an ascii or binary ply file, with the labels and gt stored as columns + tensorboard -- Visualize point cloud in tensorboard + wandb -- Upload point cloud to wandb. WARNING: This can become very slow, both in training and on the web. + Make sure you properly limit the num_samples_per_epoch and wandb_max_points. + csv -- creates a csv with predictions + + The configuration looks like this: + visualization: + activate: False # Whether to activate the visualizer + format: ["ply", "tensorboard"] # 'pointcloud' is deprecated, use 'ply' instead + num_samples_per_epoch: 2 # If negative, it will save all elements + deterministic: True # False -> Randomly sample elements from epoch to epoch + deterministic_seed: 0 # Random seed used to generate consistant keys if deterministic is True + saved_keys: # Mapping from Data Object to structured numpy + pos: [['x', 'float'], ['y', 'float'], ['z', 'float']] + y: [['l', 'float']] + pred: [['p', 'float']] + indices: # List of indices to be saved (support "train", "test", "val") + train: [0, 3] + # Format specific options: + ply_format: binary_big_endian # PLY format (support "binary_big_endian", "binary_little_endian", "ascii") + tensorboard_mesh: # Mapping from mesh name and propety use to color + label: 'y' + prediction: 'pred' + wandb_max_points: 10000 # Limits the size of the cloud that gets uploaded by random sampling. + # "-1" saves the entire cloud + wandb_cmap: # Applies a color map to the point cloud. Allows custom coloring of different classes. + - [0, 0, 0] # class 0 + - [255, 255, 255] # class 1 + - [128, 128, 128] # class 2 + """ + + def __init__(self, viz_conf, num_batches, batch_size, save_dir, tracker): + # From configuration and dataset + for stage_name, stage_num_sample in num_batches.items(): + setattr(self, "{}_num_batches".format(stage_name), stage_num_sample) + self._batch_size = batch_size + self._activate = viz_conf.activate + self._format = [viz_conf.format] if not is_list(viz_conf.format) else viz_conf.format + self._num_samples_per_epoch = int(viz_conf.num_samples_per_epoch) + self._deterministic = viz_conf.deterministic + self._seed = viz_conf.deterministic_seed if viz_conf.deterministic_seed is not None else 0 + self._tracker = tracker + + self._saved_keys = viz_conf.saved_keys + self._tensorboard_mesh = {} + self.save_dir = save_dir + self._viz_path = os.path.join(save_dir, "viz") + + # Internal state + self._stage = None + self._current_epoch = None + + # format-specific initialization + if "ply" in self._format: + self._ply_format = viz_conf.ply_format if viz_conf.ply_format is not None else "binary_big_endian" + + if "tensorboard" in self._format: + if not tracker._use_tensorboard: + log.warning("Tensorboard visualization specified, but tensorboard isn't active.") + else: + self._tensorboard_mesh = viz_conf.tensorboard_mesh + + # SummaryWriter for tensorboard loging + self._writer = tracker._writer + + if "wandb" in self._format: + if not self._tracker._wandb: + log.warning("Wandb visualization specified, but Wandb isn't active.") + else: + self._wandb_cmap = viz_conf.get("wandb_cmap", None) + self._max_points = viz_conf.wandb_max_points if viz_conf.get("wandb_max_points", + None) is not None else -1 + + if "gpkg" in self._format or "csv" in self._format: + self.dfs = [] + + self._indices = {} + self._contains_indices = False + + try: + indices = getattr(viz_conf, "indices", None) + except: + indices = None + + if indices: + for split in ["train", "test", "val"]: + if split in indices: + split_indices = indices[split] + self._indices[split] = np.asarray(split_indices) + self._contains_indices = True + + def finalize_epoch(self, loaders): + if "gpkg" in self._format or "csv" in self._format: + df = pd.concat(self.dfs) + + for loader in loaders: + + dataset = loader.dataset + for area_name in np.unique(df.area): + area_pred = df.query("area == @area_name").set_index("label_idx").drop("area", axis=1) + + if "csv" in self._format: + dff = area_pred.query(f"stage == '{self._stage}'") + del dff["stage"] + file = self.save_dir / Path(f"{area_name}_{self._stage}_preds.csv") + dff.to_csv(file, mode='a', header=not file.exists()) + + if "gpkg" in self._format: + file = self.save_dir / Path(f"{area_name}_preds.gpkg") + area_label = dataset.areas.get(area_name, None) + if area_label is None: + continue + area_label = area_label["labels"] + area_df = area_label.join(area_pred, rsuffix="_pred", how="inner") + area_df = area_df.query(f"stage == '{self._stage}'") + del area_df["stage"] + area_df.to_file(file, mode='a' if file.exists() else "w") + + def get_indices(self, stage): + """This function is responsible to calculate the indices to be saved""" + if self._contains_indices: + return + stage_num_batches = getattr(self, "{}_num_batches".format(stage)) + total_items = (stage_num_batches - 1) * self._batch_size + if stage_num_batches > 0: + if self._num_samples_per_epoch < 0: # All elements should be saved. + if stage_num_batches > 0: + self._indices[stage] = np.arange(total_items) + else: + self._indices[stage] = None + else: + if self._num_samples_per_epoch > total_items: + log.warning("Number of samples to save is higher than the number of available elements") + self._indices[stage] = self._rng.permutation(total_items)[: self._num_samples_per_epoch] + + @property + def is_active(self): + return self._activate + + def reset(self, epoch, stage): + """This function is responsible to restore the visualizer + to start a new epoch on a new stage + """ + self._current_epoch = epoch + self._seen_batch = 0 + self._stage = stage + if self._deterministic: + self._rng = np.random.default_rng(self._seed) + else: + self._rng = np.random.default_rng() + if self._activate: + self.get_indices(stage) + + def _extract_from_PYG(self, item, pos_idx): + num_samples = item.batch.shape[0] + batch_mask = item.batch == pos_idx + out_data = {} + for k in item.keys: + if torch.is_tensor(item[k]) and (k in self._saved_keys.keys() or k in self._tensorboard_mesh.values()): + if item[k].shape[0] == num_samples: + out_data[k] = item[k][batch_mask] + return out_data + + def _extract_from_dense(self, item, pos_idx): + assert ( # TODO only true if task is segmentation + item.y.shape[0] == item.pos.shape[0] + ), "y and pos should have the same number of samples. Something is probably wrong with your data to visualise" + num_samples = item.y.shape[0] + out_data = {} + for k in item.keys: + if torch.is_tensor(item[k]) and (k in self._saved_keys.keys() or k in self._tensorboard_mesh.values()): + if item[k].shape[0] == num_samples: + out_data[k] = item[k][pos_idx] + return out_data + + def _dict_to_structured_npy(self, item): + item.keys() + out = [] + dtypes = [] + for k, v in item.items(): + v_npy = v.detach().cpu().numpy() + if len(v_npy.shape) == 1: + v_npy = v_npy[..., np.newaxis] + for dtype in self._saved_keys[k]: + dtypes.append(dtype) + out.append(v_npy) + + out = np.concatenate(out, axis=-1) + dtypes = np.dtype([tuple(d) for d in dtypes]) + return np.asarray([tuple(o) for o in out], dtype=dtypes) + + def save_visuals(self, model, loader): + """This function is responsible to save the data into .ply objects + Parameters: + model -- Contains the model including visuals + loader - Contains the dataloader + Make sure the saved_keys within the config maps to the Data attributes. + """ + if self._stage in self._indices: + visuals = model.get_current_visuals() + + if any(format in self._format for format in ["csv", "gpkg", "ply"]): + dataset = loader.dataset + pred = {} + data_vis = model.data_visual + if model.has_reg_targets: + preds = model.get_reg_output().detach().cpu().numpy() + for i, pred_name in enumerate(dataset.reg_targets): + pred[pred_name] = preds[:, i] + if model.has_mol_targets: + preds = model.get_mol_output().detach().cpu().numpy() + for i, pred_name in enumerate(dataset.mol_targets): + pred[pred_name] = preds[:, i] + if model.has_cls_targets: + preds = model.get_cls_output() + for i, pred_name in enumerate(dataset.cls_targets): + pred[pred_name] = preds[i].argmax(1).detach().cpu().numpy() + pred[f"{pred_name}_prob"] = preds[i].softmax(1).detach().cpu().numpy().max(1) + pred["area"] = data_vis.area_name + + label_idx_ = [idx[0] for idx in data_vis.label_idx] + pred["label_idx"] = label_idx_ + + df = pd.DataFrame(pred) + + if "gpkg" in self._format or "csv" in self._format: + df["stage"] = self._stage + self.dfs.append(df) + + is_ply = "ply" in self._format + if is_ply: + viz_path = Path(self._viz_path) + viz_path.mkdir(exist_ok=True) + for i, (_, sample) in enumerate(df.iterrows()): + area_name = sample["area"] + area_path = viz_path / area_name + area_path.mkdir(exist_ok=True) + label_idx = sample['label_idx'] + file = area_path / f"{label_idx}.ply" + out_item = self._extract_from_PYG(data_vis, i) + out_item = self._dict_to_structured_npy(out_item) + self.save_ply(out_item, f"{area_name}_{label_idx}", file) + + if all(format not in self._format for format in ["wandb", "tensorboard"]): + return + stage_num_batches = getattr(self, "{}_num_batches".format(self._stage)) + batch_indices = self._indices[self._stage] // self._batch_size + pos_indices = self._indices[self._stage] % self._batch_size + for idx in np.argwhere(self._seen_batch == batch_indices).flatten(): + pos_idx = pos_indices[idx] + for visual_name, item in visuals.items(): + if hasattr(item, "batch") and item.batch is not None: # The PYG dataloader has been used + out_item = self._extract_from_PYG(item, pos_idx) + else: + out_item = self._extract_from_dense(item, pos_idx) + + if "tensorboard" in self._format and self._tracker._use_tensorboard: + self.save_tensorboard(out_item, visual_name, stage_num_batches) + + out_item = self._dict_to_structured_npy(out_item) + gt_name = "{}_{}_{}_gt".format(self._current_epoch, self._seen_batch, pos_idx) + pred_name = "{}_{}_{}".format(self._current_epoch, self._seen_batch, pos_idx) + + if "wandb" in self._format and self._tracker._wandb: + self.save_wandb(out_item, gt_name, pred_name) + + self._seen_batch += 1 + + def save_ply(self, npy_array, visual_name, filename): + + el = PlyElement.describe(npy_array, visual_name) + if self._ply_format == "ascii": + PlyData([el], text=True).write(filename) + elif self._ply_format == "binary_little_endian": + PlyData([el], byte_order="<").write(filename) + elif self._ply_format == "binary_big_endian": + PlyData([el], byte_order=">").write(filename) + else: + PlyData([el]).write(filename) + + def save_tensorboard(self, out_item, visual_name, stage_num_batches): + pos = out_item["pos"].detach().cpu().unsqueeze(0) + colors = get_cmap("tab10") + config_dict = {"material": {"size": 0.3}} + + for label, k in self._tensorboard_mesh.items(): + value = out_item[k].detach().cpu() + + if len(value.shape) == 2 and value.shape[1] == 3: + if value.min() >= 0 and value.max() <= 1: + value = (255 * value).type(torch.uint8).unsqueeze(0) + else: + value = value.type(torch.uint8).unsqueeze(0) + elif len(value.shape) == 1 and value.shape[0] == 1: + value = np.tile((255 * colors(value.numpy() % 10))[:, 0:3].astype(np.uint8), (pos.shape[0], 1)).reshape( + (1, -1, 3) + ) + elif len(value.shape) == 1 or value.shape[1] == 1: + value = (255 * colors(value.numpy() % 10))[:, 0:3].astype(np.uint8).reshape((1, -1, 3)) + else: + continue + + self._writer.add_mesh( + self._stage + "/" + visual_name + "/" + label, + pos, + colors=value, + config_dict=config_dict, + global_step=(self._current_epoch - 1) * (10 ** ceil(log10(stage_num_batches + 1))) + self._seen_batch, + ) + + def gen_bb_corners(self, points): + points_min = np.min(points, axis=0) + points_max = np.max(points, axis=0) + points_min_max = np.stack([points_min, points_max], axis=0) + + bb_points = [] + for x, y, z in [i for i in product(range(2), repeat=3)]: # 2^3 binary combination table + bb_points.append([points_min_max[x, 0], points_min_max[y, 1], points_min_max[z, 2]]) + return bb_points + + def apply_cmap(self, val): + out = np.zeros((val.shape[0], 3), dtype=int) + for label, color in enumerate(self._wandb_cmap): + out[val == label] = color + return out + + PRED_COLOR = [255, 0, 0] # red + GT_COLOR = [124, 255, 0] # green + + # https://docs.wandb.ai/guides/track/log/media#3d-visualizations + def save_wandb(self, out_item, gt_name, pred_name): + if self._max_points > 0: + out_item = out_item[self._rng.permutation(len(out_item))[: self._max_points]] + if self._wandb_cmap is None: + assert (out_item["p"].max() + 1) <= 14, "Wandb classes must be in 1-14" + assert (out_item["l"].max() + 1) <= 14, "Wandb classes must be in 1-14" + + pred_points = np.stack([out_item["x"], out_item["y"], out_item["z"], out_item["p"] + 1], axis=1) + gt_points = np.stack([out_item["x"], out_item["y"], out_item["z"], out_item["l"] + 1], axis=1) + else: + pred_colors = self.apply_cmap(out_item["p"]) + gt_colors = self.apply_cmap(out_item["l"]) + pred_points = np.stack( + [out_item["x"], out_item["y"], out_item["z"], pred_colors[:, 0], pred_colors[:, 1], pred_colors[:, 2]], + axis=1, + ) + gt_points = np.stack( + [out_item["x"], out_item["y"], out_item["z"], gt_colors[:, 0], gt_colors[:, 1], gt_colors[:, 2]], axis=1 + ) + + corners = self.gen_bb_corners(pred_points) + + pred_scene = wandb.Object3D( + { + "type": "lidar/beta", + "points": pred_points, + "boxes": np.array( # draw 3d boxes + [ + { + "corners": corners, + "label": pred_name, + "color": self.PRED_COLOR, + } + ] + ), + } + ) + gt_scene = wandb.Object3D( + { + "type": "lidar/beta", + "points": gt_points, + "boxes": np.array( # draw 3d boxes + [ + { + "corners": corners, + "label": gt_name, + "color": self.GT_COLOR, + } + ] + ), + } + ) + + gt_scene_name = "{}/gt".format(self._stage) + pred_scene_name = "{}/pred".format(self._stage) + wandb.log({pred_scene_name: pred_scene, gt_scene_name: gt_scene, "epoch": self._current_epoch}) diff --git a/torch-points3d/train.py b/torch-points3d/train.py new file mode 100644 index 0000000..79b97fc --- /dev/null +++ b/torch-points3d/train.py @@ -0,0 +1,22 @@ +import hydra +from hydra.core.global_hydra import GlobalHydra +from omegaconf import OmegaConf +from torch_points3d.trainer import Trainer + + +@hydra.main(config_path="conf", config_name="config") +def main(cfg): + OmegaConf.set_struct(cfg, False) # This allows getattr and hasattr methods to function correctly + if cfg.pretty_print: + print(OmegaConf.to_yaml(cfg)) + + trainer = Trainer(cfg) + trainer.train() + # + # # https://github.com/facebookresearch/hydra/issues/440 + GlobalHydra.get_state().clear() + return 0 + + +if __name__ == "__main__": + main()