-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
252 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import numpy as np | ||
|
||
from ramjet.data_interface.moa_data_interface import MoaDataInterface | ||
from ramjet.photometric_database.derived.moa_survey_light_curve_collection import MoaSurveyLightCurveCollection | ||
from ramjet.photometric_database.standard_and_injected_light_curve_database import StandardAndInjectedLightCurveDatabase | ||
|
||
from qusi.light_curve_collection import LabeledLightCurveCollection | ||
from qusi.light_curve_dataset import LightCurveDataset | ||
from qusi.light_curve_collection import LightCurveCollection | ||
|
||
|
||
|
||
def positive_label_function(path): | ||
return 1 | ||
|
||
|
||
def negative_label_function(path): | ||
return 0 | ||
|
||
class MoaSurveyMicrolensingAndNonMicrolensingDatabase(StandardAndInjectedLightCurveDatabase): | ||
""" | ||
A class for a database of MOA light curves including non-microlensing, and microlensing collections. | ||
""" | ||
moa_data_interface = MoaDataInterface() | ||
|
||
def __init__(self, test_split: int): | ||
super().__init__() | ||
validation_split = (test_split - 1) % 10 | ||
train_splits = list(range(10)) | ||
train_splits.remove(validation_split) | ||
train_splits.remove(test_split) | ||
# self.number_of_label_values = 1 | ||
# self.number_of_parallel_processes_per_map = 5 | ||
# self.time_steps_per_example = 18000 | ||
# self.shuffle_buffer_size = 1000 | ||
# self.include_time_as_channel = False | ||
|
||
# Note that the NN has number_of_splits: int = 10 already set. | ||
# Creating the training collection | splits [0, 1, 2, 3, 4, 5, 6, 7] = 80% of the data | ||
self.negative_training = MoaSurveyLightCurveCollection( | ||
survey_tags=['v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string], | ||
label=0, | ||
dataset_splits=train_splits) | ||
self.positive_training = MoaSurveyLightCurveCollection( | ||
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb'], | ||
label=1, | ||
dataset_splits=train_splits) | ||
|
||
# Creating the validation collection | split [8] = 10% of the data | ||
self.negative_validation = MoaSurveyLightCurveCollection( | ||
survey_tags=['v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string], | ||
label=0, | ||
dataset_splits=[validation_split]) | ||
self.positive_validation = MoaSurveyLightCurveCollection( | ||
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb'], | ||
label=1, | ||
dataset_splits=[validation_split]) | ||
|
||
# Creating the inference collection | split [9] = 10% of the data | ||
self.negative_inference = MoaSurveyLightCurveCollection( | ||
survey_tags=['v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string], | ||
label=0, | ||
dataset_splits=[test_split]) | ||
self.positive_inference = MoaSurveyLightCurveCollection( | ||
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb'], | ||
label=1, | ||
dataset_splits=[test_split]) | ||
self.all_inference = MoaSurveyLightCurveCollection( | ||
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb', | ||
'v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string], | ||
label=np.nan, | ||
dataset_splits=[test_split]) | ||
|
||
# QUSI structure | ||
def get_microlensing_train_dataset(self): | ||
positive_train_light_curve_collection = LabeledLightCurveCollection.new( | ||
get_paths_function=self.positive_training.get_paths, | ||
load_times_and_fluxes_from_path_function=self.positive_training.load_times_and_fluxes_from_path, | ||
load_label_from_path_function=positive_label_function) | ||
negative_train_light_curve_collection = LabeledLightCurveCollection.new( | ||
get_paths_function=self.negative_training.get_paths, | ||
load_times_and_fluxes_from_path_function=self.negative_training.load_times_and_fluxes_from_path, | ||
load_label_from_path_function=negative_label_function) | ||
train_light_curve_dataset = LightCurveDataset.new( | ||
standard_light_curve_collections=[positive_train_light_curve_collection, | ||
negative_train_light_curve_collection]) | ||
# print('check "properties" of the train_light_curve_dataset', train_light_curve_dataset) | ||
return train_light_curve_dataset | ||
|
||
def get_microlensing_validation_dataset(self): | ||
positive_validation_light_curve_collection = LabeledLightCurveCollection.new( | ||
get_paths_function=self.positive_validation.get_paths, | ||
load_times_and_fluxes_from_path_function=self.positive_validation.load_times_and_fluxes_from_path, | ||
load_label_from_path_function=positive_label_function) | ||
negative_validation_light_curve_collection = LabeledLightCurveCollection.new( | ||
get_paths_function=self.negative_validation.get_paths, | ||
load_times_and_fluxes_from_path_function=self.negative_validation.load_times_and_fluxes_from_path, | ||
load_label_from_path_function=negative_label_function) | ||
validation_light_curve_dataset = LightCurveDataset.new( | ||
standard_light_curve_collections=[positive_validation_light_curve_collection, | ||
negative_validation_light_curve_collection]) | ||
return validation_light_curve_dataset | ||
|
||
def get_microlensing_infer_collection(self): | ||
infer_light_curve_collection = LightCurveCollection.new( | ||
get_paths_function=self.all_inference.get_paths, | ||
load_times_and_fluxes_from_path_function=self.all_inference.load_times_and_fluxes_from_path) | ||
return infer_light_curve_collection | ||
# def get_microlensing_finite_test_dataset(self): | ||
# positive_test_light_curve_collection = LabeledLightCurveCollection.new( | ||
# get_paths_function=self.positive_inference.get_paths, | ||
# load_times_and_fluxes_from_path_function=self.positive_inference.load_times_and_fluxes_from_path, | ||
# load_label_from_path_function=positive_label_function) | ||
# negative_test_light_curve_collection = LabeledLightCurveCollection.new( | ||
# get_paths_function=self.negative_inference.get_paths, | ||
# load_times_and_fluxes_from_path_function=self.negative_inference.load_times_and_fluxes_from_path, | ||
# load_label_from_path_function=negative_label_function) | ||
# test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new( | ||
# standard_light_curve_collections=[positive_test_light_curve_collection, | ||
# negative_test_light_curve_collection]) | ||
# return test_light_curve_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
import pandas as pd | ||
|
||
from moa_dataset import MoaSurveyMicrolensingAndNonMicrolensingDatabase | ||
|
||
from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset | ||
from qusi.hadryss_model import Hadryss | ||
from qusi.infer_session import get_device, infer_session | ||
|
||
|
||
def read_wandb_name_csv(test_split_): | ||
df = pd.read_csv('inferences/wandb.csv') | ||
df_temp = df[df["Tags"] == f"550k M vs NM Split {test_split_}"] | ||
wandb_name = df_temp['Name'].values[0] | ||
return wandb_name.strip() | ||
|
||
|
||
def main(test_split_, type_, eventname_): | ||
print('Test split #: ', test_split_) | ||
print('Type: ', type_) | ||
print('WANDB name: ', eventname_) | ||
database = MoaSurveyMicrolensingAndNonMicrolensingDatabase(test_split=test_split_) | ||
infer_light_curve_collection = database.get_microlensing_infer_collection() | ||
test_light_curve_dataset = FiniteStandardLightCurveDataset.new( | ||
light_curve_collections=[infer_light_curve_collection]) | ||
|
||
model = Hadryss.new() | ||
device = get_device() | ||
model.load_state_dict(torch.load(f'sessions/{eventname_}_latest_model.pt', map_location=device)) | ||
confidences = infer_session(infer_datasets=[test_light_curve_dataset], model=model, | ||
batch_size=100, device=device)[0] | ||
paths = list(database.all_inference.get_paths()) | ||
paths_with_confidences = zip(paths, confidences) | ||
sorted_paths_with_confidences = sorted( | ||
paths_with_confidences, key=lambda path_with_confidence: path_with_confidence[1], reverse=True) | ||
print(sorted_paths_with_confidences) | ||
df = pd.DataFrame(sorted_paths_with_confidences, columns=['Path', 'Score']) | ||
df['Path'] = df['Path'].astype(str) | ||
lightcurves_names = df['Path'].str.split('/').str[-1].str.split('.').str[0].str.split('_').str[-1] | ||
# .str.split('.')[0].str.split('_')[-1] | ||
df['lightcurve_name'] = lightcurves_names | ||
df.to_csv(f'inferences/results_{type_}_{test_split_}.csv') | ||
|
||
print() | ||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
import time | ||
start_time = time.time() | ||
# total arguments | ||
n = len(sys.argv) | ||
print("Total arguments passed:", n) | ||
# Arguments passed | ||
python_script_name = sys.argv[0] | ||
split_number = int(sys.argv[1]) | ||
wandb_name = str(read_wandb_name_csv(split_number)) | ||
|
||
main(test_split_=split_number, type_='550k', eventname_=wandb_name) | ||
|
||
# main(test_split_=int(0), type_='550k', eventname_='gs66-ponyta') | ||
# main(test_split_=int(0), type_='550k', eventname_='graceful-serenity-51') | ||
end_time = time.time() | ||
print('Time taken: ', end_time - start_time) | ||
# gs66-fugu-550k | ||
# main(test_split_=int(0), type_='550k', eventname_='confused-resonance-22') | ||
|
||
# The below work on my computer. above does not work on fugu | ||
# main(test_split_=int(0), type_='550k', eventname_='gs66-ponyta') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from qusi.hadryss_model import Hadryss | ||
from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration | ||
from qusi.train_logging_configuration import TrainLoggingConfiguration | ||
from qusi.train_session import train_session | ||
|
||
from torchmetrics.classification import (BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinarySpecificity, | ||
BinaryStatScores) | ||
|
||
from moa_dataset import MoaSurveyMicrolensingAndNonMicrolensingDatabase | ||
from wrapped_metrics import WrappedBinaryPrecision, WrappedBinaryRecall | ||
|
||
from tqdm import tqdm | ||
|
||
def main(test_split): | ||
# WAND | ||
logging_configuration = TrainLoggingConfiguration.new(wandb_project='qusi_moa', wandb_entity='ramjet') | ||
|
||
# Database | ||
database = MoaSurveyMicrolensingAndNonMicrolensingDatabase(test_split=test_split) | ||
train_light_curve_dataset = database.get_microlensing_train_dataset() | ||
validation_light_curve_dataset = database.get_microlensing_validation_dataset() | ||
|
||
# model and config | ||
model = Hadryss.new() | ||
train_hyperparameter_configuration = TrainHyperparameterConfiguration.new( | ||
batch_size=100, cycles=50, train_steps_per_cycle=100, validation_steps_per_cycle=10) | ||
|
||
# Metrics | ||
# metric_functions = [BinaryAccuracy(), BinaryAUROC(), BinaryRecall(), | ||
# BinaryPrecision(), BinaryROC(), BinaryConfusionMatrix()] | ||
|
||
metric_functions = [BinaryAccuracy(), BinaryAUROC(), BinaryF1Score(), BinarySpecificity(), | ||
WrappedBinaryPrecision(), WrappedBinaryRecall()] | ||
# metric_functions = [BinaryAccuracy()] | ||
|
||
# Train! | ||
train_session(train_datasets=[train_light_curve_dataset], validation_datasets=[validation_light_curve_dataset], | ||
model=model, hyperparameter_configuration=train_hyperparameter_configuration, | ||
logging_configuration=logging_configuration, metric_functions=metric_functions) | ||
|
||
|
||
if __name__ == '__main__': | ||
import sys | ||
# total arguments | ||
n = len(sys.argv) | ||
print("Total arguments passed:", n) | ||
# Arguments passed | ||
print("\nName of Python script:", sys.argv[0]) | ||
print("\nSplit #:", sys.argv[1]) | ||
# for i in tqdm([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]): | ||
# print('split is ', i) | ||
# main(test_split=i) | ||
|
||
main(test_split=int(sys.argv[1])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters