exojax
exojax copied to clipboard
Release v1.2
Release branch for v1.2.
I will merge this branch to master
and release v1.2 in November 21st unless any serious issues are found.
Comments and reviews welcome for anyone!
- Tutorials: I will reorganize them as possible.
Unit Test Results
95 testsβββ95 :heavy_check_mark:ββ1m 9s :stopwatch: ββ1 suitesββββ0 :zzz: ββ1 filesββββββ0 :x:
Results for commit 3ff219b7.
:recycle: This comment has been updated with latest results.
Thank you very much, and sorry for the delay in reviewing.
When I run Cross_Section_using_Discrete_Integral_Transform.ipynb, qt = mdbCO.qr_interp(0,Tfix)
gives a following error.
File "/home/kawashima/ExoJAX/test/cs.py", line 23, in <module>
qt = mdbCO.qr_interp(0,Tfix)
File "/home/kawashima/anaconda3/lib/python3.9/site-packages/ExoJAX-1.2-py3.9.egg/exojax/spec/api.py", line 650, in qr_interp
isotope_index = _isotope_index_from_isotope_number(
File "/home/kawashima/anaconda3/lib/python3.9/site-packages/ExoJAX-1.2-py3.9.egg/exojax/spec/api.py", line 713, in _isotope_index_from_isotope_number
isotope_index = np.where(uniqiso == isotope)[0][0]
IndexError: index 0 is out of bounds for axis 0 with size 0
For me, it seems that the current version of qr_interp cannot handle isotope=0 (all the isotopes case). Is my understanding correct?
The same error occurs for the following codes, too:
@ykawashima Thanks! I will check the case of isotope=0 tomorrow. update: addressed by #321.
Hi Kawahara-san, thank you for the request and apologies for the delay.
I am not sure if it is a bug or just me, but I got the following message whenever I run "mdbCO_HITEMP=api.MdbHitemp('CO',nus, gpu_transfer=True)"
for the second time (i.e., after it downloaded the linelist):
`---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_146160/168903778.py in
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev518+gb7efcbd-py3.7.egg/exojax/spec/api.py in init(self, path, nurange, margin, crit, Ttyp, isotope, gpu_transfer, inherit_dataframe) 273 verbose=True, 274 chunksize=100000, --> 275 parallel=True, 276 ) 277
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/radis-0.14-py3.7.egg/radis/api/hitempapi.py in init(self, name, molecule, local_databases, engine, verbose, chunksize, parallel) 124 engine, 125 verbose=verbose, --> 126 parallel=parallel, 127 ) 128 self.chunksize = chunksize
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/radis-0.14-py3.7.egg/radis/api/dbmanager.py in init(self, name, molecule, local_databases, engine, extra_params, verbose, parallel, nJobs, batch_size)
126 ): # TODO: replace with pathlib
127 raise ValueError(
--> 128 f"Databank {self.name}
is already registered in radis.json but the declared path ({registered_path_abspath}) is not in the expected local databases folder ({local_databases_abspath}). Please fix/delete the radis.json entry, change the databank_name
, or change the default local databases path entry 'DEFAULT_DOWNLOAD_PATH' in radis.config
or ~/radis.json"
129 )
130
ValueError: Databank HITEMP-{molecule}
is already registered in radis.json but the declared path (/mnt/phoenest/stevanus/planetspecgen/tutorials/co-05_hitemp2019.hdf5) is not in the expected local databases folder (/mnt/phoenest/stevanus/PlanetSpecGen). Please fix/delete the radis.json entry, change the databank_name
, or change the default local databases path entry 'DEFAULT_DOWNLOAD_PATH' in radis.config
or ~/radis.json
`
I have tried to change the DEFAULT_DOWNLOAD_PATH in the radis.json, but it doesn't work. To make it work, I have to delete "~/radis.json" every time I want to run the code for the second time after the linelist was downloaded.
@astrostevanus Thanks! Honestly, I do not fully understand how databank works in radis. We will fix this by the next release. Meanwhile, remove radis.json when you have that error (and I do that...).
Oh okay, I'll remove radis.json from time to time for now. I found possible bugs in the "Forward_modeling" notebook tutorial. I think
mdbCO.generate_jnp_arrays()
is missing before
SijM=jit(vmap(SijT,(0,None,None,None,0)))\
(Tarr,mdbCO.logsij0,mdbCO.nu_lines,mdbCO.elower,qt)
After generating the jnp arrays, everything went well except the edge of the raw spectrum that looks like this:
I guess that is due to the convolution in the rotational and instrumental broadening?
In optimize_spectrum_JAXopt, I think the wavenumber_grid
should be used instead of nugrid
. Additionally, I run into this error when I run res = gd.run(init_params=initpar)
This also happens if I run mcmc.run(rng_key_, nu1=nusd, y1=nflux)
in the Reverse_modeling
it seems that the velocity_grid(resolution, vmax) in utils/grids.py needs use jnp instead of np, but I am not sure
TracerArrayConversionError Traceback (most recent call last) /tmp/ipykernel_30790/332355194.py in
1 gd = jaxopt.GradientDescent(fun=objective, maxiter=1000,stepsize=1.e-4) ----> 2 res = gd.run(init_params=initpar) 3 params, state = res ~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/base.py in run(self, init_params, *args, **kwargs) 213 run = decorator(run) 214 --> 215 return run(init_params, *args, **kwargs) 216 217
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/implicit_diff.py in wrapped_solver_fun(*args, **kwargs) 249 args, kwargs = _signature_bind(solver_fun_signature, *args, **kwargs) 250 keys, vals = list(kwargs.keys()), list(kwargs.values()) --> 251 return make_custom_vjp_solver_fun(solver_fun, keys)(*args, *vals) 252 253 return wrapped_solver_fun
[... skipping hidden 5 frame]
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/implicit_diff.py in solver_fun_flat(*flat_args) 205 def solver_fun_flat(*flat_args): 206 args, kwargs = _extract_kwargs(kwarg_keys, flat_args) --> 207 return solver_fun(*args, **kwargs) 208 209 def solver_fun_fwd(*flat_args):
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/base.py in _run(self, init_params, *args, **kwargs) 183 cond_fun=self._cond_fun, body_fun=self._body_fun, 184 init_val=init_val, maxiter=self.maxiter - 1, jit=jit, --> 185 unroll=unroll)[0] 186 187 return tree_util.tree_map(
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/loop.py in while_loop(cond_fun, body_fun, init_val, maxiter, unroll, jit) 80 fun = jax.jit(fun, static_argnums=(0, 1, 3)) 81 ---> 82 return fun(cond_fun, body_fun, init_val, maxiter)
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/loop.py in _while_loop_lax(cond_fun, body_fun, init_val, maxiter) 58 return it+1, val 59 ---> 60 return jax.lax.while_loop(_cond_fun, _body_fun, (0, init_val))[1] 61 62
[... skipping hidden 11 frame]
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/loop.py in _body_fun(_val) 55 def _body_fun(_val): 56 it, val = _val ---> 57 val = body_fun(val) 58 return it+1, val 59
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/base.py in _body_fun(self, inputs) 143 def _body_fun(self, inputs): 144 (params, state), (args, kwargs) = inputs --> 145 return self.update(params, state, *args, **kwargs), (args, kwargs) 146 147 # TODO(frostig,mblondel): temporary workaround to accommodate line
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/gradient_descent.py in update(self, params, state, *args, **kwargs) 79 (params, state) 80 """ ---> 81 return super().update(params, state, None, *args, **kwargs) 82 83 def optimality_fun(self, params, *args, **kwargs):
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/proximal_gradient.py in update(self, params, state, hyperparams_prox, *args, **kwargs) 270 """ 271 f = self._update_accel if self.acceleration else self._update --> 272 return f(params, state, hyperparams_prox, args, kwargs) 273 274 def _fixed_point_fun(self, sol, hyperparams_prox, args, kwargs):
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/proximal_gradient.py in _update_accel(self, x, state, hyperparams_prox, args, kwargs) 239 t = state.t 240 stepsize = state.stepsize --> 241 y_fun_val, y_fun_grad = self._value_and_grad_fun(y, *args, **kwargs) 242 next_x, next_stepsize = self._iter(iter_num, y, y_fun_val, y_fun_grad, 243 stepsize, hyperparams_prox, args, kwargs)
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/proximal_gradient.py in _value_and_grad_fun(self, params, *args, **kwargs) 282 283 def _value_and_grad_fun(self, params, *args, **kwargs): --> 284 (value, aux), grad = self._value_and_grad_with_aux(params, *args, **kwargs) 285 return value, grad 286
[... skipping hidden 8 frame]
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/proximal_gradient.py in
(*a, **kw) 298 else: 299 self._fun = self.fun --> 300 fun_with_aux = lambda *a, **kw: (self.fun(*a, **kw), None) 301 302 self._value_and_grad_with_aux = jax.value_and_grad(fun_with_aux, /tmp/ipykernel_30790/3393464750.py in objective(params) 1 def objective(params): ----> 2 f=nflux-model_c(params,boost,nusd) 3 g=jnp.dot(f,f) 4 return g
/tmp/ipykernel_30790/3599969597.py in model_c(params, boost, nu1) 35 return mu 36 ---> 37 model=obyo(nu1,nus,numatrix_CO,mdbCO,cdbH2H2) 38 return model
/tmp/ipykernel_30790/3599969597.py in obyo(nusd, nus, numatrix_CO, mdbCO, cdbH2H2) 30 F0=rtrun(dtau,sourcef)/norm 31 ---> 32 Frot=response.rigidrot(nus,F0,vsini,u1,u2) 33 #Frot=rigidrotx(nus,F0,vsini,u1,u2) 34 mu=response.ipgauss_sampling(nusd,nus,Frot,beta,RV)
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/spec/response.py in rigidrot(nus, F0, vsini, u1, u2, vsinimax) 19 """ 20 resolution = resolution_eslog(nus) ---> 21 vr_array = velocity_grid(resolution, vsinimax) 22 return convolve_rigid_rotation(F0, vr_array, vsini, u1, u2) 23
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/utils/grids.py in velocity_grid(resolution, vmax) 96 Nk = (vmax / dv) + 1 97 Nk = Nk.astype(int) ---> 98 return dv * np.arange(-Nk, Nk + 1) 99 100
~/anaconda3/envs/astroconda/lib/python3.7/site-packages/jax-0.3.17-py3.7.egg/jax/core.py in array(self, *args, **kw) 534 535 def array(self, *args, **kw): --> 536 raise TracerArrayConversionError(self) 537 538 def dlpack(self, *args, **kw):
TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/1)> The error occurred while tracing the function _body_fun at /home/stevanus/anaconda3/envs/astroconda/lib/python3.7/site-packages/jaxopt-0.5-py3.7.egg/jaxopt/_src/loop.py:55 for while_loop. This value became a tracer due to JAX operations on these lines:
operation a:f32[] = log1p b from line /home/stevanus/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/utils/grids.py:113 (delta_velocity_from_resolution)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b from line /home/stevanus/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/utils/grids.py:113 (delta_velocity_from_resolution)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b from line /home/stevanus/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/utils/grids.py:96 (velocity_grid)
operation a:f32[] = convert_element_type[new_dtype=float32 weak_type=False] b from line /home/stevanus/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/utils/grids.py:96 (velocity_grid)
operation a:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b from line /home/stevanus/anaconda3/envs/astroconda/lib/python3.7/site-packages/ExoJAX-1.1.4.dev523+gabd725f-py3.7.egg/exojax/utils/grids.py:98 (velocity_grid) See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError
@astrostevanus Thanks a lot. I fixed jax.example_libraries
. But, the second point,
it seems that the velocity_grid(resolution, vmax) in utils/grids.py needs use jnp instead of np, but I am not sure
This looks more severe issue. I will look for the solution.
@astrostevanus I forgot that I made the new function that work VJP and forgot to replace to them.
Frot = convolve_rigid_rotation(F0, vr_array, vsini, u1, u2)
So, now rigidrot
is deprecated. Thanks for pointing this.
@astrostevanus Thanks! Honestly, I do not fully understand how databank works in radis. We will fix this by the next release. Meanwhile, remove radis.json when you have that error (and I do that...).
Hello, I'll have a look so you don't have conflicts with Radis.json within exojax. I can have it integrated within 1.2 but would require one more week.
@erwanp Thanks! No problem to delay it for another week. I have several things to fix within 1.2!
@HajimeKawahara @astrostevanus Sorry, I mistakenly directly pushed to the release branch, but the commit 17ddb2a fixes the bug for nus grid, which is the cause of the error @astrostevanus found.
I forgot that nus grid should always be increasing order when I fixed the grid setting bug (#319).
I have confirmed that Forward_modeling_using_DIT.ipynb now works.
I just thought that it would be better to explicitly write down somewhere that both wavenumber and wavelength grids, outputs of wavenumber_grid, are increasing orders, so instead of wav[0], wav[-1] corresponds to the wavelength of nus[0]. Also, cross sections and spectra are both calculated for that increasing-wavenumber grid. This is because I sometimes get confused. Sorry, if I just do not notice that they are already written somewhere. What would you think?
@ykawashima Thanks! Yes I agree. I also always get confused.
Note that the docs for the current release version is here
Do you need more time @erwanp ? or maybe we can postpone it until v1.3.
@HajimeKawahara I think the Register lines are the culprit : in MdbHitran
and in MdbHitemp
they should be commented.
- ExoJax philosophy is to compare paths given by users to the files on disk; if they aren't here files are downloaded.
- Radis philosophy is to work with database names that are registered in radis.json Both have pros & cons; but in Exojax there shouldn't be any registration.
The only problem is it takes me time to set up a proper environment to test it in conditions similar to yours. @astrostevanus could you try with commenting the lines mentionned above ? You want to make sure that :
- the error is not raised anymore
- that HITRAN is not re-downloaded if it exists already
@erwanp Thanks a lot! in my side, it worked. Can you check the latest version @astrostevanus just in case?
Hi @erwanp and @HajimeKawahara san, I will check it tomorrow night if you don't mind. I am on my annual leave today and tomorrow, so I don't have any access to my computer π .
Hi @erwanp and @HajimeKawahara san, I have just checked it and it works well!
Perfect thank you for trying
Thanks a lot!