Skip to content

Commit

Permalink
moa config
Browse files Browse the repository at this point in the history
  • Loading branch information
stelais committed Apr 12, 2024
1 parent 37cf34f commit f3b93aa
Show file tree
Hide file tree
Showing 5 changed files with 253 additions and 4 deletions.
121 changes: 121 additions & 0 deletions microlensing/moa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import numpy as np

from ramjet.data_interface.moa_data_interface import MoaDataInterface
from ramjet.photometric_database.derived.moa_survey_light_curve_collection import MoaSurveyLightCurveCollection
from ramjet.photometric_database.standard_and_injected_light_curve_database import StandardAndInjectedLightCurveDatabase

from qusi.light_curve_collection import LabeledLightCurveCollection
from qusi.light_curve_dataset import LightCurveDataset
from qusi.light_curve_collection import LightCurveCollection



def positive_label_function(path):
return 1


def negative_label_function(path):
return 0

class MoaSurveyMicrolensingAndNonMicrolensingDatabase(StandardAndInjectedLightCurveDatabase):
"""
A class for a database of MOA light curves including non-microlensing, and microlensing collections.
"""
moa_data_interface = MoaDataInterface()

def __init__(self, test_split: int):
super().__init__()
validation_split = (test_split - 1) % 10
train_splits = list(range(10))
train_splits.remove(validation_split)
train_splits.remove(test_split)
# self.number_of_label_values = 1
# self.number_of_parallel_processes_per_map = 5
# self.time_steps_per_example = 18000
# self.shuffle_buffer_size = 1000
# self.include_time_as_channel = False

