Skip to content

Commit

Permalink
Fix sngls_minifollowup event selection/ordering [v23_release_branch] (#…
Browse files Browse the repository at this point in the history
…4824)

* Revert to maks_n_clustered in sngls_minifollowup [v23_ branch]

* REquire pre-numpy 2.0

* Revert "REquire pre-numpy 2.0"

This reverts commit 839f4ec.

* Trig_dict must have an IFO

* Remove testing printing

* Errors uring rebase
  • Loading branch information
GarethCabournDavies authored Aug 8, 2024
1 parent bbaa604 commit 69874f0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 34 deletions.
48 changes: 15 additions & 33 deletions bin/minifollowups/pycbc_sngl_minifollowup
Original file line number Diff line number Diff line change
Expand Up @@ -249,48 +249,30 @@ if args.maximum_duration is not None:
logging.info('Finding loudest clustered events')
rank_method = stat.get_statistic_from_opts(args, [args.instrument])

extra_kwargs = {}
for inputstr in args.statistic_keywords:
try:
key, value = inputstr.split(':')
extra_kwargs[key] = value
except ValueError:
err_txt = "--statistic-keywords must take input in the " \
"form KWARG1:VALUE1 KWARG2:VALUE2 KWARG3:VALUE3 ... " \
"Received {}".format(args.statistic_keywords)
raise ValueError(err_txt)

logging.info("Calculating statistic for %d triggers", len(trigs.snr))
sds = rank_method.single(trigs)
stat = rank_method.rank_stat_single((args.instrument, sds), **extra_kwargs)
logging.info("Clustering events over %.3fs window", args.cluster_window)
cid = coinc.cluster_over_time(stat, trigs.end_time,
args.cluster_window)
trigs.apply_mask(cid)
stat = stat[cid]
if len(trigs.snr) < num_events:
num_events = len(trigs.snr)

logging.info("Finding the loudest triggers")
loudest_idx = sorted(numpy.argsort(stat)[::-1][:num_events])
trigs.apply_mask(loudest_idx)
stat = stat[loudest_idx]
trigs.mask_to_n_loudest_clustered_events(
rank_method,
n_loudest=num_events,
cluster_window=args.cluster_window,
)

times = trigs.end_time
tids = trigs.template_id
trig_stat = trigs.stat
trig_snrs = trigs.snr

if isinstance(trigs.mask, numpy.ndarray) and trigs.mask.dtype == bool:
trigger_ids = numpy.flatnonzero(trigs.mask)
else:
trigger_ids = trigs.mask

# loop over number of loudest events to be followed up
order = stat.argsort()[::-1]
order = trig_stat.argsort()[::-1]
for rank, num_event in enumerate(order):
logging.info('Processing event: %s', num_event)
logging.info('Processing event: %s', rank)

files = wf.FileList([])
time = times[num_event]
ifo_time = '%s:%s' %(args.instrument, str(time))
if isinstance(trigs.mask, numpy.ndarray) and trigs.mask.dtype == bool:
tid = numpy.flatnonzero(trigs.mask)[num_event]
else:
tid = trigs.mask[num_event]
tid = trigger_ids[num_event]
ifo_tid = '%s:%s' %(args.instrument, str(tid))

layouts += (mini.make_sngl_ifo(workflow, sngl_file, tmpltbank_file,
Expand Down
3 changes: 2 additions & 1 deletion pycbc/io/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,14 +483,15 @@ def checkbank(self, param):
% param)

def trig_dict(self):
"""Returns dict of the masked trigger valuse """
"""Returns dict of the masked trigger values """
mtrigs = {}
for k in self.trigs:
if len(self.trigs[k]) == len(self.trigs['end_time']):
if self.mask is not None:
mtrigs[k] = self.trigs[k][self.mask]
else:
mtrigs[k] = self.trigs[k][:]
mtrigs['ifo'] = self.ifo
return mtrigs

@classmethod
Expand Down

0 comments on commit 69874f0

Please sign in to comment.