forked from gwastro/pycbc
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
max_kde and copying data to output file
- Loading branch information
1 parent
93cb2ff
commit 78060c8
Showing
1 changed file
with
33 additions
and
198 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,214 +1,49 @@ | ||
#!/usr/bin/env python | ||
|
||
# Copyright 2023 Jam Sadiq, Praveen Kumar | ||
# | ||
# This program is free software; you can redistribute it and/or modify it | ||
# under the terms of the GNU General Public License as published by the | ||
# Free Software Foundation; either version 3 of the License, or (at your | ||
# option) any later version. | ||
# | ||
# This program is distributed in the hope that it will be useful, but | ||
# WITHOUT ANY WARRANTY; without even the implied warranty of | ||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General | ||
# Public License for more details. | ||
|
||
import numpy, h5py, operator, argparse, logging | ||
import numpy, h5py, argparse, logging | ||
from pycbc import init_logging | ||
import pycbc.conversions as convert | ||
from pycbc import libutils | ||
from pycbc.events import triggers | ||
akde = libutils.import_optional('awkde') | ||
kf = libutils.import_optional('sklearn.model_selection') | ||
|
||
|
||
parser = argparse.ArgumentParser(description=__doc__) | ||
parser.add_argument('--template-file', required=True, help='Hdf5 file with ' | ||
'template masses and spins') | ||
parser.add_argument('--nfold-template', type=int, | ||
help='Number of k-folds for template KDE cross validation') | ||
parser.add_argument('--fit-param', nargs='+', required=True, | ||
help='Parameters over which KDE is calculated') | ||
parser.add_argument('--log-param', nargs='+', choices=['True', 'False'], | ||
required=True) | ||
parser.add_argument('--output-file', required=True, help='Name of .hdf output') | ||
parser.add_argument('--make-template-kde', action='store_true') | ||
parser.add_argument('--fom-plot', help='Make a FOM plot for cross-validation' | ||
' and save it as this file') | ||
parser.add_argument('--alpha-grid', type=float, nargs="+", | ||
help='Grid of choices of sensitivity parameter alpha for' | ||
' local bandwidth') | ||
parser.add_argument('--bw-grid', type=float, nargs='+', | ||
help='Grid of choices of global bandwidth') | ||
parser.add_argument('--seed', type=int, nargs="+", | ||
help='Random number generator seed') | ||
parser.add_argument('--mchirp-downsample-power', type=float, | ||
help='Exponent value for the power law distribution') | ||
parser.add_argument('--min-ratio', type=float, | ||
parser.add_argument('--template-files', nargs='+', required=True, | ||
help='Hdf5 files with template masses and spins') | ||
parser.add_argument('--output-file', required=True, help='Name of output .hdf file') | ||
parser.add_argument('--min-ratio', type=float, | ||
help='Minimum ratio for template_kde relative to the maximum') | ||
parser.add_argument('--verbose', action='store_true') | ||
args = parser.parse_args() | ||
init_logging(verbose=args.verbose, format='%(asctime)s %(message)s') | ||
|
||
|
||
assert len(args.fit_param) == len(args.log_param) | ||
|
||
def kde_awkde(x, x_grid, alp=0.5, gl_bandwidth=None, ret_kde=False): | ||
if gl_bandwidth is None: # Use default from awkde | ||
kde = akde.GaussianKDE(alpha=alp, diag_cov=True) | ||
else: | ||
kde = akde.GaussianKDE(glob_bw=gl_bandwidth, alpha=alp, diag_cov=True) | ||
|
||
kde.fit(x) | ||
y = kde.predict(x_grid) | ||
|
||
if ret_kde == True: | ||
return kde, y | ||
return y | ||
|
||
|
||
def kfcv_awkde(sample, bwchoice, alphachoice, k=2): | ||
""" | ||
Evaluate the K-fold cross validated log likelihood for an awKDE with | ||
specific bandwidth and sensitivity (alpha) parameters | ||
""" | ||
fomlist = [] | ||
kfold = kf.KFold(n_splits=k, shuffle=True, random_state=None) | ||
for train_index, test_index in kfold.split(sample): | ||
train, test = sample[train_index], sample[test_index] | ||
y = kde_awkde(train, test, alp=alphachoice, gl_bandwidth=bwchoice) | ||
# Figure of merit : log likelihood for training samples | ||
fomlist.append(numpy.sum(numpy.log(y))) | ||
|
||
# Return the sum over all K sets of training samples | ||
return numpy.sum(fomlist) | ||
|
||
|
||
def optimizedparam(sampleval, bwgrid, alphagrid, nfold=2): | ||
npoints, ndim = sampleval.shape | ||
FOM = {} | ||
for gbw in bwgrid: | ||
for alphavals in alphagrid: | ||
FOM[(gbw, alphavals)] = kfcv_awkde(sampleval, gbw, alphavals, | ||
k=nfold) | ||
optval = max(FOM.items(), key=operator.itemgetter(1))[0] | ||
optbw, optalpha = optval[0], optval[1] | ||
maxFOM = FOM[(optbw, optalpha)] | ||
|
||
# Plotting FOM parameters | ||
if args.fom_plot: | ||
import matplotlib.pyplot as plt | ||
fig = plt.figure(figsize=(12,8)) | ||
ax = fig.add_subplot(111) | ||
for bw in bwgrid: | ||
FOMlist = [FOM[(bw, al)] for al in alphagrid] | ||
ax.plot(alphagrid, FOMlist, label='{0:.3f}'.format(bw)) | ||
ax.plot(optalpha, maxFOM, 'ko', linewidth=10, label= | ||
r'$\alpha={0:.3f},bw={1:.3f}$'.format(optalpha, optbw)) | ||
ax.set_xlabel(r'$\alpha$', fontsize=15) | ||
ax.set_ylabel(r'$FOM$', fontsize=15) | ||
# Guess at a suitable range of FOM values to plot | ||
ax.set_ylim(maxFOM - 0.5 * npoints, maxFOM + 0.2 * npoints) | ||
ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.135), ncol=8) | ||
plt.savefig(args.fom_plot) | ||
plt.close() | ||
|
||
return optbw, optalpha | ||
temp_files = [h5py.File(file, 'r') for file in args.template_files] | ||
num_files = len(temp_files) | ||
if num_files < 2: | ||
raise ValueError("At least two input files are required.") | ||
|
||
|
||
# Obtaining template parameters | ||
temp_file = h5py.File(args.template_file, 'r') | ||
mass1 = temp_file['mass1'][:] | ||
tid = numpy.arange(len(mass1)) # Array of template ids | ||
mass_spin = triggers.get_mass_spin(temp_file, tid) | ||
|
||
# Creating output file and save datasets and attributes from input files | ||
f_dest = h5py.File(args.output_file, 'w') | ||
f_dest.create_dataset("template_id", data=tid) | ||
template_pars = [] | ||
for param, slog in zip(args.fit_param, args.log_param): | ||
pvals = triggers.get_param(param, args, *mass_spin) | ||
# Write the KDE param values to output file | ||
f_dest.create_dataset(param, data=pvals) | ||
if slog in ['False']: | ||
logging.info('Using param: %s', param) | ||
template_pars.append(pvals) | ||
elif slog in ['True']: | ||
logging.info('Using log param: %s', param) | ||
template_pars.append(numpy.log(pvals)) | ||
else: | ||
raise ValueError("invalid log param argument, use 'True', or 'False'") | ||
|
||
# Copy standard data to output file | ||
f_dest.attrs['fit_param'] = args.fit_param | ||
f_dest.attrs['log_param'] = args.log_param | ||
with h5py.File(args.template_file, "r") as f_src: | ||
f_src.copy(f_src["./"], f_dest["./"], "input_template_params") | ||
temp_samples = numpy.vstack((template_pars)).T | ||
|
||
|
||
if args.make_template_kde: | ||
|
||
# Rejection sampling to reduce computational load | ||
if args.mchirp_downsample_power is not None: | ||
logging.info('Downsampling with mchirp power ' | ||
f'{args.mchirp_downsample_power}') | ||
f_dest.attrs['mchirp_downsample_power'] = args.mchirp_downsample_power | ||
try: | ||
mchirp_index = args.fit_param.index('mchirp') | ||
except: | ||
raise ValueError("mchirp does not exist in args.fit_param") | ||
|
||
mc_vals = template_pars[mchirp_index] | ||
if args.log_param[mchirp_index]: | ||
mc_vals = numpy.exp(mc_vals) | ||
power_vals = mc_vals ** args.mchirp_downsample_power | ||
probabilities = power_vals / numpy.max(power_vals) | ||
|
||
seeds = args.seed | ||
seed_values = list(map(int, seeds)) | ||
f_dest.attrs['seeds'] = seed_values | ||
template_kde_list = [] | ||
for seed in seeds: | ||
numpy.random.seed(seed) | ||
rand_nums = numpy.random.uniform(0, 1, len(mass1)) | ||
ind = rand_nums < probabilities | ||
logging.info(f'{ind.sum()} templates after downsampling') | ||
kde_train_samples = temp_samples[ind] | ||
dataset_name = f'kde_train_samples_{seed}' | ||
f_dest.create_dataset(dataset_name, data=kde_train_samples) | ||
|
||
logging.info('Starting optimization of template KDE parameters') | ||
optbw, optalpha = optimizedparam(kde_train_samples, alphagrid=args.alpha_grid, | ||
bwgrid=args.bw_grid, nfold=args.nfold_template) | ||
logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha)) | ||
f_dest.attrs['alpha'] = optalpha | ||
f_dest.attrs['bandwidth'] = optbw | ||
logging.info('Evaluating template KDE at seed {}'.format(seed)) | ||
template_kde = kde_awkde(kde_train_samples, temp_samples, | ||
alp=optalpha, gl_bandwidth=optbw) | ||
template_kde_list.append(template_kde) | ||
|
||
template_kde = numpy.maximum.reduce(template_kde_list) | ||
# Compensation factor for downsampling of templates | ||
template_kde *= 1. / probabilities | ||
if args.min_ratio is not None: | ||
logging.info(f'Applying minimum template KDE ratio {args.min_ratio}') | ||
f_dest.attrs['min-kde-ratio'] = args.min_ratio | ||
min_val = args.min_ratio * numpy.max(template_kde) | ||
template_kde = numpy.maximum(template_kde, min_val) | ||
|
||
else: | ||
logging.info('Starting optimization of template KDE parameters') | ||
optbw, optalpha = optimizedparam(temp_samples, alphagrid=args.alpha_grid, | ||
bwgrid=args.bw_grid, nfold=args.nfold_template) | ||
logging.info('Bandwidth %.4f, alpha %.2f' % (optbw, optalpha)) | ||
logging.info('Evaluating template KDE') | ||
template_kde = kde_awkde(temp_samples, temp_samples, alp=optalpha, | ||
gl_bandwidth=optbw) | ||
|
||
f_dest.create_dataset("data_kde", data=template_kde) | ||
f_dest.attrs['stat'] = "template-kde_file" | ||
f_dest.attrs['template-file'] = args.template_file | ||
|
||
for name in temp_files[0]: | ||
if isinstance(temp_files[0][name], h5py.Dataset): | ||
if name == 'data_kde': | ||
data_combined = [temp_file['data_kde'][:] for temp_file in temp_files] | ||
template_kde = numpy.maximum.reduce(data_combined) | ||
elif all(isinstance(temp_file[name], h5py.Dataset) for temp_file in temp_files): | ||
dataset_values = [temp_file[name][:] for temp_file in temp_files] | ||
f_dest.create_dataset(name, data=dataset_values[0]) | ||
|
||
for attr_name in temp_files[0].attrs.keys(): | ||
attr_values = [temp_file.attrs.get(attr_name) for temp_file in temp_files] | ||
if all(numpy.array_equal(attr_values[0], attr_val) for attr_val in attr_values[1:]): | ||
f_dest.attrs[attr_name] = attr_values[0] | ||
|
||
if args.min_ratio is not None: | ||
logging.info(f'Applying minimum template KDE ratio {args.min_ratio}') | ||
f_dest.attrs['min-kde-ratio'] = args.min_ratio | ||
min_val = args.min_ratio * numpy.max(template_kde) | ||
template_kde = numpy.maximum(template_kde, min_val) | ||
f_dest.create_dataset('data_kde', data=template_kde) | ||
|
||
for temp_file in temp_files: | ||
temp_file.close() | ||
f_dest.close() | ||
|
||
logging.info('Done!') |