# Note that the NN has number_of_splits: int = 10 already set.
# Creating the training collection | splits [0, 1, 2, 3, 4, 5, 6, 7] = 80% of the data
self.negative_training = MoaSurveyLightCurveCollection(
survey_tags=['v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string],
label=0,
dataset_splits=train_splits)
self.positive_training = MoaSurveyLightCurveCollection(
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb'],
label=1,
dataset_splits=train_splits)

# Creating the validation collection | split [8] = 10% of the data
self.negative_validation = MoaSurveyLightCurveCollection(
survey_tags=['v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string],
label=0,
dataset_splits=[validation_split])
self.positive_validation = MoaSurveyLightCurveCollection(
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb'],
label=1,
dataset_splits=[validation_split])

# Creating the inference collection | split [9] = 10% of the data
self.negative_inference = MoaSurveyLightCurveCollection(
survey_tags=['v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string],
label=0,
dataset_splits=[test_split])
self.positive_inference = MoaSurveyLightCurveCollection(
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb'],
label=1,
dataset_splits=[test_split])
self.all_inference = MoaSurveyLightCurveCollection(
survey_tags=['c', 'cf', 'cp', 'cw', 'cs', 'cb',
'v', 'n', 'nr', 'm', 'j', self.moa_data_interface.no_tag_string],
label=np.nan,
dataset_splits=[test_split])

# QUSI structure
def get_microlensing_train_dataset(self):
positive_train_light_curve_collection = LabeledLightCurveCollection.new(
get_paths_function=self.positive_training.get_paths,
load_times_and_fluxes_from_path_function=self.positive_training.load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_train_light_curve_collection = LabeledLightCurveCollection.new(
get_paths_function=self.negative_training.get_paths,
load_times_and_fluxes_from_path_function=self.negative_training.load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
train_light_curve_dataset = LightCurveDataset.new(
standard_light_curve_collections=[positive_train_light_curve_collection,
negative_train_light_curve_collection])
# print('check "properties" of the train_light_curve_dataset', train_light_curve_dataset)
return train_light_curve_dataset

def get_microlensing_validation_dataset(self):
positive_validation_light_curve_collection = LabeledLightCurveCollection.new(
get_paths_function=self.positive_validation.get_paths,
load_times_and_fluxes_from_path_function=self.positive_validation.load_times_and_fluxes_from_path,
load_label_from_path_function=positive_label_function)
negative_validation_light_curve_collection = LabeledLightCurveCollection.new(
get_paths_function=self.negative_validation.get_paths,
load_times_and_fluxes_from_path_function=self.negative_validation.load_times_and_fluxes_from_path,
load_label_from_path_function=negative_label_function)
validation_light_curve_dataset = LightCurveDataset.new(
standard_light_curve_collections=[positive_validation_light_curve_collection,
negative_validation_light_curve_collection])
return validation_light_curve_dataset

def get_microlensing_infer_collection(self):
infer_light_curve_collection = LightCurveCollection.new(
get_paths_function=self.all_inference.get_paths,
load_times_and_fluxes_from_path_function=self.all_inference.load_times_and_fluxes_from_path)
return infer_light_curve_collection
# def get_microlensing_finite_test_dataset(self):
# positive_test_light_curve_collection = LabeledLightCurveCollection.new(
# get_paths_function=self.positive_inference.get_paths,
# load_times_and_fluxes_from_path_function=self.positive_inference.load_times_and_fluxes_from_path,
# load_label_from_path_function=positive_label_function)
# negative_test_light_curve_collection = LabeledLightCurveCollection.new(
# get_paths_function=self.negative_inference.get_paths,
# load_times_and_fluxes_from_path_function=self.negative_inference.load_times_and_fluxes_from_path,
# load_label_from_path_function=negative_label_function)
# test_light_curve_dataset = FiniteStandardLightCurveObservationDataset.new(
# standard_light_curve_collections=[positive_test_light_curve_collection,
# negative_test_light_curve_collection])
# return test_light_curve_dataset
69 changes: 69 additions & 0 deletions microlensing/moa_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
import pandas as pd

from moa_dataset import MoaSurveyMicrolensingAndNonMicrolensingDatabase

from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset
from qusi.hadryss_model import Hadryss
from qusi.infer_session import get_device, infer_session


def read_wandb_name_csv(test_split_):
df = pd.read_csv('inferences/wandb.csv')
df_temp = df[df["Tags"] == f"550k M vs NM Split {test_split_}"]
wandb_name = df_temp['Name'].values[0]
return wandb_name.strip()


def main(test_split_, type_, eventname_):
print('Test split #: ', test_split_)
print('Type: ', type_)
print('WANDB name: ', eventname_)
database = MoaSurveyMicrolensingAndNonMicrolensingDatabase(test_split=test_split_)
infer_light_curve_collection = database.get_microlensing_infer_collection()
test_light_curve_dataset = FiniteStandardLightCurveDataset.new(
light_curve_collections=[infer_light_curve_collection])

model = Hadryss.new()
device = get_device()
model.load_state_dict(torch.load(f'sessions/{eventname_}_latest_model.pt', map_location=device))
confidences = infer_session(infer_datasets=[test_light_curve_dataset], model=model,
batch_size=100, device=device)[0]
paths = list(database.all_inference.get_paths())
paths_with_confidences = zip(paths, confidences)
sorted_paths_with_confidences = sorted(
paths_with_confidences, key=lambda path_with_confidence: path_with_confidence[1], reverse=True)
print(sorted_paths_with_confidences)
df = pd.DataFrame(sorted_paths_with_confidences, columns=['Path', 'Score'])
df['Path'] = df['Path'].astype(str)
lightcurves_names = df['Path'].str.split('/').str[-1].str.split('.').str[0].str.split('_').str[-1]
# .str.split('.')[0].str.split('_')[-1]
df['lightcurve_name'] = lightcurves_names
df.to_csv(f'inferences/results_{type_}_{test_split_}.csv')

print()


if __name__ == '__main__':
import sys
import time
start_time = time.time()
# total arguments
n = len(sys.argv)
print("Total arguments passed:", n)
# Arguments passed
python_script_name = sys.argv[0]
split_number = int(sys.argv[1])
wandb_name = str(read_wandb_name_csv(split_number))

main(test_split_=split_number, type_='550k', eventname_=wandb_name)

# main(test_split_=int(0), type_='550k', eventname_='gs66-ponyta')
# main(test_split_=int(0), type_='550k', eventname_='graceful-serenity-51')
end_time = time.time()
print('Time taken: ', end_time - start_time)
# gs66-fugu-550k
# main(test_split_=int(0), type_='550k', eventname_='confused-resonance-22')

# The below work on my computer. above does not work on fugu
# main(test_split_=int(0), type_='550k', eventname_='gs66-ponyta')
54 changes: 54 additions & 0 deletions microlensing/moa_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from qusi.hadryss_model import Hadryss
from qusi.train_hyperparameter_configuration import TrainHyperparameterConfiguration
from qusi.train_logging_configuration import TrainLoggingConfiguration
from qusi.train_session import train_session

from torchmetrics.classification import (BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinarySpecificity,
BinaryStatScores)

from moa_dataset import MoaSurveyMicrolensingAndNonMicrolensingDatabase
from wrapped_metrics import WrappedBinaryPrecision, WrappedBinaryRecall

from tqdm import tqdm

def main(test_split):
# WAND
logging_configuration = TrainLoggingConfiguration.new(wandb_project='qusi_moa', wandb_entity='ramjet')

# Database
database = MoaSurveyMicrolensingAndNonMicrolensingDatabase(test_split=test_split)
train_light_curve_dataset = database.get_microlensing_train_dataset()
validation_light_curve_dataset = database.get_microlensing_validation_dataset()

# model and config
model = Hadryss.new()
train_hyperparameter_configuration = TrainHyperparameterConfiguration.new(
batch_size=100, cycles=50, train_steps_per_cycle=100, validation_steps_per_cycle=10)

# Metrics
# metric_functions = [BinaryAccuracy(), BinaryAUROC(), BinaryRecall(),
# BinaryPrecision(), BinaryROC(), BinaryConfusionMatrix()]

metric_functions = [BinaryAccuracy(), BinaryAUROC(), BinaryF1Score(), BinarySpecificity(),
WrappedBinaryPrecision(), WrappedBinaryRecall()]
# metric_functions = [BinaryAccuracy()]

# Train!
train_session(train_datasets=[train_light_curve_dataset], validation_datasets=[validation_light_curve_dataset],
model=model, hyperparameter_configuration=train_hyperparameter_configuration,
logging_configuration=logging_configuration, metric_functions=metric_functions)


if __name__ == '__main__':
import sys
# total arguments
n = len(sys.argv)
print("Total arguments passed:", n)
# Arguments passed
print("\nName of Python script:", sys.argv[0])
print("\nSplit #:", sys.argv[1])
# for i in tqdm([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]):
# print('split is ', i)
# main(test_split=i)

main(test_split=int(sys.argv[1]))
7 changes: 6 additions & 1 deletion src/qusi/infer_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch.nn import Module
from torch.types import Device
from torch.utils.data import DataLoader
from tqdm import tqdm

from qusi.finite_standard_light_curve_dataset import FiniteStandardLightCurveDataset

Expand All @@ -24,19 +25,23 @@ def infer_session(

def get_device() -> Device:
if torch.cuda.is_available():
print("Using CUDA")
device = torch.device("cuda")
else:
print("Using CPU")
device = torch.device("cpu")
return device


def infer_phase(dataloader, model: Module, device: Device):
batch_count = 0
batches_of_predicted_targets = []
model = model.to(device=device)
model.eval()
with torch.no_grad():
for input_features in dataloader:
for input_features in tqdm(dataloader):
input_features_on_device = input_features.to(device, non_blocking=True)
input_features_on_device = input_features_on_device.to(dtype=torch.float32) # SIS added this line
batch_predicted_targets = model(input_features_on_device)
batches_of_predicted_targets.append(batch_predicted_targets)
batch_count += 1
Expand Down
6 changes: 3 additions & 3 deletions src/ramjet/data_interface/moa_data_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def survey_tag_to_path_list_dictionary(self) -> dict[str, list[Path]]:
"""
if self.survey_tag_to_path_list_dictionary_ is None:
takahiro_sumi_nine_year_events_data_frame = self.read_corrected_nine_year_events_table_as_data_frame(
Path("data/moa_microlensing/moa9yr_events_oct2018.txt")
Path("data/moa_microlensing_550k/candlist_2023Oct12.txt")
)
self.survey_tag_to_path_list_dictionary_ = self.group_paths_by_tag_in_events_data_frame(
list(Path("data/moa_microlensing").glob("**/*.cor.feather")), takahiro_sumi_nine_year_events_data_frame
list(Path("data/moa_microlensing_550k").glob("**/*.cor.feather")), takahiro_sumi_nine_year_events_data_frame
)
return self.survey_tag_to_path_list_dictionary_

Expand Down Expand Up @@ -107,7 +107,7 @@ def get_yuki_hirao_events_data_frame() -> pd.DataFrame:
tbl = soup.find("table")
events_data_frame = pd.read_html(str(tbl))[0]
events_data_frame[["field", "clr", "chip", "subfield", "id"]] = events_data_frame["MOA INTERNAL ID"].str.split(
"-", 4, expand=True
"-", n=4, expand=True
)
events_data_frame["chip"] = events_data_frame["chip"].astype(int)
events_data_frame["subfield"] = events_data_frame["subfield"].astype(int)
Expand Down

0 comments on commit f3b93aa

Please sign in to comment.