Skip to content

Commit afe6f6c

Browse files
committed
Committing the changes required to produce the results for the 2021 transit paper
1 parent 1ed893b commit afe6f6c

File tree

7 files changed

+95
-21
lines changed

7 files changed

+95
-21
lines changed

infer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,27 @@
33
import datetime
44
from pathlib import Path
55

6-
from ramjet.models.hades import Hades
6+
from ramjet.models.hades import Hades, FfiHades
7+
from ramjet.photometric_database.derived.tess_ffi_transit_databases import \
8+
TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase
79
from ramjet.photometric_database.derived.tess_two_minute_cadence_transit_databases import \
810
TessTwoMinuteCadenceStandardAndInjectedTransitDatabase
911
from ramjet.analysis.model_loader import get_latest_log_directory
1012
from ramjet.trial import infer
1113

12-
log_name = get_latest_log_directory(logs_directory='logs') # Uses the latest model in the log directory.
13-
# log_name = 'logs/baseline YYYY-MM-DD-hh-mm-ss' # Specify the path to the model to use.
14+
# log_name = get_latest_log_directory(logs_directory='logs') # Uses the latest model in the log directory.
15+
log_name = 'logs/FFI transit sai aeb FfiHades mag14 quick pos no neg cont from existing no random start 2020-12-19-16-10-27' # Specify the path to the model to use.
1416
saved_log_directory = Path(f'{log_name}')
1517
datetime_string = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
1618

1719
print('Setting up dataset...', flush=True)
18-
database = TessTwoMinuteCadenceStandardAndInjectedTransitDatabase()
20+
database = TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase()
1921
inference_dataset = database.generate_inference_dataset()
2022

2123
print('Loading model...', flush=True)
22-
model = Hades(database.number_of_label_types)
24+
model = FfiHades()
2325
model.load_weights(str(saved_log_directory.joinpath('model.ckpt'))).expect_partial()
2426

2527
print('Inferring...', flush=True)
2628
infer_results_path = saved_log_directory.joinpath(f'infer results {datetime_string}.csv')
27-
infer(model, inference_dataset, infer_results_path)
29+
infer(model, inference_dataset, infer_results_path, number_of_top_predictions_to_keep=5000)

ramjet/analysis/transit_fitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def round_value_to_significant_figures(value):
350350
if __name__ == '__main__':
351351
print('Opening Bokeh application on http://localhost:5006/')
352352
# Start the server.
353-
server = Server({'/': TransitFitter(tic_id=362043085).bokeh_application})
353+
server = Server({'/': TransitFitter(tic_id=297678377).bokeh_application})
354354
server.start()
355355
# Start the specific application on the server.
356356
server.io_loop.add_callback(server.show, "/")

ramjet/models/hades.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,48 @@ def call(self, inputs, training=False, mask=None):
4848
x = self.prediction_layer(x, training=training)
4949
outputs = self.reshape(x, training=training)
5050
return outputs
51+
52+
53+
class FfiHades(Model):
54+
def __init__(self):
55+
super().__init__()
56+
self.block0 = LightCurveNetworkBlock(filters=8, kernel_size=3, pooling_size=2, batch_normalization=False,
57+
dropout_rate=0)
58+
self.block1 = LightCurveNetworkBlock(filters=8, kernel_size=3, pooling_size=2)
59+
self.block2 = LightCurveNetworkBlock(filters=16, kernel_size=3, pooling_size=2)
60+
self.block3 = LightCurveNetworkBlock(filters=32, kernel_size=3, pooling_size=2)
61+
self.block4 = LightCurveNetworkBlock(filters=64, kernel_size=3, pooling_size=2)
62+
self.block5 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=2)
63+
self.block6 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=1)
64+
self.block7 = LightCurveNetworkBlock(filters=128, kernel_size=3, pooling_size=1)
65+
self.block8 = LightCurveNetworkBlock(filters=20, kernel_size=3, pooling_size=1, spatial=False)
66+
self.block9 = LightCurveNetworkBlock(filters=20, kernel_size=7, pooling_size=1)
67+
self.block10 = LightCurveNetworkBlock(filters=20, kernel_size=1, pooling_size=1, batch_normalization=False,
68+
dropout_rate=0)
69+
self.prediction_layer = Convolution1D(1, kernel_size=1, activation=sigmoid)
70+
self.reshape = Reshape([1])
71+
72+
def call(self, inputs, training=False, mask=None):
73+
"""
74+
The forward pass of the layer.
75+
76+
:param inputs: The input tensor.
77+
:param training: A boolean specifying if the layer should be in training mode.
78+
:param mask: A mask for the input tensor.
79+
:return: The output tensor of the layer.
80+
"""
81+
x = inputs
82+
x = self.block0(x, training=training)
83+
x = self.block1(x, training=training)
84+
x = self.block2(x, training=training)
85+
x = self.block3(x, training=training)
86+
x = self.block4(x, training=training)
87+
x = self.block5(x, training=training)
88+
x = self.block6(x, training=training)
89+
x = self.block7(x, training=training)
90+
x = self.block8(x, training=training)
91+
x = self.block9(x, training=training)
92+
x = self.block10(x, training=training)
93+
x = self.prediction_layer(x, training=training)
94+
outputs = self.reshape(x, training=training)
95+
return outputs

