Skip to content

Commit efe961a

Browse files
committed
Make executable. Add option to get ML model from config file
1 parent ce35553 commit efe961a

File tree

1 file changed

+46
-8
lines changed

1 file changed

+46
-8
lines changed

jwql/instrument_monitors/nircam_monitors/wisp_finder.py

100644100755
Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
3333
"""
3434

35+
import argparse
3536
import datetime
3637
import logging
3738
import os
@@ -51,7 +52,7 @@
5152
from jwql.utils.utils import get_config
5253
from jwql.website.apps.jwql.archive_database_update import files_in_filesystem
5354
from jwql.website.apps.jwql.models import Anomalies, RootFileInfo
54-
from . import prepare_wisp_pngs
55+
from jwql.instrument_monitors.nircam_monitors import prepare_wisp_pngs
5556

5657
if 1>0:
5758
os.environ.setdefault("DJANGO_SETTINGS_MODULE", "jwql.website.jwql_proj.settings")
@@ -137,6 +138,35 @@ def define_model_architecture():
137138
return model
138139

139140

141+
def define_options(parser=None, usage=None, conflict_handler='resolve'):
142+
"""Add command line options
143+
144+
Parrameters
145+
-----------
146+
parser : argparse.parser
147+
Parser object
148+
149+
usage : str
150+
Usage string
151+
152+
conflict_handler : str
153+
Conflict handling strategy
154+
155+
Returns
156+
-------
157+
parser : argparse.parser
158+
Parser object with added options
159+
"""
160+
if parser is None:
161+
parser = argparse.ArgumentParser(usage=usage, conflict_handler=conflict_handler)
162+
163+
parser.add_argument('-m', '--model_filename', type=str, default=None, help='Filename of saved ML model. (default=%(default)s)')
164+
parser.add_argument('-s', '--starting_date', type=float, default=None, help='Earliest MJD to search for data. If None, date is retrieved from database.')
165+
parser.add_argument('-e', '--ending_date', type=float, default=None, help='Latest MJD to search for data. If None, the current date is used.')
166+
parser.add_argument('-f', '--file_list', type=str, nargs='+', default=None, help='List of full paths to files to run the monitor on.')
167+
return parser
168+
169+
140170
def load_ml_model(model_filename):
141171
"""Load the ML model for wisp prediction
142172
@@ -145,13 +175,8 @@ def load_ml_model(model_filename):
145175
model_filename : str
146176
Location of file containing the model. e.g. /path/to/my_best_model.pth
147177
"""
148-
149-
#model = torch.load(model_filename)
150-
151178
model = define_model_architecture()
152179
model.load_state_dict(torch.load(model_filename))
153-
154-
155180
model.eval() # Set model to evaluation mode
156181
return model
157182

@@ -240,7 +265,7 @@ def remove_duplicate_files(file_list):
240265
return unique_files
241266

242267

243-
def run(model_filename, starting_date=None, ending_date=None, file_list=None):
268+
def run(model_filename=None, starting_date=None, ending_date=None, file_list=None):
244269
"""Run the wisp finder monitor. From user-input dates or dates retrieved from
245270
the database, query MAST for all NIRCam NRCB4 full-frame imaging mode data. For
246271
each file, create a png file continaing an image of the rate file, scaled to a
@@ -264,6 +289,11 @@ def run(model_filename, starting_date=None, ending_date=None, file_list=None):
264289
to run the wisp prediction for. If this list is provided, the MAST query
265290
is skipped.
266291
"""
292+
# If no model_filename is given, the retrieve the default model_filename
293+
# from the config file
294+
if model_filename is None:
295+
model_filename = get_config()['wisp_finder_ML_model']
296+
267297
if file_list is None:
268298

269299
# If ending_date is not provided, set it equal to the current time
@@ -308,7 +338,6 @@ def run(model_filename, starting_date=None, ending_date=None, file_list=None):
308338
# Create png
309339
working_dir = os.path.dirname(working_filepath)
310340
png_filename = prepare_wisp_pngs.run(working_filepath, out_dir=working_dir)
311-
print(png_filename)
312341

313342
# Predict
314343
prediction = predict_wisp(model, png_filename, transform)
@@ -339,3 +368,12 @@ def run(model_filename, starting_date=None, ending_date=None, file_list=None):
339368
else:
340369
print('What dates do we add to the database in this case?')
341370

371+
372+
if __name__ == '__main__':
373+
parser = define_options()
374+
args = parser.parse_args()
375+
376+
run(args.model_filename,
377+
file_list=args.file_list,
378+
starting_date=args.starting_date,
379+
ending_date=args.ending_date)

0 commit comments

Comments
 (0)