diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 0984f0d4..ebff6fba 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -31,10 +31,11 @@ jobs: # Run the tests and specify the temporary directory as output_dir - name: Run tests - timeout-minutes: 30 + timeout-minutes: 45 run: | pytest tests/test_lace.py - pytest tests/plot_mpg.py + pytest tests/plot_mpg_gp.py + pytest tests/plot_mpg_nn.py - name: List generated plots run: | diff --git a/notebooks/Tutorial_emulator.py b/notebooks/Tutorial_emulator.py index 5a6bbafb..f37ebaf5 100644 --- a/notebooks/Tutorial_emulator.py +++ b/notebooks/Tutorial_emulator.py @@ -6,9 +6,9 @@ # extension: .py # format_name: percent # format_version: '1.3' -# jupytext_version: 1.16.1 +# jupytext_version: 1.15.2 # kernelspec: -# display_name: Python 3 (ipykernel) +# display_name: NERSC Python # language: python # name: python3 # --- diff --git a/tests/plot_mpg_gp.py b/tests/plot_mpg_gp.py new file mode 100644 index 00000000..0d458e10 --- /dev/null +++ b/tests/plot_mpg_gp.py @@ -0,0 +1,73 @@ +# Import necessary modules +## General python modules +import numpy as np +import matplotlib +matplotlib.use('Agg') # Use a non-interactive backend for rendering plots +from matplotlib import pyplot as plt +import matplotlib.cm as cm +import os +import argparse # Used for parsing command-line arguments + +## LaCE specific modules +import lace +from lace.emulator.nn_emulator import NNEmulator +from lace.archive import nyx_archive, gadget_archive +from lace.utils import poly_p1d +from lace.utils.plotting_functions import plot_p1d_vs_emulator + + +def test(): + """ + Function to plot emulated P1D using specified archive (Nyx or Gadget). + + Parameters: + archive (str): Archive to use for data ('Nyx' or 'Gadget') + """ + archive_name = 'Gadget' + # Get the base directory of the lace module + repo = os.path.dirname(lace.__path__[0]) + "/" + + # Define the parameters for the emulator specific to Gadget + emu_params = ['Delta2_p', 'n_p', 'mF', 'sigT_Mpc', 'gamma', 'kF_Mpc'] + training_set='Pedersen21' + emulator_label='Pedersen23' + + + # Initialize a GadgetArchive instance for postprocessing data + archive = gadget_archive.GadgetArchive(postproc="Cabayol23") + + # Directory for saving plots + save_dir = f'{repo}data/tmp_validation_figures/{archive_name}/' + #save_dir = '{repo}tmp/validation_figures/' + # Create the directory if it does not exist + os.makedirs(save_dir, exist_ok=True) + + for ii, sim in enumerate(archive.list_sim): + if sim == 'mpg_central': + if sim in archive.list_sim_test: + emulator = GPEmulator( + training_set=training_set, + emulator_label=emulator_label, + emu_params=emu_params, + ) + else: + emulator = GPEmulator( + training_set=training_set, + emulator_label=emulator_label, + emu_params=emu_params, + drop_sim=sim, + ) + + # Get testing data for the current simulation + testing_data = archive.get_testing_data(sim_label=f'{sim}') + if sim != 'nyx_central': + testing_data = [d for d in testing_data if d['val_scaling'] == 1] + + # Plot and save the emulated P1D + save_path = f'{save_dir}{sim}{emulator_label}.png' + plot_p1d_vs_emulator(testing_data, emulator, save_path=save_path) + + return + +# Call the function to execute the test +test() diff --git a/tests/plot_mpg.py b/tests/plot_mpg_nn.py similarity index 97% rename from tests/plot_mpg.py rename to tests/plot_mpg_nn.py index 43442a92..77d8aa1a 100644 --- a/tests/plot_mpg.py +++ b/tests/plot_mpg_nn.py @@ -70,7 +70,7 @@ def test(): testing_data = [d for d in testing_data if d['val_scaling'] == 1] # Plot and save the emulated P1D - save_path = f'{save_dir}{sim}.png' + save_path = f'{save_dir}{sim}{emulator_label}.png' plot_p1d_vs_emulator(testing_data, emulator, save_path=save_path) return