ramjet/photometric_database/derived/tess_ffi_eclipsing_binary_lightcurve_collection.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
"""
22
Code representing the collection of TESS two minute cadence lightcurves containing eclipsing binaries.
33
"""
4-
from typing import Union, List
4+
import pandas as pd
5+
from pathlib import Path
6+
from typing import Union, List, Iterable
57

68
from peewee import Select
79

@@ -56,3 +58,15 @@ def get_sql_query(self) -> Select:
5658
TessEclipsingBinaryMetadata.tic_id.not_in(transit_tic_id_query))
5759
query = query.where(TessFfiLightcurveMetadata.tic_id.in_(eclipsing_binary_tic_id_query))
5860
return query
61+
62+
63+
class TessFfiQuickTransitNegativeLightcurveCollection(TessFfiLightcurveCollection):
64+
def __init__(self, dataset_splits: Union[List[int], None] = None,
65+
magnitude_range: (Union[float, None], Union[float, None]) = (None, None)):
66+
super().__init__(dataset_splits=dataset_splits, magnitude_range=magnitude_range)
67+
self.label = 0
68+
69+
def get_paths(self) -> Iterable[Path]:
70+
data_frame = pd.read_csv('quick_negative_paths.csv')
71+
paths = list(map(Path, data_frame['Lightcurve path'].values))
72+
return paths

ramjet/photometric_database/derived/tess_ffi_lightcurve_collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ def get_sql_query(self) -> Select:
3434
:return: The SQL query.
3535
"""
3636
query = TessFfiLightcurveMetadata().select()
37-
query = self.order_by_dataset_split_with_random_start(query, TessFfiLightcurveMetadata.dataset_split,
38-
self.dataset_splits)
37+
# query = self.order_by_dataset_split_with_random_start(query, TessFfiLightcurveMetadata.dataset_split,
38+
# self.dataset_splits)
3939
if self.magnitude_range[0] is not None and self.magnitude_range[1] is not None:
4040
query = query.where(TessFfiLightcurveMetadata.magnitude.between(*self.magnitude_range))
4141
elif self.magnitude_range[0] is not None:

ramjet/photometric_database/derived/tess_ffi_transit_databases.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from ramjet.photometric_database.derived.tess_ffi_eclipsing_binary_lightcurve_collection import \
2-
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection
2+
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection, TessFfiQuickTransitNegativeLightcurveCollection
33
from ramjet.photometric_database.derived.tess_ffi_lightcurve_collection import TessFfiLightcurveCollection
44
from ramjet.photometric_database.derived.tess_ffi_transit_lightcurve_collections import \
55
TessFfiConfirmedTransitLightcurveCollection, TessFfiNonTransitLightcurveCollection
@@ -15,11 +15,11 @@ def __init__(self):
1515
super().__init__()
1616
self.batch_size = 1000
1717
self.time_steps_per_example = 1000
18-
self.shuffle_buffer_size = 100000
18+
self.shuffle_buffer_size = 10000
1919
self.out_of_bounds_injection_handling = OutOfBoundsInjectionHandlingMethod.RANDOM_INJECTION_LOCATION
2020

2121

22-
magnitude_range = (0, 11)
22+
magnitude_range = (0, 14)
2323

2424

2525
class TessFfiStandardTransitDatabase(TessFfiDatabase):
@@ -106,23 +106,32 @@ class TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase(TessFfiDataba
106106
"""
107107
def __init__(self):
108108
super().__init__()
109+
self.shuffle_buffer_size = 10000
110+
self.number_of_parallel_processes_per_map = 6
109111
self.training_standard_lightcurve_collections = [
110112
TessFfiConfirmedTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
111113
TessFfiNonTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
112114
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)),
113-
magnitude_range=magnitude_range)
115+
magnitude_range=magnitude_range),
116+
# TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)),
117+
# magnitude_range=magnitude_range)
114118
]
115119
self.training_injectee_lightcurve_collection = TessFfiNonTransitLightcurveCollection(
116120
dataset_splits=list(range(8)), magnitude_range=magnitude_range)
117121
self.training_injectable_lightcurve_collections = [
118122
TessFfiConfirmedTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
119123
TessFfiNonTransitLightcurveCollection(dataset_splits=list(range(8)), magnitude_range=magnitude_range),
120124
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)),
121-
magnitude_range=magnitude_range)
125+
magnitude_range=magnitude_range),
126+
# TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)),
127+
# magnitude_range=magnitude_range)
122128
]
123129
self.validation_standard_lightcurve_collections = [
124130
TessFfiConfirmedTransitLightcurveCollection(dataset_splits=[8], magnitude_range=magnitude_range),
125131
TessFfiNonTransitLightcurveCollection(dataset_splits=[8], magnitude_range=magnitude_range),
126132
TessFfiAntiEclipsingBinaryForTransitLightcurveCollection(dataset_splits=list(range(8)),
127-
magnitude_range=magnitude_range)
133+
magnitude_range=magnitude_range),
134+
# TessFfiQuickTransitNegativeLightcurveCollection(dataset_splits=list(range(8)),
135+
# magnitude_range=magnitude_range)
128136
]
137+
self.inference_lightcurve_collections = [TessFfiLightcurveCollection(magnitude_range=magnitude_range)]

train.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
from tensorflow.python.keras import callbacks
66
from tensorflow.python.keras.losses import BinaryCrossentropy
77

8-
from ramjet.models.hades import Hades
8+
from ramjet.basic_models import SimplePoolingLightcurveCnn2, FfiSimplePoolingLightcurveCnn2
9+
from ramjet.models.hades import Hades, FfiHades
10+
from ramjet.photometric_database.derived.tess_ffi_transit_databases import \
11+
TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase
912
from ramjet.photometric_database.derived.tess_two_minute_cadence_transit_databases import \
1013
TessTwoMinuteCadenceStandardAndInjectedTransitDatabase
1114

@@ -14,9 +17,9 @@ def train():
1417
"""Runs the training."""
1518
print('Starting training process...', flush=True)
1619
# Basic training settings.
17-
trial_name = f'baseline' # Add any desired run name details to this string.
18-
database = TessTwoMinuteCadenceStandardAndInjectedTransitDatabase()
19-
model = Hades(database.number_of_label_types)
20+
trial_name = f'FFI transit sai aeb FfiHades mag14 quick pos no neg cont from existing no random start' # Add any desired run name details to this string.
21+
model = FfiHades()
22+
database = TessFfiStandardAndInjectedTransitAntiEclipsingBinaryDatabase()
2023
# database.batch_size = 100 # Reducing the batch size may help if you are running out of memory.
2124
epochs_to_run = 1000
2225
logs_directory = 'logs'
@@ -33,14 +36,15 @@ def train():
3336
training_dataset, validation_dataset = database.generate_datasets()
3437
optimizer = tf.optimizers.Adam(learning_rate=1e-4)
3538
loss_metric = BinaryCrossentropy(name='Loss')
36-
metrics = [tf.keras.metrics.AUC(num_thresholds=20, name='Area_under_ROC_curve', multi_label=True),
39+
metrics = [tf.keras.metrics.AUC(num_thresholds=20, name='Area_under_ROC_curve'),
3740
tf.metrics.SpecificityAtSensitivity(0.9, name='Specificity_at_90_percent_sensitivity'),
3841
tf.metrics.SensitivityAtSpecificity(0.9, name='Sensitivity_at_90_percent_specificity'),
3942
tf.metrics.BinaryAccuracy(name='Accuracy'), tf.metrics.Precision(name='Precision'),
4043
tf.metrics.Recall(name='Recall')]
4144

4245
# Compile and train model.
4346
model.compile(optimizer=optimizer, loss=loss_metric, metrics=metrics)
47+
model.load_weights('/att/gpfsfs/briskfs01/ppl/golmsche/ramjet/logs/FFI transit sai aeb FfiHades mag13 quick pos no neg cont from existing no random start 2020-10-08-17-11-05/model.ckpt')
4448
try:
4549
model.fit(training_dataset, epochs=epochs_to_run, validation_data=validation_dataset,
4650
callbacks=[tensorboard_callback, model_checkpoint_callback], steps_per_epoch=5000,

0 commit comments

Comments
 (0)