Skip to content

Commit

Permalink
Merge pull request #2550 from hwpang/bac_correlation
Browse files Browse the repository at this point in the history
Allow user to plot correlation matrix for AECs
  • Loading branch information
JacksonBurns authored Dec 6, 2023
2 parents 8162ccb + c384e50 commit f4eb3b0
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 14 deletions.
71 changes: 68 additions & 3 deletions arkane/encorr/ae.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
small, predetermined set of molecules.
"""

import os
import importlib
import json
import logging
Expand All @@ -41,12 +42,19 @@
import numpy as np
from scipy.stats import distributions

try:
import matplotlib.pyplot as plt
except ImportError as e:
plt = None
matplotlib_exception = e

from rmgpy import constants
from rmgpy.molecule import get_element, Molecule

import arkane.encorr.data as data
from arkane.encorr.reference import ReferenceDatabase
from arkane.modelchem import LevelOfTheory, CompositeLevelOfTheory
from arkane.encorr.bac import _covariance_to_correlation

# List of species labels that will be used for fitting (labels should match reference database)
SPECIES_LABELS = [
Expand Down Expand Up @@ -103,20 +111,24 @@ def __init__(self,
self.overwrite = overwrite
self.ae = AE(species_energies)

def execute(self, output_file: str = None):
def execute(self, output_directory: str = None, plot: bool = False):
"""
Execute the atom energy job.
Args:
output_file: Write the fitted energies to this file.
output_directory: Write the fitted energies to this directory.
plot: Save plots of results.
"""
if self.level_of_theory is None:
logging.info('Fitting atom energies')
else:
logging.info(f'Fitting atom energies for {self.level_of_theory}')
self.ae.fit()

if output_file is not None:
if output_directory is not None:
model_chemistry_formatted = self.level_of_theory.to_model_chem().replace('//', '__').replace('/', '_')
output_file = os.path.join(output_directory, f'AEC_{model_chemistry_formatted}.out')

with open(output_file, 'a') as f:
if self.level_of_theory is not None:
f.write(f'# {self.level_of_theory}\n')
Expand All @@ -125,6 +137,10 @@ def execute(self, output_file: str = None):
f.writelines(self.ae.format_atom_energies(
'atom_energies' if self.level_of_theory is None else self.level_of_theory))

if plot:
correlation_path = os.path.join(output_directory, f'AEC_{model_chemistry_formatted}_correlation.pdf')
self.save_correlation_mat(correlation_path)

if self.write_to_database:
if self.level_of_theory is None:
raise Exception('Level of theory is required for writing to database')
Expand All @@ -134,6 +150,51 @@ def execute(self, output_file: str = None):
logging.warning('Could not write atom energies to database. Captured error:')
logging.warning(str(e))

def save_correlation_mat(self, path: str, labels: List[str] = None, **kwargs):
"""
Save a visual representation of the parameter correlation matrix.
Args:
path: Path to save figure to.
labels: Parameter labels.
"""

if plt is None:
raise matplotlib_exception

default_kwargs = {"fontsize": 14}
kwargs = {**default_kwargs, **kwargs}

if self.ae.correlation is None:
raise Exception('Fit AECs before saving correlation matrix!')

if labels is None:
labels = list(self.ae.atom_energies.keys())

fig, ax = plt.subplots()
ax.matshow(self.ae.correlation, cmap=plt.cm.PiYG)

# Superimpose values as text
for i in range(len(self.ae.correlation)):
for j in range(len(self.ae.correlation)):
c = self.ae.correlation[j, i]
ax.text(i, j, f'{c: .2f}', va='center', ha='center', fontsize=8)

# Save lims because they get changed when modifying labels
xlim = ax.get_xlim()
ylim = ax.get_ylim()

ax.set_xticks(list(range(len(self.ae.correlation))))
ax.set_yticks(list(range(len(self.ae.correlation))))
ax.set_xticklabels(labels, fontsize=kwargs["fontsize"])
ax.set_yticklabels(labels, fontsize=14)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.tick_params(bottom=False, top=False, left=False, right=False)

fig.savefig(path, bbox_inches='tight', pad_inches=0)



class AE:
"""
Expand All @@ -146,6 +207,7 @@ class AE:
def __init__(self, species_energies: Dict[str, float]):
self.species_energies = species_energies # Hartree
self.atom_energies = None
self.correlation = None # correlation matrix
self.confidence_intervals = None

for lbl in SPECIES_LABELS:
Expand Down Expand Up @@ -198,6 +260,9 @@ def fit(self):
ypred = x @ w
sigma2 = np.sum((y - ypred)**2) / (n - k) # MSE
cov = sigma2 * np.linalg.inv(x.T @ x) # covariance matrix

self.correlation = _covariance_to_correlation(cov)

se = np.sqrt(np.diag(cov)) # standard error
alpha = 0.05 # 95% confidence level
tdist = distributions.t.ppf(1 - alpha/2, n - k) # student-t
Expand Down
19 changes: 11 additions & 8 deletions arkane/encorr/bac.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
from scipy.stats import distributions
from sklearn.model_selection import KFold

try:
import matplotlib.pyplot as plt
except ImportError as e:
plt = None
matplotlib_exception = e

from rmgpy.quantity import ScalarQuantity

import arkane.encorr.data as data
Expand Down Expand Up @@ -190,10 +196,8 @@ def plot(self, output_directory: str, jobnum: int = 1):
output_directory: Save the plots in this directory.
jobnum: Job number
"""
try:
import matplotlib.pyplot as plt
except ImportError:
return
if plt is None:
raise matplotlib_exception

model_chemistry_formatted = self.level_of_theory.to_model_chem().replace('//', '__').replace('/', '_')
if self.crossval_n_folds == 1:
Expand Down Expand Up @@ -929,10 +933,9 @@ def save_correlation_mat(self, path: str, labels: List[str] = None):
path: Path to save figure to.
labels: Parameter labels.
"""
try:
import matplotlib.pyplot as plt
except ImportError:
return

if plt is None:
raise matplotlib_exception

if self.correlation is None:
raise BondAdditivityCorrectionError('Fit BACs before saving correlation matrix!')
Expand Down
2 changes: 1 addition & 1 deletion arkane/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def execute(self):
job.execute(output_directory=self.output_directory, plot=self.plot, jobnum=bacjob_num)
bacjob_num += 1
if isinstance(job, AEJob):
job.execute(output_file=output_file)
job.execute(output_directory=self.output_directory, plot=self.plot)

with open(chemkin_file, 'a') as f:
f.write('\n')
Expand Down
4 changes: 2 additions & 2 deletions documentation/source/users/arkane/running.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ to decrease; the latter two cause it to increase.
Suppressing plot creation
=========================

Arkane by default will generate many plot files. By adding the ``-p``/``--no-plot``
flag, Arkane will not generate any plots, reducing file size of output and
Arkane by default will generate many plot files and the previous files will be overwritten on subsequent runs.
By adding the ``-p``/``--no-plot`` flag, Arkane will not generate any plots, reducing file size of output and
increasing the calculation speed.

Help
Expand Down

0 comments on commit f4eb3b0

Please sign in to comment.