exoplanet
exoplanet copied to clipboard
xarray serialization error when using observed custom distribution
Describe the bug Unable to save traces to file, which is essential for running code on a cluster.
To Reproduce
Follow the case study Fitting TESS data (https://gallery.exoplanet.codes/tutorials/tess/) --
except using
lc = lc_file.remove_nans().remove_outliers().normalize()
instead of
lc = lc_file.remove_nans().normalize().remove_outliers()
, as the first order of transformations raised an unrelated error in my case.
After sampling, try to save the trace as one commonly saves arviz.InferenceData objects: trace.to_netcdf('results')
.
This will raise the following error
ValueError Traceback (most recent call last)
<ipython-input-23-c0e5828e59ee> in <module>
----> 1 trace.to_netcdf('results')
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/arviz/data/inference_data.py in to_netcdf(self, filename, compress, groups)
390 if compress:
391 kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
--> 392 data.to_netcdf(filename, mode=mode, group=group, **kwargs)
393 data.close()
394 mode = "a"
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/core/dataset.py in to_netcdf(self, path, mode, format, group, engine, encoding, unlimited_dims, compute, invalid_netcdf)
1900 from ..backends.api import to_netcdf
1901
-> 1902 return to_netcdf(
1903 self,
1904 path,
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/api.py in to_netcdf(dataset, path_or_file, mode, format, group, engine, encoding, unlimited_dims, compute, multifile, invalid_netcdf)
1070 # TODO: allow this work (setting up the file for writing array data)
1071 # to be parallelized with dask
-> 1072 dump_to_store(
1073 dataset, store, writer, encoding=encoding, unlimited_dims=unlimited_dims
1074 )
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/api.py in dump_to_store(dataset, store, writer, encoder, encoding, unlimited_dims)
1117 variables, attrs = encoder(variables, attrs)
1118
-> 1119 store.store(variables, attrs, check_encoding, writer, unlimited_dims=unlimited_dims)
1120
1121
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/common.py in store(self, variables, attributes, check_encoding_set, writer, unlimited_dims)
259 writer = ArrayWriter()
260
--> 261 variables, attributes = self.encode(variables, attributes)
262
263 self.set_attributes(attributes)
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/backends/common.py in encode(self, variables, attributes)
348 # All NetCDF files get CF encoded by default, without this attempting
349 # to write times, for example, would fail.
--> 350 variables, attributes = cf_encoder(variables, attributes)
351 variables = {k: self.encode_variable(v) for k, v in variables.items()}
352 attributes = {k: self.encode_attribute(v) for k, v in attributes.items()}
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in cf_encoder(variables, attributes)
853 _update_bounds_encoding(variables)
854
--> 855 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
856
857 # Remove attrs from bounds variables (issue #2921)
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in <dictcomp>(.0)
853 _update_bounds_encoding(variables)
854
--> 855 new_vars = {k: encode_cf_variable(v, name=k) for k, v in variables.items()}
856
857 # Remove attrs from bounds variables (issue #2921)
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in encode_cf_variable(var, needs_copy, name)
273 var = maybe_default_fill_value(var)
274 var = maybe_encode_bools(var)
--> 275 var = ensure_dtype_not_object(var, name=name)
276
277 for attr_name in CF_RELATED_DATA:
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in ensure_dtype_not_object(var, name)
231 data[missing] = fill_value
232 else:
--> 233 data = _copy_with_dtype(data, dtype=_infer_dtype(data, name))
234
235 assert data.dtype.kind != "O" or data.dtype.metadata
/cluster/apps/nss/gcc-8.2.0/python/3.9.9/x86_64/lib64/python3.9/site-packages/xarray/conventions.py in _infer_dtype(array, name)
165 return dtype
166
--> 167 raise ValueError(
168 "unable to infer dtype on variable {!r}; xarray "
169 "cannot serialize arbitrary Python objects".format(name)
ValueError: unable to infer dtype on variable 'ecc_prior'; xarray cannot serialize arbitrary Python objects
Expected behavior I expect the trace to be saved, as usual. I can individually save
trace.posterior.to_netcdf('posterior')
trace.log_likelihood.to_netcdf('log_likelihood')
trace.sample_stats.to_netcdf('sample_stats')
However, the same error is caused when trying to save
trace.observed_data.to_netcdf('observed_data')
My setup
- Version of exoplanet: 0.5.2
- Operating system: I reproduced this error once on Linux and once on macOS 10.15
- Python version: python 3.9.9
- Installation method: pip install -U "exoplanet[extras]"
- Version of arviz: 0.11.4
- Version of pymc3: 3.11.4
Has anyone encountered this problem before, knows how to solve it, or has a suggestion for a workaround? Thank you very much in advance.
I have seen this before, and I don't totally understand why this happens. In the short term, I'd recommend using the groups
argument to to_netcdf
(care of @avivajpeyi):
trace.to_netcdf(
filename=...,
groups=["posterior", "log_likelihood", "sample_stats"],
)
Here's a simpler snippet that fails with the same issue:
import pymc3 as pm
import exoplanet as xo
with pm.Model() as model:
ecc = pm.Uniform("ecc")
xo.eccentricity.kipping13("ecc_prior", fixed=True, observed=ecc)
trace = pm.sample(return_inferencedata=True)
trace.to_netcdf("test")
This could be used to debug and find a longer term solution.
@dfm Thank you very much for the quick reply, the boiled down version and the workaround. I would still be curious to know if trace.observed_data can also be saved / how the custom exoplanet distributions would need to be changed.
The issue is somehow related to the fact that observed_data
is required to be fixed (i.e. the same for every sample), but here we're overloading the obs
argument to be "tensor" that depends on the parameters. We could hack it to use pm.Potential
or something, but then it would no longer work as a prior:
ecc = xo.eccentricity.kipping13("ecc_prior", fixed=True)
Perhaps there's some other PyMC3 trickery that we could use, but I don't know what it would look like!
I see. I played around some more and found that the problem of xarray serialization is not limited to the custom distributions of exoplanet but also occurs for celerite2.theano.GaussianProcess objects under certain circumstance:
import numpy as np
import exoplanet as xo
import pymc3 as pm
import aesara_theano_fallback.tensor as tt
from celerite2.theano import terms, GaussianProcess
x = np.linspace(-1, 1, 1000)
texp = x[1]-x[0]
true_orbit = xo.orbits.KeplerianOrbit(period=4, t0=0)
y = xo.LimbDarkLightCurve([0, 0]).get_light_curve(orbit=true_orbit, t=x, r= 0.1, texp=texp).eval()
y = y.reshape(-1) + np.random.normal(loc=0.0, scale=0.001, size=len(x))
with pm.Model() as model:
t0 = pm.Normal("t0", mu=0.1, sd=0.5)
orbit = xo.orbits.KeplerianOrbit(period=3.99, t0=t0)
light_curves = xo.LimbDarkLightCurve([0, 0]).get_light_curve(orbit=orbit, r=0.1, t=x, texp=texp)
resid = y - tt.sum(light_curves, axis=-1)
kernel = terms.SHOTerm(sigma=np.std(y), rho=np.std(y), Q=1/np.sqrt(2))
gp = GaussianProcess(kernel, t=x, yerr=np.std(y))
gp.marginal("gp", observed=resid) # no error if resid is replaced by y
# Sample
trace = pm.sample(return_inferencedata=True, tune=100, draws=200)
trace.to_netcdf("test")
raises the error
ValueError: unable to infer dtype on variable 'gp'; xarray cannot serialize arbitrary Python objects
However, there is no error if gp.marginal("gp", observed=resid)
is replaced with gp.marginal("gp", observed=y)
, i.e. we marginalise on a quantity that is independent of any random variables.
Yes. This is exactly the same issue. The "observed data" isn't a constant if it depends on the model!
You are right of course, I just wanted to point out that the problem is not limited to the custom distribution of exoplanet.
Haha good point. The celerite2 design (mistakes?) are the same and also my own :D
Somewhat off-topic: Well, thank you very much for your great work. For tackling scientific problems, I find it always very pleasant to use higher-level abstraction languages like Python while still having the advantage of high performance of lower-level languages like C++. Accordingly, I am grateful that packages like celerite2 exist. I thought a few times that in the long run it might be worthwhile to consider translating exoplanet to Julia, thus, potentially avoiding the dependence on various other languages that is currently required which might increase performance. Needless to say that such an undertaking is of course also very labour-intensive and would take quite some time to realise.
On-topic: I now asked in the PyMC forum here how this issue is usually handled and will let you know in case anyone comes up with a neat solution.
Thanks!
If you're interested in a Julia version, there has been some work on implementing at least some parts over in: https://github.com/JuliaAstro/Transits.jl And we have this old implementation of celerite in Julia (I'm not sure how compatible it'll be with recent versions of Julia): https://github.com/ericagol/celerite.jl
Thanks for pointing out these projects, I will have a look at them!
If you're interested in a Julia version, there has been some work on implementing at least some parts over in: https://github.com/JuliaAstro/Transits.jl
Also https://github.com/rodluger/Limbdark.jl