diff --git a/arkane/encorr/ae.py b/arkane/encorr/ae.py index 3638705dcd..b06bd025c9 100644 --- a/arkane/encorr/ae.py +++ b/arkane/encorr/ae.py @@ -32,6 +32,7 @@ small, predetermined set of molecules. """ +import os import importlib import json import logging @@ -41,12 +42,19 @@ import numpy as np from scipy.stats import distributions +try: + import matplotlib.pyplot as plt +except ImportError as e: + plt = None + matplotlib_exception = e + from rmgpy import constants from rmgpy.molecule import get_element, Molecule import arkane.encorr.data as data from arkane.encorr.reference import ReferenceDatabase from arkane.modelchem import LevelOfTheory, CompositeLevelOfTheory +from arkane.encorr.bac import _covariance_to_correlation # List of species labels that will be used for fitting (labels should match reference database) SPECIES_LABELS = [ @@ -103,12 +111,13 @@ def __init__(self, self.overwrite = overwrite self.ae = AE(species_energies) - def execute(self, output_file: str = None): + def execute(self, output_directory: str = None, plot: bool = False): """ Execute the atom energy job. Args: - output_file: Write the fitted energies to this file. + output_directory: Write the fitted energies to this directory. + plot: Save plots of results. """ if self.level_of_theory is None: logging.info('Fitting atom energies') @@ -116,7 +125,10 @@ def execute(self, output_file: str = None): logging.info(f'Fitting atom energies for {self.level_of_theory}') self.ae.fit() - if output_file is not None: + if output_directory is not None: + model_chemistry_formatted = self.level_of_theory.to_model_chem().replace('//', '__').replace('/', '_') + output_file = os.path.join(output_directory, f'AEC_{model_chemistry_formatted}.out') + with open(output_file, 'a') as f: if self.level_of_theory is not None: f.write(f'# {self.level_of_theory}\n') @@ -125,6 +137,10 @@ def execute(self, output_file: str = None): f.writelines(self.ae.format_atom_energies( 'atom_energies' if self.level_of_theory is None else self.level_of_theory)) + if plot: + correlation_path = os.path.join(output_directory, f'AEC_{model_chemistry_formatted}_correlation.pdf') + self.save_correlation_mat(correlation_path) + if self.write_to_database: if self.level_of_theory is None: raise Exception('Level of theory is required for writing to database') @@ -134,6 +150,51 @@ def execute(self, output_file: str = None): logging.warning('Could not write atom energies to database. Captured error:') logging.warning(str(e)) + def save_correlation_mat(self, path: str, labels: List[str] = None, **kwargs): + """ + Save a visual representation of the parameter correlation matrix. + + Args: + path: Path to save figure to. + labels: Parameter labels. + """ + + if plt is None: + raise matplotlib_exception + + default_kwargs = {"fontsize": 14} + kwargs = {**default_kwargs, **kwargs} + + if self.ae.correlation is None: + raise Exception('Fit AECs before saving correlation matrix!') + + if labels is None: + labels = list(self.ae.atom_energies.keys()) + + fig, ax = plt.subplots() + ax.matshow(self.ae.correlation, cmap=plt.cm.PiYG) + + # Superimpose values as text + for i in range(len(self.ae.correlation)): + for j in range(len(self.ae.correlation)): + c = self.ae.correlation[j, i] + ax.text(i, j, f'{c: .2f}', va='center', ha='center', fontsize=8) + + # Save lims because they get changed when modifying labels + xlim = ax.get_xlim() + ylim = ax.get_ylim() + + ax.set_xticks(list(range(len(self.ae.correlation)))) + ax.set_yticks(list(range(len(self.ae.correlation)))) + ax.set_xticklabels(labels, fontsize=kwargs["fontsize"]) + ax.set_yticklabels(labels, fontsize=14) + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.tick_params(bottom=False, top=False, left=False, right=False) + + fig.savefig(path, bbox_inches='tight', pad_inches=0) + + class AE: """ @@ -146,6 +207,7 @@ class AE: def __init__(self, species_energies: Dict[str, float]): self.species_energies = species_energies # Hartree self.atom_energies = None + self.correlation = None # correlation matrix self.confidence_intervals = None for lbl in SPECIES_LABELS: @@ -198,6 +260,9 @@ def fit(self): ypred = x @ w sigma2 = np.sum((y - ypred)**2) / (n - k) # MSE cov = sigma2 * np.linalg.inv(x.T @ x) # covariance matrix + + self.correlation = _covariance_to_correlation(cov) + se = np.sqrt(np.diag(cov)) # standard error alpha = 0.05 # 95% confidence level tdist = distributions.t.ppf(1 - alpha/2, n - k) # student-t diff --git a/arkane/encorr/bac.py b/arkane/encorr/bac.py index f9b09f22c3..564cadeb15 100644 --- a/arkane/encorr/bac.py +++ b/arkane/encorr/bac.py @@ -52,6 +52,12 @@ from scipy.stats import distributions from sklearn.model_selection import KFold +try: + import matplotlib.pyplot as plt +except ImportError as e: + plt = None + matplotlib_exception = e + from rmgpy.quantity import ScalarQuantity import arkane.encorr.data as data @@ -190,10 +196,8 @@ def plot(self, output_directory: str, jobnum: int = 1): output_directory: Save the plots in this directory. jobnum: Job number """ - try: - import matplotlib.pyplot as plt - except ImportError: - return + if plt is None: + raise matplotlib_exception model_chemistry_formatted = self.level_of_theory.to_model_chem().replace('//', '__').replace('/', '_') if self.crossval_n_folds == 1: @@ -929,10 +933,9 @@ def save_correlation_mat(self, path: str, labels: List[str] = None): path: Path to save figure to. labels: Parameter labels. """ - try: - import matplotlib.pyplot as plt - except ImportError: - return + + if plt is None: + raise matplotlib_exception if self.correlation is None: raise BondAdditivityCorrectionError('Fit BACs before saving correlation matrix!') diff --git a/arkane/main.py b/arkane/main.py index d3adf14a17..b5345f99c0 100644 --- a/arkane/main.py +++ b/arkane/main.py @@ -222,7 +222,7 @@ def execute(self): job.execute(output_directory=self.output_directory, plot=self.plot, jobnum=bacjob_num) bacjob_num += 1 if isinstance(job, AEJob): - job.execute(output_file=output_file) + job.execute(output_directory=self.output_directory, plot=self.plot) with open(chemkin_file, 'a') as f: f.write('\n') diff --git a/documentation/source/users/arkane/running.rst b/documentation/source/users/arkane/running.rst index f5d674feba..f6cc90d8e4 100644 --- a/documentation/source/users/arkane/running.rst +++ b/documentation/source/users/arkane/running.rst @@ -34,8 +34,8 @@ to decrease; the latter two cause it to increase. Suppressing plot creation ========================= -Arkane by default will generate many plot files. By adding the ``-p``/``--no-plot`` -flag, Arkane will not generate any plots, reducing file size of output and +Arkane by default will generate many plot files and the previous files will be overwritten on subsequent runs. +By adding the ``-p``/``--no-plot`` flag, Arkane will not generate any plots, reducing file size of output and increasing the calculation speed. Help