Issues sampling with PyMCv5 (gp.marginal fails when combined modelling GP+transit)
I am having issues using celerite2 with PYMC. In the past (with PyMC3) I always used gp.marginal(observed=y-extra_model) in order to sample models which included both GP and other variables (i.e. a transit model) and this had no issue. For whatever reason that is no longer the case with PyMCv5 and I get TypeError: Variables that depend on other nodes cannot be used for observed data..
I thought an easy alternative would be to initialise with gp.compute(), generate a predicted GP curve with gp.predict(), and then model everything the "classical" way in PyMC using pm.Normal(mu=gp_pred+extra_model, sigma=y_err, observed=y). But this gives completely different, and horrendously overfitted, results from using gp.marginal() for the same model. (see below)
So I would love some advice on how to model combined celerite + additional functions:
a) Is there any way to sample using gp.marginal() where the observed data can depend on other PyMC parameters? For example, maybe the mean function could be more than a single value but to have N_t values and we can put the transit model in that way?
b) How should sampling within PyMC be done if using gp.marginal() with y-extra_model is not possible? Should we be using gp.predict() for this purpose at all, or is there just a step I'm missing which is causing the drastic overfitting?
Some code as a MWE:
import pymc as pm
import pymc_ext as pmx
import celerite2.pymc
import arviz as az
import exoplanet as xo
import numpy as np
import matplotlib.pyplot as plt
#Initialising some sinusoidal terms to act as something for GP to remove:
sin_amps=np.exp(np.random.normal(-3,0.2,5))
sin_t0s=np.random.normal(0,15,5)
sin_pers=np.exp(np.random.normal(2,0.5,5))
#Initialising transit parameters:
i_Rs=0.8;i_Ms=0.76
i_us=np.array([0.1,0.3])
i_t0=3.197652;i_P=12.59219 #days
i_b=0.393
i_rpl=3.1309 #Rearth
i_rprs=i_rpl/109.2*i_Rs
#Creating fake data by doing LimbDarkLightCurve
t=np.arange(0,50,1/50)
flux_err=np.tile(0.15,2500)
pure_flux = 1000*xo.LimbDarkLightCurve(i_us).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=i_Rs,m_star=i_Ms,period=i_P,t0=i_t0,b=i_b), r=i_rprs*i_Rs, t=t).eval()[:,0] + \
np.sum(sin_amps[:,None]*np.sin(2*np.pi*(t[None,:]-sin_t0s[:,None])/sin_pers[:,None]),axis=0)
flux=pure_flux+np.random.normal(0.0,np.nanmedian(flux_err),2500)
#Plotting to check:
plt.plot(t,flux,'.')
plt.plot(t,pure_flux,'--',alpha=0.7)
The anticipated behaviour, using gp.marginal() (no transit):
with pm.Model() as model:
logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)
#Initialising GP:
sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
upper=np.ptp(flux),target=0.01))
w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
upper=(2*np.pi)/0.2,target=0.01))
kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))
mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
loglik=gp.marginal("loglik", observed=flux)# - light_curve)
gp_pred=pm.Deterministic("gp_pred",gp.predict(flux,return_var=False))
wmarg_init_soln=pm.find_MAP()
wmarg_trace=pm.sample(start=wmarg_init_soln)
plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(wmarg_trace.posterior['gp_pred'],axis=(0,1)),'-')
plt.savefig("MWE_fit_wmarg.png")
The arviz summary:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
logjit -3.097 0.194 -3.436 -2.759 0.004 0.003 2858.0 2106.0 1.0
mean -0.003 0.024 -0.048 0.045 0.001 0.000 2031.0 1516.0 1.0
sigma 0.119 0.017 0.089 0.150 0.000 0.000 1796.0 1514.0 1.0
w0 1.599 0.428 0.916 2.432 0.011 0.008 1431.0 1696.0 1.0
So a GP-only model works fine.
The behaviour when including an additional non-celerite mean function (with exoplanet transit):
with pm.Model() as model:
Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
P=pm.Normal("P",mu=12.6,sigma=0.01)
t0=pm.Normal("t0",mu=3.21,sigma=0.04)
log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3)
rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
b=xo.distributions.ImpactParameter("b",ror=rprs)
orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
light_curve = 1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t)
logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)
#Initialising GP:
sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
upper=np.ptp(flux),target=0.01))
w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
upper=(2*np.pi)/0.2,target=0.01))
kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))
mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
loglik=gp.marginal("loglik", observed=flux - light_curve)
pm.find_MAP()
Output:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[25], line 26
24 gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
25 gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
---> 26 loglik=gp.marginal("loglik", observed=flux - light_curve)
27 pm.find_MAP()
File ~/miniconda3/envs/chx/lib/python3.9/site-packages/celerite2/pymc/celerite2.py:96, in GaussianProcess.marginal(self, name, **kwargs)
93 from celerite2.pymc.distribution import CeleriteNormal
95 self._add_citations_to_pymc_model(**kwargs)
---> 96 return CeleriteNormal(
97 name,
98 self._mean_value,
99 self._norm,
100 self._t,
101 self._c,
102 self._U,
103 self._W,
104 self._d,
105 **kwargs
106 )
File ~/miniconda3/envs/chx/lib/python3.9/site-packages/pymc/distributions/distribution.py:413, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
409 kwargs["shape"] = tuple(observed.shape)
411 rv_out = cls.dist(*args, **kwargs)
--> 413 rv_out = model.register_rv(
414 rv_out,
415 name,
416 observed,
417 total_size,
418 dims=dims,
419 transform=transform,
420 initval=initval,
421 )
423 # add in pretty-printing support
424 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)
File ~/miniconda3/envs/chx/lib/python3.9/site-packages/pymc/model/core.py:1265, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
1252 else:
1253 if (
1254 isinstance(observed, Variable)
1255 and not isinstance(observed, GenTensorVariable)
(...)
1263 and not is_minibatch(observed)
1264 ):
-> 1265 raise TypeError(
1266 "Variables that depend on other nodes cannot be used for observed data."
1267 f"The data variable was: {observed}"
1268 )
1270 # `rv_var` is potentially changed by `make_obs_var`,
1271 # for example into a new graph for imputation of missing data.
1272 rv_var = self.make_obs_var(rv_var, observed, dims, transform, total_size)
TypeError: Variables that depend on other nodes cannot be used for observed data.The data variable was: Sub.0
I have verified that the same error occurs across different computers (both my M2 Mac and linux server).
The behaviour when sampling with the output of gp.predict():
with pm.Model() as model:
logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)
#Initialising GP:
sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
upper=np.ptp(flux),target=0.01))
w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
upper=(2*np.pi)/0.2,target=0.01))
kernel = celerite2.pymc.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))
mean_flux = pm.Normal("mean", mu=0.0, sigma=0.5*np.nanstd(flux))
gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux)
gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
gp_pred=pm.Deterministic("gp_pred",gp.predict(flux, return_var=False))
loglik=pm.Normal("loglik", mu=gp_pred, sigma=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), observed=flux)# - light_curve)
nomarg_init_soln=pm.find_MAP()
nomarg_trace=pm.sample(start=nomarg_init_soln)
plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(nomarg_trace.posterior['gp_pred'],axis=(0,1)),'-')
plt.savefig("MWE_fit_nomarg.png")
The arviz summary:
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat
logjit -5.297 0.391 -6.014 -4.597 0.009 0.006 2263.0 1961.0 1.00
mean -0.000 0.084 -0.150 0.164 0.002 0.001 2608.0 2583.0 1.00
sigma 1.559 0.874 0.652 2.833 0.030 0.021 1389.0 833.0 1.00
w0 163.862 205.731 21.684 393.793 7.451 5.271 992.0 849.0 1.01
This is clearly extremely over-fitted for some reason...
Ok, it looks like using pm.Potential(gp.log_likelihood(y-extra_model)) is the way to go:
with pm.Model() as model:
Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
P=pm.Normal("P",mu=12.6,sigma=0.01)
t0=pm.Normal("t0",mu=3.21,sigma=0.04)
log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3,initval=-2)
rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
b=xo.distributions.ImpactParameter("b",ror=rprs,initval=0.4)
orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
lightcurve=pm.Deterministic('lightcurve',1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t))
logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)
#Initialising GP:
sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
upper=np.ptp(flux),target=0.01))
w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
upper=(2*np.pi)/0.2,target=0.01))
kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))
mean_flux = pm.Normal("mean_flux", mu=0.0, sigma=0.5*np.nanstd(flux))
gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux, t=t, diag=flux_err** 2 + pm.math.exp(logjit)**2)
loglik=pm.Potential("loglik", gp.log_likelihood(flux-pm.math.sum(lightcurve,axis=1)))
gp_pred=pm.Deterministic("gp_pred", gp.predict(flux-pm.math.sum(lightcurve,axis=1), return_var=False))
#nomarg_trans_init_soln=pm.find_MAP()
nomarg_trans_trace=pm.sample()
plt.plot(t,flux,'.',alpha=0.6)
plt.plot(t,np.nanmedian(nomarg_trans_trace.posterior['gp_pred'].values+nomarg_trans_trace.posterior['lightcurve'].values[:,:,:,0],axis=(0,1)),'-')
plt.savefig("MWE_fit_nomarg_trans.png")
So that seems to fix it! Though I am apprehensive about this as having a bit of a blackbox likelihood function - sometimes that doesn't play ball with some arviz functions like WAIC, so any advice on directly calling pm.Normal or gp.marginal would still be useful imho.
I'm also coming across this issue, thanks @hposborn for the current fix! I imagine the only workaround to this is if marginal calls something like pm.Potential(pm.logp(pm.MvNorm(<gp.marginal's vars>), observed=y-transit_model))... It seems like pymc5+> are keeping observed as observed onwards.
My "fix" doesn't really work - I end up with compilation errors when running large models that way. So I'd love to know the official way of sampling with pymcv5 and both a GP and additional models...
Hi @hposborn , I believe the solution to this is to indeed model the classical way of lc+GP = observed, but this is accessed through setting the GP mean as your light curve. Using your code:
lightcurve=pm.Deterministic('lightcurve',1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t))
logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)
#Initialising GP:
sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
upper=np.ptp(flux),target=0.01))
w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
upper=(2*np.pi)/0.2,target=0.01))
kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))
mean_flux = pm.Normal("mean_flux", mu=0.0, sigma=0.5*np.nanstd(flux))
gp = celerite2.pymc.GaussianProcess(kernel, mean=(lightcurve+mean_flux))
gp.compute(t, diag=flux_err** 2 + pm.math.exp(logjit)**2, quiet=True)
pm.Deterministic(f'gp_pred', gp.predict(flux))
gp.marginal(f'obs', observed=flux)
(Obviously making sure flux and lightcurve+mean_flux are centered around the same value)
Let me know if this works! I believe this should be the solution moving from pymc>=5
Hi @TylerFair. You are correct! In fact, I thought this might be the way to go (I mentioned the shape of the mean function in my initial text) but I guess when I was testing I forgot that lightcurve.LimbDarkLightcurve spits out a 2D array not a 1D vector, which was the true cause of my issues. When including a quick sum along the short axis, mean=mean_flux + pm.math.sum(light_curve,axis=1)), the sampling works fine.
with pm.Model() as model:
Rs=pm.Normal("Rs",mu=0.8,sigma=0.02)
Ms=pm.Normal("Ms",mu=0.78,sigma=0.02)
P=pm.Normal("P",mu=12.6,sigma=0.01)
t0=pm.Normal("t0",mu=3.21,sigma=0.04)
log_rprs=pm.Normal("log_rprs",mu=-4,sigma=3)
rprs=pm.Deterministic("rprs",pm.math.exp(log_rprs))
rpl=pm.Deterministic("rpl",rprs*Rs*109.2)
b=xo.distributions.ImpactParameter("b",ror=rprs)
orb = xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b)
u_stars = xo.distributions.QuadLimbDark("u_star", testval=np.array([0.3, 0.2]))
logjit =pm.Normal('logjit', mu=np.log(np.std(flux)), sigma=1)
#Initialising GP:
sigma=pm.InverseGamma("sigma", **pmx.utils.estimate_inverse_gamma_parameters(lower=np.nanmedian(flux_err),
upper=np.ptp(flux),target=0.01))
w0=pm.InverseGamma("w0", **pmx.utils.estimate_inverse_gamma_parameters(lower=(2*np.pi)/10,
upper=(2*np.pi)/0.2,target=0.01))
kernel = celerite2.pymc.terms.SHOTerm(sigma=sigma, w0=w0, Q=1/np.sqrt(2))
mean_flux = pm.Normal("mean_flux", mu=0.0, sigma=0.5*np.nanstd(flux))
light_curve = pm.Deterministic("light_curve",1000*xo.LimbDarkLightCurve(u_stars).get_light_curve(orbit=xo.orbits.KeplerianOrbit(r_star=Rs,m_star=Ms,period=P,t0=t0,b=b), r=rprs*Rs, t=t))
gp = celerite2.pymc.GaussianProcess(kernel, mean=mean_flux + pm.math.sum(light_curve,axis=1))
gp.compute(t, yerr=pm.math.sqrt(flux_err** 2 + pm.math.exp(logjit)**2), quiet=True)
gp_pred=pm.Deterministic("gp_pred",gp.predict(flux, return_var=False))
loglik=gp.marginal("loglik", observed=flux)
sum_trans_trace=pm.sample(chains=12,cores=6)