JAX-GalSim
JAX-GalSim copied to clipboard
validating class instantiation and jitting
In many classes in Galsim there is code that validates the instantiation of a class object. This is difficult to maintain if we want make the class jittable and vmappable. For instance, in the Bounds
class we had to get rid of code from Galsim that did the following:
if (self.xmin != int(self.xmin) or self.xmax != int(self.xmax) or
self.ymin != int(self.ymin) or self.ymax != int(self.ymax)):
raise TypeError("BoundsI must be initialized with integer values")
I don't think it's urgent but I would like to discuss whether there is a workaround around this that we can use to bring us closer to the Galsim code since these checks are arguably useful for the user.
@jecampagne did you have some ideas regarding this issue?
Hi Ismael
It is not my final guess but I think it is related to https://github.com/jecampagne/JaxTutos/blob/main/JAX_PyTree_initialisation.ipynb (section "A more complete ex with different kinds of variables")
the point is when you want to call a function to define some initial value you certainly need the "explicit init which is a classmethod" (see the _y variable)
Here it is a bit diffierent as you ask for a test but I guess it is quite similar as this test is common to all BoundsI objects. Now, another point, xmin... are certainly as "shape" variables that can (should) be considered as fshape in
https://github.com/jecampagne/JaxTutos/blob/main/JAX_static_traced_var_func.ipynb section More advanced exo
There is a takeaway in the Jax doc "A useful pattern is to use |numpy| for operations that should be static (i.e. done at compile-time), and use |jax.numpy| for operations that should be traced (i.e. compiled and executed at run-time). (https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations)
We certainly need a skeleton model for the Bounds checking pb.
No time yet to setup this demo. Sorry
JE
Le 5/5/23 à 11:07 PM, Ismael Mendoza a écrit :
In many classes in Galsim there is code that tests the instantiation of a class object. This is difficult to maintain if we want make the class jittable and vmappable. For instance, in the |Bounds| class we had to get rid of code from Galsim that did the following:
|if (self.xmin != int(self.xmin) or self.xmax != int(self.xmax) or self.ymin != int(self.ymin) or self.ymax != int(self.ymax)): raise TypeError("BoundsI must be initialized with integer values") |
I don't think it's urgent but I would like to discuss whether there is a workaround around this that we can use to bring us closer to the Galsim code since these checks are arguably useful for the user.
@jecampagne https://github.com/jecampagne did you have some ideas regarding this issue?
— Reply to this email directly, view it on GitHub https://github.com/GalSim-developers/JAX-GalSim/issues/37, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE4WS3YHKCYECPRMAWLA7DLXEVTYNANCNFSM6AAAAAAXXSZSIU. You are receiving this because you were mentioned.Message ID: @.***>
I don't think there is a workaround for this that is perfect. Our tests now jit and autodiff all of the objets, so some of the checks work, but not all of them. We have further noted in the doc strings where things are not checked. Shall we close this one?
That's ok with me, since this is sounds like some intrinsic to jitting in JAX. As long as we note in the docs what we do check I'd be ok with closing this issue.
Ok. I'll go through and update doc strings.