forked from mehulghosal/small-neo-lightcurves
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgp_test.py
60 lines (51 loc) · 2.09 KB
/
gp_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import pymc3 as pm
import pymc3_ext as pmx
import aesara_theano_fallback.tensor as tt
from celerite2.theano import terms, GaussianProcess
import numpy as np
def gaussian_process ( time , lightcurve , yerr , period , ) :
with pm.Model() as model:
# The mean flux of the time series
mean = pm.Normal("mean", mu=np.mean(lightcurve), sigma=np.std(lightcurve))
# A jitter term describing excess white noise
log_jitter = pm.Normal("log_jitter", mu=np.log(np.mean(yerr)), sigma=.05)
# A term to describe the non-periodic variability
sigma = pm.InverseGamma(
"sigma", **pmx.estimate_inverse_gamma_parameters(0.5, 5)
)
rho = pm.InverseGamma(
"rho", **pmx.estimate_inverse_gamma_parameters(0.5, 5)
)
# The parameters of the RotationTerm kernel
sigma_rot = pm.InverseGamma(
"sigma_rot", **pmx.estimate_inverse_gamma_parameters(0.5, 5.0)
)
log_period = pm.Normal("log_period", mu=np.log(period), sigma=2.0)
period = pm.Deterministic("period", tt.exp(log_period))
log_Q0 = pm.HalfNormal("log_Q0", sigma=2.0)
log_dQ = pm.Normal("log_dQ", mu=0.0, sigma=2.0)
f = pm.Uniform("f", lower=0.1, upper=1.0)
# Set up the Gaussian Process model
kernel = terms.SHOTerm(sigma=sigma, rho=rho, Q=1 / 3.0)
kernel += terms.RotationTerm(
sigma=sigma_rot,
period=period,
Q0=tt.exp(log_Q0),
dQ=tt.exp(log_dQ),
f=f,
)
gp = GaussianProcess(
kernel,
t= time,
diag=yerr**2 + tt.exp(2 * log_jitter),
mean=mean,
quiet=True,
)
# Compute the Gaussian Process likelihood and add it into the
# the PyMC3 model as a "potential"
gp.marginal("gp", observed=lightcurve)
# Compute the mean model prediction for plotting purposes
pm.Deterministic("pred", gp.predict(lightcurve))
# Optimize to find the maximum a posteriori parameters
map_soln = pmx.optimize()
return map_soln