Skip to content

Commit

Permalink
add GP to CI
Browse files Browse the repository at this point in the history
  • Loading branch information
Laura Cabayol-Garcia committed Oct 7, 2024
1 parent 8415329 commit 34457ab
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 5 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@ jobs:

# Run the tests and specify the temporary directory as output_dir
- name: Run tests
timeout-minutes: 30
timeout-minutes: 45
run: |
pytest tests/test_lace.py
pytest tests/plot_mpg.py
pytest tests/plot_mpg_gp.py
pytest tests/plot_mpg_nn.py
- name: List generated plots
run: |
Expand Down
4 changes: 2 additions & 2 deletions notebooks/Tutorial_emulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.16.1
# jupytext_version: 1.15.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# display_name: NERSC Python
# language: python
# name: python3
# ---
Expand Down
73 changes: 73 additions & 0 deletions tests/plot_mpg_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Import necessary modules
## General python modules
import numpy as np
import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend for rendering plots
from matplotlib import pyplot as plt
import matplotlib.cm as cm
import os
import argparse # Used for parsing command-line arguments

## LaCE specific modules
import lace
from lace.emulator.nn_emulator import NNEmulator
from lace.archive import nyx_archive, gadget_archive
from lace.utils import poly_p1d
from lace.utils.plotting_functions import plot_p1d_vs_emulator


def test():
"""
Function to plot emulated P1D using specified archive (Nyx or Gadget).
Parameters:
archive (str): Archive to use for data ('Nyx' or 'Gadget')
"""
archive_name = 'Gadget'
# Get the base directory of the lace module
repo = os.path.dirname(lace.__path__[0]) + "/"

# Define the parameters for the emulator specific to Gadget
emu_params = ['Delta2_p', 'n_p', 'mF', 'sigT_Mpc', 'gamma', 'kF_Mpc']
training_set='Pedersen21'
emulator_label='Pedersen23'


# Initialize a GadgetArchive instance for postprocessing data
archive = gadget_archive.GadgetArchive(postproc="Cabayol23")

# Directory for saving plots
save_dir = f'{repo}data/tmp_validation_figures/{archive_name}/'
#save_dir = '{repo}tmp/validation_figures/'
# Create the directory if it does not exist
os.makedirs(save_dir, exist_ok=True)

for ii, sim in enumerate(archive.list_sim):
if sim == 'mpg_central':
if sim in archive.list_sim_test:
emulator = GPEmulator(
training_set=training_set,
emulator_label=emulator_label,
emu_params=emu_params,
)
else:
emulator = GPEmulator(
training_set=training_set,
emulator_label=emulator_label,
emu_params=emu_params,
drop_sim=sim,
)

# Get testing data for the current simulation
testing_data = archive.get_testing_data(sim_label=f'{sim}')
if sim != 'nyx_central':
testing_data = [d for d in testing_data if d['val_scaling'] == 1]

# Plot and save the emulated P1D
save_path = f'{save_dir}{sim}{emulator_label}.png'
plot_p1d_vs_emulator(testing_data, emulator, save_path=save_path)

return

# Call the function to execute the test
test()
2 changes: 1 addition & 1 deletion tests/plot_mpg.py → tests/plot_mpg_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test():
testing_data = [d for d in testing_data if d['val_scaling'] == 1]

# Plot and save the emulated P1D
save_path = f'{save_dir}{sim}.png'
save_path = f'{save_dir}{sim}{emulator_label}.png'
plot_p1d_vs_emulator(testing_data, emulator, save_path=save_path)

return
Expand Down

0 comments on commit 34457ab

Please sign in to comment.