<xarray.Dataset>\n", + "Dimensions: (longitude: 16, latitude: 16, time: 291542)\n", + "Coordinates:\n", + " * longitude (longitude) float32 -81.0 -80.8 -80.6 -80.4 ... -78.4 -78.2 -78.0\n", + " * latitude (latitude) float32 45.0 44.8 44.6 44.4 ... 42.6 42.4 42.2 42.0\n", + " * time (time) datetime64[ns] 1990-01-01 ... 2023-04-05T13:00:00\n", + "Data variables:\n", + " t2m (time, latitude, longitude) float32 dask.array<chunksize=(8760, 16, 16), meta=np.ndarray>\n", + "Attributes:\n", + " Conventions: CF-1.6\n", + " history: 2023-04-10 08:34:25 GMT by grib_to_netcdf-2.25.1: /opt/ecmw...
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/mean_bias │ -0.4124308228492737 │\n", + "│ test/mean_bias_2m_temperature │ -0.4124308228492737 │\n", + "│ test/pearson │ 0.9967161095043267 │\n", + "│ test/pearsonr_2m_temperature │ 0.9967161095043267 │\n", + "│ test/rmse │ 0.523897647857666 │\n", + "│ test/rmse_2m_temperature │ 0.523897647857666 │\n", + "└───────────────────────────────┴───────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "trainer.test(model_module_downscaling, data_module)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "This visualization displays the results of a random test for the Resnet model. It shows the weather forecast for the spatial downscaling of the selected region (area = [42.0, -81.0, 45.0, -78.0]) in a short time period (1 days). The initial condition is the temperature of the date we selected to test, the ground truth is the actual temperature after 1 days, and the prediction is the outcome of the trained model. \n", + "\n", + "The mean square error of the Resnet model is relatively high, but its small size of less than 5MB makes it easy to deploy. This is a major advantage of the Resnet model, as it can be used in a variety of applications without taking up too much space." + ], + "metadata": { + "id": "PROj0EHeeRbl" + } + }, + { + "cell_type": "code", + "source": [ + "from climate_learn.utils import visualize\n", + "\n", + "# if samples = 2, we randomly pick 2 initial conditions in the test set\n", + "visualize(model_module_downscaling, data_module, samples = 2)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 240 + }, + "id": "EOtdiPKsrMP2", + "outputId": "f0e9f851-1c1c-4b8d-be8b-e27d59cde67d" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
<xarray.Dataset>\n", + "Dimensions: (longitude: 16, latitude: 16, time: 291542)\n", + "Coordinates:\n", + " * longitude (longitude) float32 -81.0 -80.8 -80.6 -80.4 ... -78.4 -78.2 -78.0\n", + " * latitude (latitude) float32 45.0 44.8 44.6 44.4 ... 42.6 42.4 42.2 42.0\n", + " * time (time) datetime64[ns] 1990-01-01 ... 2023-04-05T13:00:00\n", + "Data variables:\n", + " t2m (time, latitude, longitude) float32 dask.array<chunksize=(8760, 16, 16), meta=np.ndarray>\n", + "Attributes:\n", + " Conventions: CF-1.6\n", + " history: 2023-04-10 08:34:25 GMT by grib_to_netcdf-2.25.1: /opt/ecmw...
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/acc │ 0.6581142544746399 │\n", + "│ test/acc_2m_temperature_1.0_days │ 0.6581142544746399 │\n", + "│ test/w_rmse │ 2.6785058975219727 │\n", + "│ test/w_rmse_2m_temperature_1.0_days │ 2.6785058975219727 │\n", + "│ test_climatology_baseline/w_rmse │ 8.735182762145996 │\n", + "│ test_climatology_baseline/w_rmse_2m_temperature_1.0_d… │ 8.735182762145996 │\n", + "│ test_persistence_baseline/w_rmse │ 3.306151866912842 │\n", + "│ test_persistence_baseline/w_rmse_2m_temperature_1.0_d… │ 3.306151866912842 │\n", + "└────────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "trainer.test(model_module_forecasting, data_module)" + ] + }, + { + "cell_type": "markdown", + "source": [ + "This visualization displays the results of a random test for the Unet model. It shows the weather forecast for the temperature of the selected region (area = [42.0, -81.0, 45.0, -78.0]) in a short time period (1 days). The initial condition is the temperature of the date we selected to test, the ground truth is the actual temperature after 1 days, and the prediction is the outcome of the trained model. \n" + ], + "metadata": { + "id": "PROj0EHeeRbl" + } + }, + { + "cell_type": "code", + "source": [ + "from climate_learn.utils import visualize\n", + "\n", + "# if samples = 2, we randomly pick 2 initial conditions in the test set\n", + "visualize(model_module_forecasting, data_module, samples = 2)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 240 + }, + "id": "EOtdiPKsrMP2", + "outputId": "c0861764-b2cf-41f8-99f0-6f72cd5d1319" + }, + "execution_count": 49, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃ ┃ Name ┃ Type ┃ Params ┃\n", + "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│ 0 │ net │ ResNet │ 1.2 M │\n", + "│ 1 │ net.activation │ LeakyReLU │ 0 │\n", + "│ 2 │ net.image_proj │ PeriodicConv2D │ 6.4 K │\n", + "│ 3 │ net.image_proj.padding │ PeriodicPadding2D │ 0 │\n", + "│ 4 │ net.image_proj.conv │ Conv2d │ 6.4 K │\n", + "│ 5 │ net.blocks │ ModuleList │ 1.2 M │\n", + "│ 6 │ net.blocks.0 │ ResidualBlock │ 295 K │\n", + "│ 7 │ net.blocks.0.activation │ LeakyReLU │ 0 │\n", + "│ 8 │ net.blocks.0.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 9 │ net.blocks.0.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 10 │ net.blocks.0.conv1.conv │ Conv2d │ 147 K │\n", + "│ 11 │ net.blocks.0.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 12 │ net.blocks.0.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 13 │ net.blocks.0.conv2.conv │ Conv2d │ 147 K │\n", + "│ 14 │ net.blocks.0.shortcut │ Identity │ 0 │\n", + "│ 15 │ net.blocks.0.norm1 │ BatchNorm2d │ 256 │\n", + "│ 16 │ net.blocks.0.norm2 │ BatchNorm2d │ 256 │\n", + "│ 17 │ net.blocks.0.drop │ Dropout │ 0 │\n", + "│ 18 │ net.blocks.1 │ ResidualBlock │ 295 K │\n", + "│ 19 │ net.blocks.1.activation │ LeakyReLU │ 0 │\n", + "│ 20 │ net.blocks.1.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 21 │ net.blocks.1.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 22 │ net.blocks.1.conv1.conv │ Conv2d │ 147 K │\n", + "│ 23 │ net.blocks.1.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 24 │ net.blocks.1.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 25 │ net.blocks.1.conv2.conv │ Conv2d │ 147 K │\n", + "│ 26 │ net.blocks.1.shortcut │ Identity │ 0 │\n", + "│ 27 │ net.blocks.1.norm1 │ BatchNorm2d │ 256 │\n", + "│ 28 │ net.blocks.1.norm2 │ BatchNorm2d │ 256 │\n", + "│ 29 │ net.blocks.1.drop │ Dropout │ 0 │\n", + "│ 30 │ net.blocks.2 │ ResidualBlock │ 295 K │\n", + "│ 31 │ net.blocks.2.activation │ LeakyReLU │ 0 │\n", + "│ 32 │ net.blocks.2.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 33 │ net.blocks.2.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 34 │ net.blocks.2.conv1.conv │ Conv2d │ 147 K │\n", + "│ 35 │ net.blocks.2.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 36 │ net.blocks.2.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 37 │ net.blocks.2.conv2.conv │ Conv2d │ 147 K │\n", + "│ 38 │ net.blocks.2.shortcut │ Identity │ 0 │\n", + "│ 39 │ net.blocks.2.norm1 │ BatchNorm2d │ 256 │\n", + "│ 40 │ net.blocks.2.norm2 │ BatchNorm2d │ 256 │\n", + "│ 41 │ net.blocks.2.drop │ Dropout │ 0 │\n", + "│ 42 │ net.blocks.3 │ ResidualBlock │ 295 K │\n", + "│ 43 │ net.blocks.3.activation │ LeakyReLU │ 0 │\n", + "│ 44 │ net.blocks.3.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 45 │ net.blocks.3.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 46 │ net.blocks.3.conv1.conv │ Conv2d │ 147 K │\n", + "│ 47 │ net.blocks.3.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 48 │ net.blocks.3.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 49 │ net.blocks.3.conv2.conv │ Conv2d │ 147 K │\n", + "│ 50 │ net.blocks.3.shortcut │ Identity │ 0 │\n", + "│ 51 │ net.blocks.3.norm1 │ BatchNorm2d │ 256 │\n", + "│ 52 │ net.blocks.3.norm2 │ BatchNorm2d │ 256 │\n", + "│ 53 │ net.blocks.3.drop │ Dropout │ 0 │\n", + "│ 54 │ net.norm │ BatchNorm2d │ 256 │\n", + "│ 55 │ net.final │ PeriodicConv2D │ 6.3 K │\n", + "│ 56 │ net.final.padding │ PeriodicPadding2D │ 0 │\n", + "│ 57 │ net.final.conv │ Conv2d │ 6.3 K │\n", + "│ 58 │ denormalization │ Normalize │ 0 │\n", + "└────┴────────────────────────────┴───────────────────┴────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n", + "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│\u001b[2m \u001b[0m\u001b[2m0 \u001b[0m\u001b[2m \u001b[0m│ net │ ResNet │ 1.2 M │\n", + "│\u001b[2m \u001b[0m\u001b[2m1 \u001b[0m\u001b[2m \u001b[0m│ net.activation │ LeakyReLU │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m2 \u001b[0m\u001b[2m \u001b[0m│ net.image_proj │ PeriodicConv2D │ 6.4 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m3 \u001b[0m\u001b[2m \u001b[0m│ net.image_proj.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m4 \u001b[0m\u001b[2m \u001b[0m│ net.image_proj.conv │ Conv2d │ 6.4 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m5 \u001b[0m\u001b[2m \u001b[0m│ net.blocks │ ModuleList │ 1.2 M │\n", + "│\u001b[2m \u001b[0m\u001b[2m6 \u001b[0m\u001b[2m \u001b[0m│ net.blocks.0 │ ResidualBlock │ 295 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m7 \u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.activation │ LeakyReLU │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m8 \u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.conv1 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m9 \u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m10\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.conv1.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m11\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.conv2 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m12\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m13\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.conv2.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m14\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.shortcut │ Identity │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m15\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.norm1 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m16\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.norm2 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m17\u001b[0m\u001b[2m \u001b[0m│ net.blocks.0.drop │ Dropout │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m18\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1 │ ResidualBlock │ 295 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m19\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.activation │ LeakyReLU │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m20\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.conv1 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m21\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m22\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.conv1.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m23\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.conv2 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m24\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m25\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.conv2.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m26\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.shortcut │ Identity │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m27\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.norm1 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m28\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.norm2 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m29\u001b[0m\u001b[2m \u001b[0m│ net.blocks.1.drop │ Dropout │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m30\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2 │ ResidualBlock │ 295 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m31\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.activation │ LeakyReLU │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m32\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.conv1 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m33\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m34\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.conv1.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m35\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.conv2 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m36\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m37\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.conv2.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m38\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.shortcut │ Identity │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m39\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.norm1 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m40\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.norm2 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m41\u001b[0m\u001b[2m \u001b[0m│ net.blocks.2.drop │ Dropout │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m42\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3 │ ResidualBlock │ 295 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m43\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.activation │ LeakyReLU │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m44\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.conv1 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m45\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m46\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.conv1.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m47\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.conv2 │ PeriodicConv2D │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m48\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m49\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.conv2.conv │ Conv2d │ 147 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m50\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.shortcut │ Identity │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m51\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.norm1 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m52\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.norm2 │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m53\u001b[0m\u001b[2m \u001b[0m│ net.blocks.3.drop │ Dropout │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m54\u001b[0m\u001b[2m \u001b[0m│ net.norm │ BatchNorm2d │ 256 │\n", + "│\u001b[2m \u001b[0m\u001b[2m55\u001b[0m\u001b[2m \u001b[0m│ net.final │ PeriodicConv2D │ 6.3 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m56\u001b[0m\u001b[2m \u001b[0m│ net.final.padding │ PeriodicPadding2D │ 0 │\n", + "│\u001b[2m \u001b[0m\u001b[2m57\u001b[0m\u001b[2m \u001b[0m│ net.final.conv │ Conv2d │ 6.3 K │\n", + "│\u001b[2m \u001b[0m\u001b[2m58\u001b[0m\u001b[2m \u001b[0m│ denormalization │ Normalize │ 0 │\n", + "└────┴────────────────────────────┴───────────────────┴────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
Trainable params: 1.2 M \n", + "Non-trainable params: 0 \n", + "Total params: 1.2 M \n", + "Total estimated model params size (MB): 2 \n", + "\n" + ], + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 1.2 M \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 1.2 M \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 2 \n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f1b2b711e84a4e469d28ce506c5fbcc8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.fit(model_module, data_module)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k7JdGELMXpIw" + }, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 149, + "referenced_widgets": [ + "9f6338e3a9fb4a4fa4a3b5da14e6b471", + "d124da129802411086077527493bf39b" + ] + }, + "id": "kELMwe8lpm1e", + "outputId": "d58628cc-04be-4575-febc-a4812ef471f7" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9f6338e3a9fb4a4fa4a3b5da14e6b471", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/mean_bias_2m_temperature │ -0.6033839583396912 │\n", + "│ test/pearsonr_2m_temperature │ 0.9910030555527313 │\n", + "│ test/rmse_2m_temperature │ 2.859203577041626 │\n", + "└───────────────────────────────┴───────────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36mtest/mean_bias_2m_temperature\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m -0.6033839583396912 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtest/pearsonr_2m_temperature \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9910030555527313 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/rmse_2m_temperature \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 2.859203577041626 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────────┴───────────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.test(model_module, data_module)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [], + "machine_shape": "hm" + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "cl_env", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "5b35d5811d64db97cad819926e9e0ba09b354a75e2ee95b259c11201fc783944" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "859e5ae0abf849e4b3210ac64ac5b65a": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "9f6338e3a9fb4a4fa4a3b5da14e6b471": { + "model_module": "@jupyter-widgets/output", + "model_module_version": "1.0.0", + "model_name": "OutputModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/output", + "_model_module_version": "1.0.0", + "_model_name": "OutputModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/output", + "_view_module_version": "1.0.0", + "_view_name": "OutputView", + "layout": "IPY_MODEL_d124da129802411086077527493bf39b", + "msg_id": "", + "outputs": [ + { + "data": { + "text/html": "
Testing ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 35/35 0:00:15 • 0:00:00 2.23it/s \n\n", + "text/plain": "\u001b[37mTesting\u001b[0m \u001b[38;2;98;6;224m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[37m35/35\u001b[0m \u001b[38;5;245m0:00:15 • 0:00:00\u001b[0m \u001b[38;5;249m2.23it/s\u001b[0m \n" + }, + "metadata": {}, + "output_type": "display_data" + } + ] + } + }, + "d124da129802411086077527493bf39b": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "1.2.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "f1b2b711e84a4e469d28ce506c5fbcc8": { + "model_module": "@jupyter-widgets/output", + "model_module_version": "1.0.0", + "model_name": "OutputModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/output", + "_model_module_version": "1.0.0", + "_model_name": "OutputModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/output", + "_view_module_version": "1.0.0", + "_view_name": "OutputView", + "layout": "IPY_MODEL_859e5ae0abf849e4b3210ac64ac5b65a", + "msg_id": "", + "outputs": [ + { + "data": { + "text/html": "
Epoch 4/4 ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24/24 0:00:19 • 0:00:00 1.24it/s loss: 0.0597 train/2m_temperature: \n 0.056 train/loss: 0.056 \n\n", + "text/plain": "Epoch 4/4 \u001b[38;2;98;6;224m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[37m24/24\u001b[0m \u001b[38;5;245m0:00:19 • 0:00:00\u001b[0m \u001b[38;5;249m1.24it/s\u001b[0m \u001b[37mloss: 0.0597 train/2m_temperature: \u001b[0m\n \u001b[37m0.056 train/loss: 0.056 \u001b[0m\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ] + } + }, + "52ab009c24e447888173d3245b5afe83": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_c2fbc605175d4bde95231ebe3407d3c8", + "IPY_MODEL_a5cd0b21898e432ea97f7bf3eb67a3d5", + "IPY_MODEL_f32846c4505849b7837bd729e4585af4" + ], + "layout": "IPY_MODEL_4f1fc7bc22174ec89a84648e03a7a6ab" + } + }, + "c2fbc605175d4bde95231ebe3407d3c8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_085dc4645b054ad88b486e1f3c5aecad", + "placeholder": "", + "style": "IPY_MODEL_06f95eceadc54ddfaf2d33b35d8c1ab9", + "value": "Testing DataLoader 0: 100%" + } + }, + "a5cd0b21898e432ea97f7bf3eb67a3d5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_81538e16d30b403f8e25fc9b93e20946", + "max": 206, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_ed6d23a2e45a46b3aba8503b67ecb09e", + "value": 206 + } + }, + "f32846c4505849b7837bd729e4585af4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_f5224890ce8e485794a9f5f3579392db", + "placeholder": "", + "style": "IPY_MODEL_98fd61a92bfc46c89a075f9963eefac6", + "value": " 206/206 [00:18<00:00, 10.89it/s]" + } + }, + "4f1fc7bc22174ec89a84648e03a7a6ab": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": "inline-flex", + "flex": null, + "flex_flow": "row wrap", + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": "100%" + } + }, + "085dc4645b054ad88b486e1f3c5aecad": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "06f95eceadc54ddfaf2d33b35d8c1ab9": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "81538e16d30b403f8e25fc9b93e20946": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": "2", + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "ed6d23a2e45a46b3aba8503b67ecb09e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "f5224890ce8e485794a9f5f3579392db": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "98fd61a92bfc46c89a075f9963eefac6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/Modeling_for_Toronto_Region(For_Prediction_and_Downscaling).ipynb b/Modeling_for_Toronto_Region(For_Prediction_and_Downscaling).ipynb new file mode 100644 index 0000000..d81291c --- /dev/null +++ b/Modeling_for_Toronto_Region(For_Prediction_and_Downscaling).ipynb @@ -0,0 +1,9260 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "rSRCNgYzUwaf" + }, + "source": [ + "\n", + "# Software Requirements" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "aj4B6nixRIYp" + }, + "outputs": [], + "source": [ + "%%bash\n", + "python -m pip install --upgrade pip\n", + "pip install git+https://github.com/ProfessorGuineapig/climate-learn.git" + ] + }, + { + "cell_type": "code", + "source": [ + "#After installing \"rich\", it is necessary to restart your runtime in order to ensure that all of the necessary components are properly loaded and running. \n", + "!pip install -U rich" + ], + "metadata": { + "id": "MdgpjeAUqnhd", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "978f2698-78c4-4ea7-c344-6b10196d9aca" + }, + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Requirement already satisfied: rich in /usr/local/lib/python3.9/dist-packages (13.3.3)\n", + "Collecting rich\n", + " Downloading rich-13.3.4-py3-none-any.whl (238 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m238.7/238.7 kB\u001b[0m \u001b[31m23.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: markdown-it-py<3.0.0,>=2.2.0 in /usr/local/lib/python3.9/dist-packages (from rich) (2.2.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.9/dist-packages (from rich) (2.14.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.9/dist-packages (from markdown-it-py<3.0.0,>=2.2.0->rich) (0.1.2)\n", + "Installing collected packages: rich\n", + " Attempting uninstall: rich\n", + " Found existing installation: rich 13.3.3\n", + " Uninstalling rich-13.3.3:\n", + " Successfully uninstalled rich-13.3.3\n", + "Successfully installed rich-13.3.4\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "!pip install lion-pytorch" + ], + "metadata": { + "id": "kjbWLYYlp4VY", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "78f67d31-1b83-4996-e653-0bd85b97ac52" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", + "Collecting lion-pytorch\n", + " Downloading lion_pytorch-0.0.7-py3-none-any.whl (4.3 kB)\n", + "Requirement already satisfied: torch>=1.6 in /usr/local/lib/python3.9/dist-packages (from lion-pytorch) (2.0.0+cu118)\n", + "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/dist-packages (from torch>=1.6->lion-pytorch) (2.0.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.9/dist-packages (from torch>=1.6->lion-pytorch) (1.11.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/dist-packages (from torch>=1.6->lion-pytorch) (3.1.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.9/dist-packages (from torch>=1.6->lion-pytorch) (3.1)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.9/dist-packages (from torch>=1.6->lion-pytorch) (3.11.0)\n", + "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/dist-packages (from torch>=1.6->lion-pytorch) (4.5.0)\n", + "Requirement already satisfied: cmake in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.6->lion-pytorch) (3.25.2)\n", + "Requirement already satisfied: lit in /usr/local/lib/python3.9/dist-packages (from triton==2.0.0->torch>=1.6->lion-pytorch) (16.0.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/dist-packages (from jinja2->torch>=1.6->lion-pytorch) (2.1.2)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/dist-packages (from sympy->torch>=1.6->lion-pytorch) (1.3.0)\n", + "Installing collected packages: lion-pytorch\n", + "Successfully installed lion-pytorch-0.0.7\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "db5Zg6l6RJql", + "outputId": "cc62b2b0-d99d-4d32-94d9-019c9193eb9e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/drive\n" + ] + } + ], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/drive')" + ] + }, + { + "cell_type": "code", + "source": [ + "#If you already have the data from the 'Data preprocessing' section, there is no need to go through this part.\n", + "#@title # Loop over the years and download the data for each year ERA5 { run: \"auto\" }\n", + "# Define the years list\n", + "from climate_learn.data import download1\n", + "years = []\n", + "for y in range(1995, 2023):\n", + " years.append(str(y))\n", + "\n", + "# Define the API key\n", + "api_key = \"Your API_KEY\" # Change to your_api_key\n", + "\n", + "# Define the root directory for data downloads\n", + "root_dir = \"/content/drive/MyDrive/Climate/.climate_tutorial\"\n", + "\n", + "# Define the dataset and variable to download\n", + "\n", + "\n", + "\n", + "dataset = \"era5\"\n", + "variable = \"2m_temperature\"\n", + "\n", + "#This method of selecting an area of interest can be seen here.: https://youtu.be/EIe7IBMqhsw\n", + "'''\n", + "The area list corresponds to the latitude and longitude boundaries of the region of interest in the xarray dataset.\n", + "\n", + "The first and third elements of the area list correspond to the minimum and maximum latitude values of the region of interest, respectively. In this case, the minimum latitude is -5.2 and the maximum latitude is 31.\n", + "\n", + "The second and fourth elements of the area list correspond to the minimum and maximum longitude values of the region of interest, respectively. In this case, the minimum longitude is 34 and the maximum longitude is 45.\n", + "\n", + "The latitude and longitude coordinates in the xarray dataset should fall within these boundary values to be considered part of the region of interest.\n", + "\n", + "'''\n", + "\n", + "\n", + "area = [42.0, -81.0, 45.0, -78.0]\n", + "\n", + "resolution = str(0.2)\n", + "\n", + "# Loop over the years and download the data for each year\n", + "for i, year in enumerate(years):\n", + " #download_copernicus(root=root_dir, dataset=dataset, variable=variable, year=year, api_key=api_key)\n", + " download1(root = root_dir, source = \"copernicus\", variable = variable, dataset = dataset, year = year, resolution=resolution, area=area, api_key = api_key)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HfJYVFoef6ux", + "outputId": "4e8384bc-b4be-4b7a-dedb-444b4c118b03" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 1995 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_1995_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 1996 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_1996_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 1997 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_1997_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 1998 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_1998_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 1999 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_1999_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2000 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2000_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2001 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2001_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2002 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2002_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2003 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2003_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2004 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2004_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2005 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2005_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2006 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2006_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2007 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2007_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2008 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2008_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2009 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2009_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2010 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2010_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2011 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2011_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2012 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2012_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2013 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2013_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2014 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2014_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2015 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2015_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2016 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2016_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2017 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2017_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2018 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2018_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2019 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2019_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2020 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2020_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2021 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2021_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Downloading era5 2m_temperature data for year 2022 from copernicus to /content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.5/2m_temperature/2m_temperature_2022_0.5deg.nc\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + } + ] + }, + { + "cell_type": "code", + "source": [ + "from climate_learn.utils.data import load_dataset, view\n", + "from climate_learn.data import download1\n", + "\n", + "dataset = load_dataset(\"/content/drive/MyDrive/Climate/.climate_tutorial/data/copernicus/era5/0.20/2m_temperature\")\n", + "view(dataset)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 374 + }, + "id": "DL_CjAt0OS-7", + "outputId": "c1282247-1fd7-409f-c2ea-9275b2a4a1de" + }, + "execution_count": 7, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
<xarray.Dataset>\n", + "Dimensions: (longitude: 16, latitude: 16, time: 291542)\n", + "Coordinates:\n", + " * longitude (longitude) float32 -81.0 -80.8 -80.6 -80.4 ... -78.4 -78.2 -78.0\n", + " * latitude (latitude) float32 45.0 44.8 44.6 44.4 ... 42.6 42.4 42.2 42.0\n", + " * time (time) datetime64[ns] 1990-01-01 ... 2023-04-05T13:00:00\n", + "Data variables:\n", + " t2m (time, latitude, longitude) float32 dask.array<chunksize=(8760, 16, 16), meta=np.ndarray>\n", + "Attributes:\n", + " Conventions: CF-1.6\n", + " history: 2023-04-10 08:34:25 GMT by grib_to_netcdf-2.25.1: /opt/ecmw...
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃ ┃ Name ┃ Type ┃ Params ┃\n", + "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│ 0 │ net │ ResNet │ 1.2 M │\n", + "│ 1 │ net.activation │ LeakyReLU │ 0 │\n", + "│ 2 │ net.image_proj │ PeriodicConv2D │ 6.4 K │\n", + "│ 3 │ net.image_proj.padding │ PeriodicPadding2D │ 0 │\n", + "│ 4 │ net.image_proj.conv │ Conv2d │ 6.4 K │\n", + "│ 5 │ net.blocks │ ModuleList │ 1.2 M │\n", + "│ 6 │ net.blocks.0 │ ResidualBlock │ 295 K │\n", + "│ 7 │ net.blocks.0.activation │ LeakyReLU │ 0 │\n", + "│ 8 │ net.blocks.0.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 9 │ net.blocks.0.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 10 │ net.blocks.0.conv1.conv │ Conv2d │ 147 K │\n", + "│ 11 │ net.blocks.0.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 12 │ net.blocks.0.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 13 │ net.blocks.0.conv2.conv │ Conv2d │ 147 K │\n", + "│ 14 │ net.blocks.0.shortcut │ Identity │ 0 │\n", + "│ 15 │ net.blocks.0.norm1 │ BatchNorm2d │ 256 │\n", + "│ 16 │ net.blocks.0.norm2 │ BatchNorm2d │ 256 │\n", + "│ 17 │ net.blocks.0.drop │ Dropout │ 0 │\n", + "│ 18 │ net.blocks.1 │ ResidualBlock │ 295 K │\n", + "│ 19 │ net.blocks.1.activation │ LeakyReLU │ 0 │\n", + "│ 20 │ net.blocks.1.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 21 │ net.blocks.1.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 22 │ net.blocks.1.conv1.conv │ Conv2d │ 147 K │\n", + "│ 23 │ net.blocks.1.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 24 │ net.blocks.1.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 25 │ net.blocks.1.conv2.conv │ Conv2d │ 147 K │\n", + "│ 26 │ net.blocks.1.shortcut │ Identity │ 0 │\n", + "│ 27 │ net.blocks.1.norm1 │ BatchNorm2d │ 256 │\n", + "│ 28 │ net.blocks.1.norm2 │ BatchNorm2d │ 256 │\n", + "│ 29 │ net.blocks.1.drop │ Dropout │ 0 │\n", + "│ 30 │ net.blocks.2 │ ResidualBlock │ 295 K │\n", + "│ 31 │ net.blocks.2.activation │ LeakyReLU │ 0 │\n", + "│ 32 │ net.blocks.2.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 33 │ net.blocks.2.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 34 │ net.blocks.2.conv1.conv │ Conv2d │ 147 K │\n", + "│ 35 │ net.blocks.2.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 36 │ net.blocks.2.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 37 │ net.blocks.2.conv2.conv │ Conv2d │ 147 K │\n", + "│ 38 │ net.blocks.2.shortcut │ Identity │ 0 │\n", + "│ 39 │ net.blocks.2.norm1 │ BatchNorm2d │ 256 │\n", + "│ 40 │ net.blocks.2.norm2 │ BatchNorm2d │ 256 │\n", + "│ 41 │ net.blocks.2.drop │ Dropout │ 0 │\n", + "│ 42 │ net.blocks.3 │ ResidualBlock │ 295 K │\n", + "│ 43 │ net.blocks.3.activation │ LeakyReLU │ 0 │\n", + "│ 44 │ net.blocks.3.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 45 │ net.blocks.3.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 46 │ net.blocks.3.conv1.conv │ Conv2d │ 147 K │\n", + "│ 47 │ net.blocks.3.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 48 │ net.blocks.3.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 49 │ net.blocks.3.conv2.conv │ Conv2d │ 147 K │\n", + "│ 50 │ net.blocks.3.shortcut │ Identity │ 0 │\n", + "│ 51 │ net.blocks.3.norm1 │ BatchNorm2d │ 256 │\n", + "│ 52 │ net.blocks.3.norm2 │ BatchNorm2d │ 256 │\n", + "│ 53 │ net.blocks.3.drop │ Dropout │ 0 │\n", + "│ 54 │ net.norm │ BatchNorm2d │ 256 │\n", + "│ 55 │ net.final │ PeriodicConv2D │ 6.3 K │\n", + "│ 56 │ net.final.padding │ PeriodicPadding2D │ 0 │\n", + "│ 57 │ net.final.conv │ Conv2d │ 6.3 K │\n", + "│ 58 │ denormalization │ Normalize │ 0 │\n", + "└────┴────────────────────────────┴───────────────────┴────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 1.2 M \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 1.2 M \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 4 \n" + ], + "text/html": [ + "
Trainable params: 1.2 M \n", + "Non-trainable params: 0 \n", + "Total params: 1.2 M \n", + "Total estimated model params size (MB): 4 \n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Sanity Checking: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "91cd3e035e1242469d0523c5d484bac6" + } + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n", + "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "19d29e9993744499a031b0163365a3a7" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "144c80019d8145bd852bf060a3b3b873" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "0d3593a416fc41b1b2c8ff9822481b81" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "d2ed0a957ff64b11b9b904934b1718fe" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "c02da1bf9e334f29acc27e6d6c3eb60c" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "092789608c284789ac5c439cf6622669" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "890e1405dff2467891672aa315d03ee8" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "d877a9ff477f4c668f4fe94c18fe6f05" + } + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "model_module_forcast_saving = load_model(name = \"resnet\", task = \"forecasting\", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)\n", + "#model_module_forcast_saving = load_model(name = \"unet\", task = \"forecasting\", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)" + ], + "metadata": { + "id": "_fKXEoWuV982" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model_module_forecast.lat" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZJ3cpiFYauIl", + "outputId": "9d20f22d-d30d-4e78-f01c-2d49d0f2d267" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "array([45. , 44.8, 44.6, 44.4, 44.2, 44. , 43.8, 43.6, 43.4, 43.2, 43. ,\n", + " 42.8, 42.6, 42.4, 42.2, 42. ], dtype=float32)" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import torch\n", + "\n", + "#The specific file for saving the trained model has been created and is ready to be reused. \n", + "save_dir = \"/content/drive/MyDrive/Climate/.climate_tutorial/trained model/\"\n", + "if not os.path.exists(save_dir):\n", + " os.makedirs(save_dir)\n", + "\n", + "save_path = \"/content/drive/MyDrive/Climate/.climate_tutorial/trained model/trained_model__Toronto_04_12_0_20_Res_0.05(1d&batch_size128).pth\" #When saving your work, it is important to remember to choose a unique file name. \n", + "# Save model state_dict and denormalization layer information\n", + "torch.save({\n", + " 'model_state_dict': model_module_forecast.state_dict(),\n", + " 'denormalization_mean': model_module_forecast.denormalization.mean,\n", + " 'denormalization_std': model_module_forecast.denormalization.std,\n", + " 'pred_range': model_module_forecast.pred_range,\n", + " 'lat': model_module_forecast.lat,\n", + " 'test_clim': model_module_forecast.test_clim,\n", + " 'train_clim': model_module_forecast.train_clim\n", + "}, save_path)\n" + ], + "metadata": { + "id": "58urPMF1gbSh" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fCR5yoIAUBTa" + }, + "source": [ + "## Evaluation \n" + ] + }, + { + "cell_type": "code", + "source": [ + "trainer.test(model_module_forecast, data_module)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 576, + "referenced_widgets": [ + "901997c21abd47bf809e1714ea3dca15", + "f7599963c6b34ca1a94e77eb1b892eb4", + "97e1867838f54cdea500de95992b5fda", + "cf487f7185854668af144f3f154e1fc2", + "e71fe50c588143758f6f3bb434a113f0", + "b4b634ac431f4749b91add29be4e960a", + "17a35ba1222142f3b59b5689655e494a", + "e5096d3be2374ca287211f7300967647", + "d9712ac135d349e7a8b3a4c5428c611a", + "7e50fc1e036841299155f792eb25bf16", + "def32741e239424587b0b521f90d4b87" + ] + }, + "id": "CQhPrCMWMn07", + "outputId": "352c2a12-6c06-4382-a1b0-99ba678ca8b5" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Creating train dataset\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + " 0%| | 0/28 [00:00, ?it/s]\u001b[A\n", + " 18%|█▊ | 5/28 [00:00<00:00, 46.77it/s]\u001b[A\n", + " 36%|███▌ | 10/28 [00:00<00:00, 47.64it/s]\u001b[A\n", + " 54%|█████▎ | 15/28 [00:00<00:00, 47.85it/s]\u001b[A\n", + " 71%|███████▏ | 20/28 [00:00<00:00, 48.14it/s]\u001b[A\n", + "100%|██████████| 28/28 [00:00<00:00, 47.60it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Creating val dataset\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "100%|██████████| 2/2 [00:00<00:00, 40.97it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Creating test dataset\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\n", + "100%|██████████| 3/3 [00:00<00:00, 47.80it/s]\n", + "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "901997c21abd47bf809e1714ea3dca15" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test/acc \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5284543633460999 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/acc_2m_temperature_1.0_days \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5284543633460999 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/w_rmse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.2339284420013428 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/w_rmse_2m_temperature_1.0_days \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.2339284420013428 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_climatology_baseline/w_rmse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 9.015924453735352 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtest_climatology_baseline/w_rmse_2m_temperature_1.0_d…\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 9.015924453735352 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_persistence_baseline/w_rmse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.3931617736816406 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtest_persistence_baseline/w_rmse_2m_temperature_1.0_d…\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.3931617736816406 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test_ridge_regression_baseline/w_rmse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.3932297229766846 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtest_ridge_regression_baseline/w_rmse_2m_temperature_…\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 3.3932297229766846 \u001b[0m\u001b[35m \u001b[0m│\n", + "└────────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┘\n" + ], + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/acc │ 0.5284543633460999 │\n", + "│ test/acc_2m_temperature_1.0_days │ 0.5284543633460999 │\n", + "│ test/w_rmse │ 3.2339284420013428 │\n", + "│ test/w_rmse_2m_temperature_1.0_days │ 3.2339284420013428 │\n", + "│ test_climatology_baseline/w_rmse │ 9.015924453735352 │\n", + "│ test_climatology_baseline/w_rmse_2m_temperature_1.0_d… │ 9.015924453735352 │\n", + "│ test_persistence_baseline/w_rmse │ 3.3931617736816406 │\n", + "│ test_persistence_baseline/w_rmse_2m_temperature_1.0_d… │ 3.3931617736816406 │\n", + "│ test_ridge_regression_baseline/w_rmse │ 3.3932297229766846 │\n", + "│ test_ridge_regression_baseline/w_rmse_2m_temperature_… │ 3.3932297229766846 │\n", + "└────────────────────────────────────────────────────────┴────────────────────────────────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "This visualization displays the results of a random test for the Resnet model. It shows the weather forecast for the temperature of the selected region (area = [42.0, -81.0, 45.0, -78.0]) in a short time period (1 days). The initial condition is the temperature of the date we selected to test, the ground truth is the actual temperature after 1 days, and the prediction is the outcome of the trained model. \n", + "\n", + "The mean square error of the Resnet model is relatively high, but its small size of less than 5MB makes it easy to deploy. This is a major advantage of the Resnet model, as it can be used in a variety of applications without taking up too much space." + ], + "metadata": { + "id": "PROj0EHeeRbl" + } + }, + { + "cell_type": "code", + "source": [ + "from climate_learn.utils import visualize\n", + "\n", + "# if samples = 2, we randomly pick 2 initial conditions in the test set\n", + "visualize(model_module_forecast, data_module, samples = 2)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 371 + }, + "id": "EOtdiPKsrMP2", + "outputId": "e078e07e-1eba-4fab-ee77-d264bf631e13" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n", + "┃ ┃ Name ┃ Type ┃ Params ┃\n", + "┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n", + "│ 0 │ net │ ResNet │ 1.2 M │\n", + "│ 1 │ net.activation │ LeakyReLU │ 0 │\n", + "│ 2 │ net.image_proj │ PeriodicConv2D │ 6.4 K │\n", + "│ 3 │ net.image_proj.padding │ PeriodicPadding2D │ 0 │\n", + "│ 4 │ net.image_proj.conv │ Conv2d │ 6.4 K │\n", + "│ 5 │ net.blocks │ ModuleList │ 1.2 M │\n", + "│ 6 │ net.blocks.0 │ ResidualBlock │ 295 K │\n", + "│ 7 │ net.blocks.0.activation │ LeakyReLU │ 0 │\n", + "│ 8 │ net.blocks.0.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 9 │ net.blocks.0.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 10 │ net.blocks.0.conv1.conv │ Conv2d │ 147 K │\n", + "│ 11 │ net.blocks.0.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 12 │ net.blocks.0.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 13 │ net.blocks.0.conv2.conv │ Conv2d │ 147 K │\n", + "│ 14 │ net.blocks.0.shortcut │ Identity │ 0 │\n", + "│ 15 │ net.blocks.0.norm1 │ BatchNorm2d │ 256 │\n", + "│ 16 │ net.blocks.0.norm2 │ BatchNorm2d │ 256 │\n", + "│ 17 │ net.blocks.0.drop │ Dropout │ 0 │\n", + "│ 18 │ net.blocks.1 │ ResidualBlock │ 295 K │\n", + "│ 19 │ net.blocks.1.activation │ LeakyReLU │ 0 │\n", + "│ 20 │ net.blocks.1.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 21 │ net.blocks.1.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 22 │ net.blocks.1.conv1.conv │ Conv2d │ 147 K │\n", + "│ 23 │ net.blocks.1.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 24 │ net.blocks.1.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 25 │ net.blocks.1.conv2.conv │ Conv2d │ 147 K │\n", + "│ 26 │ net.blocks.1.shortcut │ Identity │ 0 │\n", + "│ 27 │ net.blocks.1.norm1 │ BatchNorm2d │ 256 │\n", + "│ 28 │ net.blocks.1.norm2 │ BatchNorm2d │ 256 │\n", + "│ 29 │ net.blocks.1.drop │ Dropout │ 0 │\n", + "│ 30 │ net.blocks.2 │ ResidualBlock │ 295 K │\n", + "│ 31 │ net.blocks.2.activation │ LeakyReLU │ 0 │\n", + "│ 32 │ net.blocks.2.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 33 │ net.blocks.2.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 34 │ net.blocks.2.conv1.conv │ Conv2d │ 147 K │\n", + "│ 35 │ net.blocks.2.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 36 │ net.blocks.2.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 37 │ net.blocks.2.conv2.conv │ Conv2d │ 147 K │\n", + "│ 38 │ net.blocks.2.shortcut │ Identity │ 0 │\n", + "│ 39 │ net.blocks.2.norm1 │ BatchNorm2d │ 256 │\n", + "│ 40 │ net.blocks.2.norm2 │ BatchNorm2d │ 256 │\n", + "│ 41 │ net.blocks.2.drop │ Dropout │ 0 │\n", + "│ 42 │ net.blocks.3 │ ResidualBlock │ 295 K │\n", + "│ 43 │ net.blocks.3.activation │ LeakyReLU │ 0 │\n", + "│ 44 │ net.blocks.3.conv1 │ PeriodicConv2D │ 147 K │\n", + "│ 45 │ net.blocks.3.conv1.padding │ PeriodicPadding2D │ 0 │\n", + "│ 46 │ net.blocks.3.conv1.conv │ Conv2d │ 147 K │\n", + "│ 47 │ net.blocks.3.conv2 │ PeriodicConv2D │ 147 K │\n", + "│ 48 │ net.blocks.3.conv2.padding │ PeriodicPadding2D │ 0 │\n", + "│ 49 │ net.blocks.3.conv2.conv │ Conv2d │ 147 K │\n", + "│ 50 │ net.blocks.3.shortcut │ Identity │ 0 │\n", + "│ 51 │ net.blocks.3.norm1 │ BatchNorm2d │ 256 │\n", + "│ 52 │ net.blocks.3.norm2 │ BatchNorm2d │ 256 │\n", + "│ 53 │ net.blocks.3.drop │ Dropout │ 0 │\n", + "│ 54 │ net.norm │ BatchNorm2d │ 256 │\n", + "│ 55 │ net.final │ PeriodicConv2D │ 6.3 K │\n", + "│ 56 │ net.final.padding │ PeriodicPadding2D │ 0 │\n", + "│ 57 │ net.final.conv │ Conv2d │ 6.3 K │\n", + "│ 58 │ denormalization │ Normalize │ 0 │\n", + "└────┴────────────────────────────┴───────────────────┴────────┘\n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "\u001b[1mTrainable params\u001b[0m: 1.2 M \n", + "\u001b[1mNon-trainable params\u001b[0m: 0 \n", + "\u001b[1mTotal params\u001b[0m: 1.2 M \n", + "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 4 \n" + ], + "text/html": [ + "
Trainable params: 1.2 M \n", + "Non-trainable params: 0 \n", + "Total params: 1.2 M \n", + "Total estimated model params size (MB): 4 \n", + "\n" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Sanity Checking: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "e6bc068500464f4f8dce41fe51de81d5" + } + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n", + "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Training: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "5fb59f5832d64c0da41a30854179315d" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "663c97c1c58d4425b443f2045d0cb776" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "33d2d2540d444888808ff3be82983815" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "7d7afc4161bd47ab9caf3b66d5219f05" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "e2545c38b2174930a3f4ed54120cd35d" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Validation: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "a3d42258c9c04f278866586703d4aec9" + } + }, + "metadata": {} + } + ], + "source": [ + "trainer.fit(model_module_downscaling, data_module)" + ] + }, + { + "cell_type": "code", + "source": [ + "model_module_downscaling_saving = load_model(name = \"resnet\", task = \"downscaling\", model_kwargs = model_kwargs, optim_kwargs = optim_kwargs)" + ], + "metadata": { + "id": "GsfEXylHB3RY" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "import torch\n", + "\n", + "#The specific file for saving the trained model has been created and is ready to be reused. \n", + "save_dir = \"/content/drive/MyDrive/Climate/.climate_tutorial/trained model/\"\n", + "if not os.path.exists(save_dir):\n", + " os.makedirs(save_dir)\n", + "\n", + "save_path = \"/content/drive/MyDrive/Climate/.climate_tutorial/trained model/trained_model_Toronto_04_10_downscaling_Res.pth\" #When saving your work, it is important to remember to choose a unique file name. \n", + "# Save model state_dict and denormalization layer information\n", + "torch.save({\n", + " 'model_state_dict': model_module_downscaling.state_dict(),\n", + " 'denormalization_mean': model_module_downscaling.denormalization.mean,\n", + " 'denormalization_std': model_module_downscaling.denormalization.std,\n", + " 'pred_range': model_module_downscaling.pred_range,\n", + " 'lat': model_module_downscaling.lat,\n", + " 'test_clim': model_module_downscaling.test_clim,\n", + " 'train_clim': model_module_downscaling.train_clim\n", + "}, save_path)" + ], + "metadata": { + "id": "AlbYNJJ6CByj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "model_module_downscaling" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8xhXEaiSHKIA", + "outputId": "b24cd1f0-940a-41f5-92eb-3b1ec307a156" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "DownscaleLitModule(\n", + " (net): ResNet(\n", + " (activation): LeakyReLU(negative_slope=0.3)\n", + " (image_proj): PeriodicConv2D(\n", + " (padding): PeriodicPadding2D()\n", + " (conv): Conv2d(1, 128, kernel_size=(7, 7), stride=(1, 1))\n", + " )\n", + " (blocks): ModuleList(\n", + " (0-3): 4 x ResidualBlock(\n", + " (activation): LeakyReLU(negative_slope=0.3)\n", + " (conv1): PeriodicConv2D(\n", + " (padding): PeriodicPadding2D()\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", + " )\n", + " (conv2): PeriodicConv2D(\n", + " (padding): PeriodicPadding2D()\n", + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))\n", + " )\n", + " (shortcut): Identity()\n", + " (norm1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (drop): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (norm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " (final): PeriodicConv2D(\n", + " (padding): PeriodicPadding2D()\n", + " (conv): Conv2d(128, 1, kernel_size=(7, 7), stride=(1, 1))\n", + " )\n", + " )\n", + " (denormalization): Normalize(mean=[-26.242485], std=[0.09332103])\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 30 + } + ] + }, + { + "cell_type": "code", + "source": [ + "import torch\n", + "from torchvision.transforms import Normalize\n", + "\n", + "save_path = \"/content/drive/MyDrive/Climate/.climate_tutorial/trained model/trained_model_03_23_downscaling_Res.pth\"\n", + "\n", + "# Load the state_dict and denormalization layer information\n", + "checkpoint = torch.load(save_path)\n", + "model_state_dict = checkpoint['model_state_dict']\n", + "denormalization_mean = checkpoint['denormalization_mean']\n", + "denormalization_std = checkpoint['denormalization_std']\n", + "pred_range = checkpoint['pred_range']\n", + "lat = checkpoint['lat']\n", + "test_clim = checkpoint['test_clim']\n", + "train_clim = checkpoint['train_clim']\n", + "\n", + "# Update the model_module_d with the loaded state_dict\n", + "model_module_d.load_state_dict(model_state_dict)\n", + "model_module_d.pred_range = pred_range\n", + "model_module_d.lat = lat\n", + "model_module_d.test_clim = test_clim\n", + "model_module_d.train_clim = train_clim\n", + "\n", + "# Recreate the denormalization layer using the saved mean and standard deviation\n", + "model_module_d.denormalization = Normalize(mean=denormalization_mean, std=denormalization_std)" + ], + "metadata": { + "id": "JgaTJuLvCsP2" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "k7JdGELMXpIw" + }, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 424, + "referenced_widgets": [ + "bb03783a139d4483b70f03e5de1738b4", + "ec931188552c41b99fc2b934e4101dd9", + "0de8d5b2382649a2a8ad15b5daeff6ef", + "44899ffbce344660ab39a21969fceb41", + "e87a6f33cd5f41579bbc3364818a4cfa", + "1eb79a8ca14b488b9b7480331b6ff25a", + "8492c77f2ca14c63a35d99ecbe2996fd", + "5ff7420f1ab7436ea359e36789eeb28d", + "4bc265b2b3e340e0ba0e515eef917354", + "f7002f4bb22b4bab958acb53ff1df045", + "008458b114b449ea9fb22d186c106a9f" + ] + }, + "id": "kELMwe8lpm1e", + "outputId": "d7cecc7d-8020-4690-ec5a-d826bdcd5cd4" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Creating train dataset\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 28/28 [00:00<00:00, 66.03it/s]\n", + "100%|██████████| 28/28 [00:00<00:00, 50.40it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Creating val dataset\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 2/2 [00:00<00:00, 63.40it/s]\n", + "100%|██████████| 2/2 [00:00<00:00, 51.30it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Creating test dataset\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 3/3 [00:00<00:00, 64.02it/s]\n", + "100%|██████████| 3/3 [00:00<00:00, 52.38it/s]\n", + "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, test_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 12 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.\n", + " rank_zero_warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Testing: 0it [00:00, ?it/s]" + ], + "application/vnd.jupyter.widget-view+json": { + "version_major": 2, + "version_minor": 0, + "model_id": "bb03783a139d4483b70f03e5de1738b4" + } + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m test/mean_bias \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m -0.4142647683620453 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtest/mean_bias_2m_temperature\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m -0.4142647683620453 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/pearson \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9967347608306023 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36mtest/pearsonr_2m_temperature \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9967347608306023 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/rmse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5263026356697083 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m test/rmse_2m_temperature \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.5263026356697083 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────────┴───────────────────────────────┘\n" + ], + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ test/mean_bias │ -0.4142647683620453 │\n", + "│ test/mean_bias_2m_temperature │ -0.4142647683620453 │\n", + "│ test/pearson │ 0.9967347608306023 │\n", + "│ test/pearsonr_2m_temperature │ 0.9967347608306023 │\n", + "│ test/rmse │ 0.5263026356697083 │\n", + "│ test/rmse_2m_temperature │ 0.5263026356697083 │\n", + "└───────────────────────────────┴───────────────────────────────┘\n", + "\n" + ] + }, + "metadata": {} + } + ], + "source": [ + "trainer.test(model_module_downscaling, data_module)" + ] + }, + { + "cell_type": "code", + "source": [ + "from climate_learn.utils import visualize\n", + "\n", + "# if samples = 2, we randomly pick 2 initial conditions in the test set\n", + "visualize(model_module_downscaling, data_module, samples = 2)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 371 + }, + "id": "ldJCrHxJFVFX", + "outputId": "0284c208-4b60-4546-9cb3-cde778e9d91a" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "