@@ -207,7 +207,14 @@ def save_model_to_file(self, savefile=None, limit_size=False):
207
207
self .get_savename (how = 'save' )
208
208
savefile = self .savenames [0 ]+ '_model.pickle'
209
209
if hasattr (self ,'trace' ):
210
- self .trace .to_netcdf (self .savenames [0 ]+ '_trace.nc' )
210
+ try :
211
+ self .trace .to_netcdf (self .savenames [0 ]+ '_trace.nc' )
212
+ except :
213
+ try :
214
+ #Stacking/unstacking removes Multitrace objects:
215
+ self .trace .unstack ().to_netcdf (self .savenames [0 ]+ '_trace.nc' )
216
+ except :
217
+ print ("Still a save error after unstacking" )
211
218
excl_types = [az .InferenceData ]
212
219
cloudpickle .dumps ({attr :getattr (self ,attr ) for attr in self .__dict__ },open (savefile ,'wb' ))
213
220
@@ -4643,7 +4650,7 @@ def plot_corner(self,corner_vars=None,use_marg=True,truths=None):
4643
4650
if not self.assume_circ:
4644
4651
corner_vars+=['multi_ecc','multi_omega']
4645
4652
'''
4646
- samples = pm . trace_to_dataframe ( self .trace , varnames = corner_vars )
4653
+ samples = self .make_table ( cols = corner_vars )
4647
4654
#print(samples.shape,samples.columns)
4648
4655
assert samples .shape [1 ]< 50
4649
4656
@@ -4715,13 +4722,13 @@ def make_table(self,short=True,save=True,cols=['all']):
4715
4722
cols_to_remove += ['mono_' + col + 's' ,'duo_' + col + 's' ]
4716
4723
medvars = [var for var in self .trace if not np .any ([icol in var for icol in cols_to_remove ])]
4717
4724
#print(cols_to_remove, medvars)
4718
- df = pm .summary (self .trace ,var_names = medvars ,stat_funcs = {"5%" : lambda x : np .percentile (x , 5 ),
4725
+ df = pm .summary (self .trace . unstack () ,var_names = medvars ,stat_funcs = {"5%" : lambda x : np .percentile (x , 5 ),
4719
4726
"-$1\sigma$" : lambda x : np .percentile (x , 15.87 ),
4720
4727
"median" : lambda x : np .percentile (x , 50 ),
4721
4728
"+$1\sigma$" : lambda x : np .percentile (x , 84.13 ),
4722
4729
"95%" : lambda x : np .percentile (x , 95 )},round_to = 5 )
4723
4730
else :
4724
- df = pm .summary (self .trace ,var_names = cols ,stat_funcs = {"5%" : lambda x : np .percentile (x , 5 ),
4731
+ df = pm .summary (self .trace . unstack () ,var_names = cols ,stat_funcs = {"5%" : lambda x : np .percentile (x , 5 ),
4725
4732
"-$1\sigma$" : lambda x : np .percentile (x , 15.87 ),
4726
4733
"median" : lambda x : np .percentile (x , 50 ),
4727
4734
"+$1\sigma$" : lambda x : np .percentile (x , 84.13 ),
@@ -4880,23 +4887,22 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
4880
4887
print ("time range" ,Time (time_start + self .lc .jd_base ,format = 'jd' ).isot ,
4881
4888
"->" ,Time (time_end + self .lc .jd_base ,format = 'jd' ).isot )
4882
4889
4883
-
4884
4890
if check_TESS :
4885
4891
sect_start_ends = self .check_TESS ()
4886
4892
4887
4893
all_trans_fin = pd .DataFrame ()
4888
4894
loopplanets = self .duos + self .trios + self .multis if include_multis else self .duos + self .trios
4889
-
4895
+ all_unq_trans = []
4890
4896
for pl in loopplanets :
4891
4897
all_trans = pd .DataFrame ()
4892
4898
if pl in self .duos + self .trios :
4893
4899
sum_all_probs = np .logaddexp .reduce (np .nanmedian (self .trace ['logprob_marg_' + pl ],axis = 0 ))
4894
- trans_p0 = np .floor (np .nanmedian (time_start - self .trace ['t0_2_' + pl ])/ np .nanmedian (self .trace ['per_' + pl ],axis = 0 ))
4895
- trans_p1 = np .ceil (np .nanmedian (time_end - self .trace ['t0_2_' + pl ])/ np .nanmedian (self .trace ['per_' + pl ],axis = 0 ))
4900
+ trans_p0 = np .floor (np .nanmedian (time_start - self .trace ['t0_2_' + pl ]. values )/ np .nanmedian (self .trace ['per_' + pl ]. values ,axis = 0 ))
4901
+ trans_p1 = np .ceil (np .nanmedian (time_end - self .trace ['t0_2_' + pl ]. values )/ np .nanmedian (self .trace ['per_' + pl ]. values ,axis = 0 ))
4896
4902
n_trans = trans_p1 - trans_p0
4897
4903
elif pl in self .multis :
4898
- trans_p0 = [np .floor (np .nanmedian (time_start - self .trace ['t0_' + pl ])/ np .nanmedian (self .trace ['per_' + pl ], axis = 0 ))]
4899
- trans_p1 = [np .ceil (np .nanmedian (time_end - self .trace ['t0_' + pl ])/ np .nanmedian (self .trace ['per_' + pl ], axis = 0 ))]
4904
+ trans_p0 = [np .floor (np .nanmedian (time_start - self .trace ['t0_' + pl ]. values )/ np .nanmedian (self .trace ['per_' + pl ]. values ))]
4905
+ trans_p1 = [np .ceil (np .nanmedian (time_end - self .trace ['t0_' + pl ]. values )/ np .nanmedian (self .trace ['per_' + pl ]. values ))]
4900
4906
n_trans = [trans_p1 [0 ]- trans_p0 [0 ]]
4901
4907
#print(pl,trans_p0,trans_p1,n_trans)
4902
4908
#print(np.nanmedian(self.trace['t0_2_'+pl])+np.nanmedian(self.trace['per_'+pl],axis=0)*trans_p0)
@@ -4909,21 +4915,21 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
4909
4915
if 'tdur' in self .fit_params or pl in self .multis :
4910
4916
dur = np .nanpercentile (self .trace ['tdur_' + pl ],percentiles )
4911
4917
naliases = [0 ] if pl in self .multis else np .arange (self .planets [pl ]['npers' ])
4918
+ idfs = []
4912
4919
for nd in naliases :
4913
4920
if n_trans [nd ]> 0 :
4914
4921
if pl in self .duos + self .trios :
4915
4922
int_alias = int (self .planets [pl ]['period_int_aliases' ][nd ])
4916
- transits = np .nanpercentile (np .vstack ([self .trace ['t0_2_' + pl ]+ ntr * self .trace ['per_' + pl ][:,nd ] for ntr in np .arange (trans_p0 [nd ],trans_p1 [nd ])]),percentiles ,axis = 1 )
4923
+ transits = np .nanpercentile (np .vstack ([self .trace ['t0_2_' + pl ]. values + ntr * self .trace ['per_' + pl ]. values [:,nd ] for ntr in np .arange (trans_p0 [nd ],trans_p1 [nd ])]),percentiles ,axis = 1 )
4917
4924
if 'tdur' in self .marginal_params :
4918
4925
dur = np .nanpercentile (self .trace ['tdur_' + pl ][:,nd ],percentiles )
4919
4926
logprobs = np .nanmedian (self .trace ['logprob_marg_' + pl ][:,nd ])- sum_all_probs
4920
4927
else :
4921
- transits = np .nanpercentile (np .vstack ([self .trace ['t0_' + pl ]+ ntr * self .trace ['per_' + pl ] for ntr in np .arange (trans_p0 [nd ],trans_p1 [nd ])]),percentiles ,axis = 1 )
4928
+ transits = np .nanpercentile (np .column_stack ([self .trace ['t0_' + pl ]. values + ntr * self .trace ['per_' + pl ]. values for ntr in np .arange (trans_p0 [nd ],trans_p1 [nd ], 1.0 )]),percentiles ,axis = 0 )
4922
4929
int_alias = 1
4923
4930
logprobs = np .array ([0.0 ])
4924
4931
#Getting the aliases for this:
4925
-
4926
- idf = pd .DataFrame ({'transit_mid_date' :Time (transits [2 ]+ self .lc .jd_base ,format = 'jd' ).isot ,
4932
+ idfs += [pd .DataFrame ({'transit_mid_date' :Time (transits [2 ]+ self .lc .jd_base ,format = 'jd' ).isot ,
4927
4933
'transit_mid_med' :transits [2 ],
4928
4934
'transit_dur_med' :np .tile (dur [2 ],len (transits [2 ])),
4929
4935
'transit_dur_-1sig' :np .tile (dur [1 ],len (transits [2 ])),
@@ -4945,8 +4951,8 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
4945
4951
'prob' :np .tile (np .exp (logprobs ),len (transits [2 ])),
4946
4952
'planet_name' :np .tile ('multi_' + pl ,len (transits [2 ])) if pl in self .multis else np .tile ('duo_' + pl ,len (transits [2 ])),
4947
4953
'alias_n' :np .tile (nd ,len (transits [2 ])),
4948
- 'alias_p' :np .tile (np .nanmedian (self .trace ['per_' + pl ][:,nd ]),len (transits [2 ])) if pl in self .duos + self .trios else np .nanmedian (self .trace ['per_' + pl ])})
4949
- all_trans = all_trans . append ( idf )
4954
+ 'alias_p' :np .tile (np .nanmedian (self .trace ['per_' + pl ]. values [:,nd ]),len (transits [2 ])) if pl in self .duos + self .trios else np .tile ( np . nanmedian (self .trace ['per_' + pl ]. values ), len ( transits [ 2 ]))})]
4955
+ all_trans = pd . concat ( idfs )
4950
4956
unq_trans = all_trans .sort_values ('log_prob' ).copy ().drop_duplicates ('transit_fractions' )
4951
4957
unq_trans = unq_trans .set_index (np .arange (len (unq_trans )))
4952
4958
unq_trans ['aliases_ns' ]= unq_trans ['alias_n' ].values .astype (str )
@@ -4960,7 +4966,8 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
4960
4966
unq_trans .loc [i ,'aliases_ps' ]= ',' .join (list (np .round (oths ['alias_p' ].values ,4 ).astype (str )))
4961
4967
unq_trans .loc [i ,'num_aliases' ]= len (oths )
4962
4968
unq_trans .loc [i ,'total_prob' ]= np .sum (oths ['prob' ].values )
4963
- all_trans_fin = all_trans_fin .append (unq_trans )
4969
+ all_unq_trans += [unq_trans ]
4970
+ all_trans_fin = pd .concat (all_unq_trans )
4964
4971
all_trans_fin = all_trans_fin .loc [(all_trans_fin ['transit_end_+2sig' ]> time_start )* (all_trans_fin ['transit_start_-2sig' ]< time_end )].sort_values ('transit_mid_med' )
4965
4972
all_trans_fin = all_trans_fin .set_index (np .arange (len (all_trans_fin )))
4966
4973
@@ -5006,7 +5013,7 @@ def cheops_RMS(self, Gmag, tdur):
5006
5013
def make_cheops_OR (self , DR2ID = None , pl = None , min_eff = 45 , oot_min_orbits = 1.0 , timing_sigma = 3 , t_start = None , t_end = None , Texp = None ,
5007
5014
max_orbits = 14 , min_pretrans_orbits = 0.5 , min_intrans_orbits = None , orbits_flex = 1.4 , observe_sigma = 2 ,
5008
5015
observe_threshold = None , max_ORs = None ,prio_1_threshold = 0.25 , prio_3_threshold = 0.0 , targetnamestring = None ,
5009
- min_orbits = 4.0 , outfilesuffix = '_output_ORs.csv' ,avoid_TESS = True ,pre_post_TESS = "pre" ):
5016
+ min_orbits = 4.0 , outfilesuffix = '_output_ORs.csv' ,avoid_TESS = True , pre_post_TESS = "pre" , prog_id = "0072 " ):
5010
5017
"""Given a list of observable transits (which are outputted from `trace_to_cheops_transits`),
5011
5018
create a csv which can be run by pycheops make_xml_files to produce input observing requests (both to FC and observing tool).
5012
5019
@@ -5057,7 +5064,6 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
5057
5064
gaiainfo = Gaia .launch_job_async ("SELECT * \
5058
5065
FROM gaiadr2.gaia_source \
5059
5066
WHERE gaiadr2.gaia_source.source_id="+ str (DR2ID )).results .to_pandas ().iloc [0 ]
5060
-
5061
5067
gaia_colour = (gaiainfo ['phot_bp_mean_mag' ]- gaiainfo ['phot_rp_mean_mag' ])
5062
5068
V = gaiainfo ['phot_g_mean_mag' ]+ 0.0176 + 0.00686 * gaia_colour + 0.1732 * gaia_colour ** 2
5063
5069
Verr = 1.09 / gaiainfo ['phot_g_mean_flux_over_error' ]+ 0.045858
@@ -5157,7 +5163,8 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
5157
5163
ser ['_DEJ2000' ]= old_radec .dec .to_string (sep = ':' )
5158
5164
ser ['pmra' ]= gaiainfo ['pmra' ]
5159
5165
ser ['pmdec' ]= gaiainfo ['pmdec' ]
5160
- ser ['parallax' ]= gaiainfo ['plx' ]
5166
+
5167
+ ser ['parallax' ]= gaiainfo ['plx' ] if 'plx' in gaiainfo else gaiainfo ['parallax' ]
5161
5168
ser ['SpTy' ]= SpTy
5162
5169
ser ['Gmag' ]= gaiainfo ['phot_g_mean_mag' ]
5163
5170
ser ['dr2_g_mag' ]= gaiainfo ['phot_g_mean_mag' ]
@@ -5167,7 +5174,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
5167
5174
ser ['Vmag' ]= V
5168
5175
ser ['e_Vmag' ]= Verr
5169
5176
5170
- ser ['Programme_ID' ]= '0048'
5177
+ ser ['Programme_ID' ]= prog_id
5171
5178
ser ['BJD_early' ]= t_start
5172
5179
ser ['BJD_late' ]= t_end
5173
5180
#Total observing time must cover duration, and either the full timing bound (i.e. assuming 3 sigma), or the oot_min_orbits (if the timing precision is better than the oot_min_orbits)
@@ -5236,7 +5243,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
5236
5243
#ser["EndPh1"]=((row['end_earliest']-row['mid'])/100)
5237
5244
#ser["Effic1"]=50
5238
5245
ser ['N_Ranges' ]= 0
5239
- out_tab = out_tab . append ( pd .Series (ser ,name = nper ) )
5246
+ out_tab . loc [ nper , list ( ser . keys ())] = pd .Series (ser ,name = nper )
5240
5247
out_tab ['MinEffDur' ]= out_tab ['MinEffDur' ].values .astype (int )
5241
5248
#print(98.77*60*out_tab['T_visit'].values)
5242
5249
out_tab ['T_visit' ]= (98.77 * 60 * out_tab ['T_visit' ].values ).astype (int ) #in seconds
@@ -5384,12 +5391,12 @@ def to_latex_table(self,varnames='all',order='columns'):
5384
5391
print ("Making Latex Table" )
5385
5392
if not hasattr (self ,'savenames' ):
5386
5393
self .get_savename (how = 'save' )
5387
- if self .tracemask is None :
5394
+ if not hasattr ( self , 'tracemask' ) or self .tracemask is None :
5388
5395
self .tracemask = np .tile (True ,len (self .trace ['Rs' ]))
5389
5396
if varnames is None or varnames == 'all' :
5390
5397
varnames = [var for var in self .trace if var [- 2 :]!= '__' and var not in ['gp_pred' ,'light_curves' ]]
5391
5398
5392
- self .samples = pm . trace_to_dataframe ( self .trace , varnames = varnames )
5399
+ self .samples = self .make_table ( cols = varnames )
5393
5400
self .samples = self .samples .loc [self .tracemask ]
5394
5401
facts = {'r_pl' :109.07637 ,'Ms' :1.0 ,'rho' :1.0 ,"t0" :1.0 ,"period" :1.0 ,"vrel" :1.0 ,"tdur" :24 }
5395
5402
units = {'r_pl' :"$ R_\\ oplus $" ,'Ms' :"$ M_\\ odot $" ,'rho' :"$ \\ rho_\\ odot $" ,
0 commit comments