Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
Laura Cabayol-Garcia committed Sep 30, 2024
1 parent 2b1e408 commit d9e3ea8
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions tests/plot_mpg.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
import pytest
import argparse
import os
import numpy as np
import matplotlib
matplotlib.use('Agg') # Use a non-interactive backend for rendering plots
matplotlib.use('Agg')
from matplotlib import pyplot as plt
import lace
from lace.emulator.nn_emulator import NNEmulator
from lace.archive import gadget_archive
from lace.utils.plotting_functions import plot_p1d_vs_emulator

@pytest.fixture
def output_dir(pytestconfig):
# Get the custom argument passed with pytest
return pytestconfig.getoption("output_dir")

def test(output_dir):
"""
Function to plot emulated P1D using specified archive and save plots to output directory.
Parameters:
output_dir (str): Directory to save the generated plots.
"""
archive_name = 'Gadget'
# Get the base directory of the lace module
repo = os.path.dirname(lace.__path__[0]) + "/"

# Define the parameters for the emulator specific to Gadget
emu_params = ['Delta2_p', 'n_p', 'mF', 'sigT_Mpc', 'gamma', 'kF_Mpc']
training_set = 'Cabayol23'
emulator_label = 'Cabayol23+'
model_path = f'{repo}data/NNmodels/Cabayol23+/Cabayol23+_drop_sim'

# Initialize a GadgetArchive instance for postprocessing data
archive = gadget_archive.GadgetArchive(postproc="Cabayol23")

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

for ii, sim in enumerate(['mpg_1', 'mpg_central']):
Expand All @@ -52,24 +52,15 @@ def test(output_dir):
train=False,
)

# Get testing data for the current simulation
testing_data = archive.get_testing_data(sim_label=f'{sim}')
if sim != 'nyx_central':
testing_data = [d for d in testing_data if d['val_scaling'] == 1]

# Plot and save the emulated P1D
save_path = os.path.join(output_dir, f'{sim}.png')
plot_p1d_vs_emulator(testing_data, emulator, save_path=save_path)

return

if __name__ == "__main__":
# Set up argument parser
parser = argparse.ArgumentParser(description="Generate and save emulator plots.")
parser.add_argument('--output_dir', type=str, required=True, help="Directory to save the plots.")

# Parse arguments
args = parser.parse_args()

# Run the test function with the specified output directory
test(args.output_dir)
def pytest_addoption(parser):
# Add custom options to pytest command
parser.addoption("--output_dir", action="store", default="tmp/validation_figures", help="Directory to save plots")

0 comments on commit d9e3ea8

Please sign in to comment.