Skip to content

Commit 5521a1e

Browse files
committed
Still bug fixing pymc/arviz stuff
1 parent 777e9eb commit 5521a1e

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

MonoTools/fit.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,14 @@ def save_model_to_file(self, savefile=None, limit_size=False):
207207
self.get_savename(how='save')
208208
savefile=self.savenames[0]+'_model.pickle'
209209
if hasattr(self,'trace'):
210-
self.trace.to_netcdf(self.savenames[0]+'_trace.nc')
210+
try:
211+
self.trace.to_netcdf(self.savenames[0]+'_trace.nc')
212+
except:
213+
try:
214+
#Stacking/unstacking removes Multitrace objects:
215+
self.trace.unstack().to_netcdf(self.savenames[0]+'_trace.nc')
216+
except:
217+
print("Still a save error after unstacking")
211218
excl_types=[az.InferenceData]
212219
cloudpickle.dumps({attr:getattr(self,attr) for attr in self.__dict__},open(savefile,'wb'))
213220

@@ -4643,7 +4650,7 @@ def plot_corner(self,corner_vars=None,use_marg=True,truths=None):
46434650
if not self.assume_circ:
46444651
corner_vars+=['multi_ecc','multi_omega']
46454652
'''
4646-
samples = pm.trace_to_dataframe(self.trace, varnames=corner_vars)
4653+
samples =self.make_table(cols=corner_vars)
46474654
#print(samples.shape,samples.columns)
46484655
assert samples.shape[1]<50
46494656

@@ -4715,13 +4722,13 @@ def make_table(self,short=True,save=True,cols=['all']):
47154722
cols_to_remove+=['mono_'+col+'s','duo_'+col+'s']
47164723
medvars=[var for var in self.trace if not np.any([icol in var for icol in cols_to_remove])]
47174724
#print(cols_to_remove, medvars)
4718-
df = pm.summary(self.trace,var_names=medvars,stat_funcs={"5%": lambda x: np.percentile(x, 5),
4725+
df = pm.summary(self.trace.unstack(),var_names=medvars,stat_funcs={"5%": lambda x: np.percentile(x, 5),
47194726
"-$1\sigma$": lambda x: np.percentile(x, 15.87),
47204727
"median": lambda x: np.percentile(x, 50),
47214728
"+$1\sigma$": lambda x: np.percentile(x, 84.13),
47224729
"95%": lambda x: np.percentile(x, 95)},round_to=5)
47234730
else:
4724-
df = pm.summary(self.trace,var_names=cols,stat_funcs={"5%": lambda x: np.percentile(x, 5),
4731+
df = pm.summary(self.trace.unstack(),var_names=cols,stat_funcs={"5%": lambda x: np.percentile(x, 5),
47254732
"-$1\sigma$": lambda x: np.percentile(x, 15.87),
47264733
"median": lambda x: np.percentile(x, 50),
47274734
"+$1\sigma$": lambda x: np.percentile(x, 84.13),
@@ -4880,23 +4887,22 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
48804887
print("time range",Time(time_start+self.lc.jd_base,format='jd').isot,
48814888
"->",Time(time_end+self.lc.jd_base,format='jd').isot)
48824889

4883-
48844890
if check_TESS:
48854891
sect_start_ends=self.check_TESS()
48864892

48874893
all_trans_fin=pd.DataFrame()
48884894
loopplanets = self.duos+self.trios+self.multis if include_multis else self.duos+self.trios
4889-
4895+
all_unq_trans=[]
48904896
for pl in loopplanets:
48914897
all_trans=pd.DataFrame()
48924898
if pl in self.duos+self.trios:
48934899
sum_all_probs=np.logaddexp.reduce(np.nanmedian(self.trace['logprob_marg_'+pl],axis=0))
4894-
trans_p0=np.floor(np.nanmedian(time_start - self.trace['t0_2_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))
4895-
trans_p1=np.ceil(np.nanmedian(time_end - self.trace['t0_2_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))
4900+
trans_p0=np.floor(np.nanmedian(time_start - self.trace['t0_2_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values,axis=0))
4901+
trans_p1=np.ceil(np.nanmedian(time_end - self.trace['t0_2_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values,axis=0))
48964902
n_trans=trans_p1-trans_p0
48974903
elif pl in self.multis:
4898-
trans_p0=[np.floor(np.nanmedian(time_start - self.trace['t0_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))]
4899-
trans_p1=[np.ceil(np.nanmedian(time_end - self.trace['t0_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))]
4904+
trans_p0=[np.floor(np.nanmedian(time_start - self.trace['t0_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values))]
4905+
trans_p1=[np.ceil(np.nanmedian(time_end - self.trace['t0_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values))]
49004906
n_trans=[trans_p1[0]-trans_p0[0]]
49014907
#print(pl,trans_p0,trans_p1,n_trans)
49024908
#print(np.nanmedian(self.trace['t0_2_'+pl])+np.nanmedian(self.trace['per_'+pl],axis=0)*trans_p0)
@@ -4909,21 +4915,21 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
49094915
if 'tdur' in self.fit_params or pl in self.multis:
49104916
dur=np.nanpercentile(self.trace['tdur_'+pl],percentiles)
49114917
naliases=[0] if pl in self.multis else np.arange(self.planets[pl]['npers'])
4918+
idfs=[]
49124919
for nd in naliases:
49134920
if n_trans[nd]>0:
49144921
if pl in self.duos+self.trios:
49154922
int_alias=int(self.planets[pl]['period_int_aliases'][nd])
4916-
transits=np.nanpercentile(np.vstack([self.trace['t0_2_'+pl]+ntr*self.trace['per_'+pl][:,nd] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1)
4923+
transits=np.nanpercentile(np.vstack([self.trace['t0_2_'+pl].values+ntr*self.trace['per_'+pl].values[:,nd] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1)
49174924
if 'tdur' in self.marginal_params:
49184925
dur=np.nanpercentile(self.trace['tdur_'+pl][:,nd],percentiles)
49194926
logprobs=np.nanmedian(self.trace['logprob_marg_'+pl][:,nd])-sum_all_probs
49204927
else:
4921-
transits=np.nanpercentile(np.vstack([self.trace['t0_'+pl]+ntr*self.trace['per_'+pl] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1)
4928+
transits=np.nanpercentile(np.column_stack([self.trace['t0_'+pl].values+ntr*self.trace['per_'+pl].values for ntr in np.arange(trans_p0[nd],trans_p1[nd],1.0)]),percentiles,axis=0)
49224929
int_alias=1
49234930
logprobs=np.array([0.0])
49244931
#Getting the aliases for this:
4925-
4926-
idf=pd.DataFrame({'transit_mid_date':Time(transits[2]+self.lc.jd_base,format='jd').isot,
4932+
idfs+=[pd.DataFrame({'transit_mid_date':Time(transits[2]+self.lc.jd_base,format='jd').isot,
49274933
'transit_mid_med':transits[2],
49284934
'transit_dur_med':np.tile(dur[2],len(transits[2])),
49294935
'transit_dur_-1sig':np.tile(dur[1],len(transits[2])),
@@ -4945,8 +4951,8 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
49454951
'prob':np.tile(np.exp(logprobs),len(transits[2])),
49464952
'planet_name':np.tile('multi_'+pl,len(transits[2])) if pl in self.multis else np.tile('duo_'+pl,len(transits[2])),
49474953
'alias_n':np.tile(nd,len(transits[2])),
4948-
'alias_p':np.tile(np.nanmedian(self.trace['per_'+pl][:,nd]),len(transits[2])) if pl in self.duos+self.trios else np.nanmedian(self.trace['per_'+pl])})
4949-
all_trans=all_trans.append(idf)
4954+
'alias_p':np.tile(np.nanmedian(self.trace['per_'+pl].values[:,nd]),len(transits[2])) if pl in self.duos+self.trios else np.tile(np.nanmedian(self.trace['per_'+pl].values),len(transits[2]))})]
4955+
all_trans=pd.concat(idfs)
49504956
unq_trans = all_trans.sort_values('log_prob').copy().drop_duplicates('transit_fractions')
49514957
unq_trans = unq_trans.set_index(np.arange(len(unq_trans)))
49524958
unq_trans['aliases_ns']=unq_trans['alias_n'].values.astype(str)
@@ -4960,7 +4966,8 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
49604966
unq_trans.loc[i,'aliases_ps']=','.join(list(np.round(oths['alias_p'].values,4).astype(str)))
49614967
unq_trans.loc[i,'num_aliases']=len(oths)
49624968
unq_trans.loc[i,'total_prob']=np.sum(oths['prob'].values)
4963-
all_trans_fin=all_trans_fin.append(unq_trans)
4969+
all_unq_trans+=[unq_trans]
4970+
all_trans_fin=pd.concat(all_unq_trans)
49644971
all_trans_fin = all_trans_fin.loc[(all_trans_fin['transit_end_+2sig']>time_start)*(all_trans_fin['transit_start_-2sig']<time_end)].sort_values('transit_mid_med')
49654972
all_trans_fin = all_trans_fin.set_index(np.arange(len(all_trans_fin)))
49664973

@@ -5006,7 +5013,7 @@ def cheops_RMS(self, Gmag, tdur):
50065013
def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, timing_sigma=3, t_start=None, t_end=None, Texp=None,
50075014
max_orbits=14, min_pretrans_orbits=0.5, min_intrans_orbits=None, orbits_flex=1.4, observe_sigma=2,
50085015
observe_threshold=None, max_ORs=None,prio_1_threshold=0.25, prio_3_threshold=0.0, targetnamestring=None,
5009-
min_orbits=4.0, outfilesuffix='_output_ORs.csv',avoid_TESS=True,pre_post_TESS="pre"):
5016+
min_orbits=4.0, outfilesuffix='_output_ORs.csv',avoid_TESS=True, pre_post_TESS="pre", prog_id="0072"):
50105017
"""Given a list of observable transits (which are outputted from `trace_to_cheops_transits`),
50115018
create a csv which can be run by pycheops make_xml_files to produce input observing requests (both to FC and observing tool).
50125019
@@ -5057,7 +5064,6 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
50575064
gaiainfo=Gaia.launch_job_async("SELECT * \
50585065
FROM gaiadr2.gaia_source \
50595066
WHERE gaiadr2.gaia_source.source_id="+str(DR2ID)).results.to_pandas().iloc[0]
5060-
50615067
gaia_colour=(gaiainfo['phot_bp_mean_mag']-gaiainfo['phot_rp_mean_mag'])
50625068
V=gaiainfo['phot_g_mean_mag']+0.0176+0.00686*gaia_colour+0.1732*gaia_colour**2
50635069
Verr=1.09/gaiainfo['phot_g_mean_flux_over_error']+0.045858
@@ -5157,7 +5163,8 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
51575163
ser['_DEJ2000']=old_radec.dec.to_string(sep=':')
51585164
ser['pmra']=gaiainfo['pmra']
51595165
ser['pmdec']=gaiainfo['pmdec']
5160-
ser['parallax']=gaiainfo['plx']
5166+
5167+
ser['parallax']=gaiainfo['plx'] if 'plx' in gaiainfo else gaiainfo['parallax']
51615168
ser['SpTy']=SpTy
51625169
ser['Gmag']=gaiainfo['phot_g_mean_mag']
51635170
ser['dr2_g_mag']=gaiainfo['phot_g_mean_mag']
@@ -5167,7 +5174,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
51675174
ser['Vmag']=V
51685175
ser['e_Vmag']=Verr
51695176

5170-
ser['Programme_ID']='0048'
5177+
ser['Programme_ID']=prog_id
51715178
ser['BJD_early']=t_start
51725179
ser['BJD_late']=t_end
51735180
#Total observing time must cover duration, and either the full timing bound (i.e. assuming 3 sigma), or the oot_min_orbits (if the timing precision is better than the oot_min_orbits)
@@ -5236,7 +5243,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
52365243
#ser["EndPh1"]=((row['end_earliest']-row['mid'])/100)
52375244
#ser["Effic1"]=50
52385245
ser['N_Ranges']=0
5239-
out_tab=out_tab.append(pd.Series(ser,name=nper))
5246+
out_tab.loc[nper,list(ser.keys())]=pd.Series(ser,name=nper)
52405247
out_tab['MinEffDur']=out_tab['MinEffDur'].values.astype(int)
52415248
#print(98.77*60*out_tab['T_visit'].values)
52425249
out_tab['T_visit']=(98.77*60*out_tab['T_visit'].values).astype(int) #in seconds
@@ -5384,12 +5391,12 @@ def to_latex_table(self,varnames='all',order='columns'):
53845391
print("Making Latex Table")
53855392
if not hasattr(self,'savenames'):
53865393
self.get_savename(how='save')
5387-
if self.tracemask is None:
5394+
if not hasattr(self,'tracemask') or self.tracemask is None:
53885395
self.tracemask=np.tile(True,len(self.trace['Rs']))
53895396
if varnames is None or varnames == 'all':
53905397
varnames=[var for var in self.trace if var[-2:]!='__' and var not in ['gp_pred','light_curves']]
53915398

5392-
self.samples = pm.trace_to_dataframe(self.trace, varnames=varnames)
5399+
self.samples = self.make_table(cols=varnames)
53935400
self.samples = self.samples.loc[self.tracemask]
53945401
facts={'r_pl':109.07637,'Ms':1.0,'rho':1.0,"t0":1.0,"period":1.0,"vrel":1.0,"tdur":24}
53955402
units={'r_pl':"$ R_\\oplus $",'Ms':"$ M_\\odot $",'rho':"$ \\rho_\\odot $",

0 commit comments

Comments
 (0)