JAX-GalSim icon indicating copy to clipboard operation
JAX-GalSim copied to clipboard

validating class instantiation and jitting

Open ismael-mendoza opened this issue 1 year ago • 4 comments

In many classes in Galsim there is code that validates the instantiation of a class object. This is difficult to maintain if we want make the class jittable and vmappable. For instance, in the Bounds class we had to get rid of code from Galsim that did the following:

        if (self.xmin != int(self.xmin) or self.xmax != int(self.xmax) or
            self.ymin != int(self.ymin) or self.ymax != int(self.ymax)):
            raise TypeError("BoundsI must be initialized with integer values")

I don't think it's urgent but I would like to discuss whether there is a workaround around this that we can use to bring us closer to the Galsim code since these checks are arguably useful for the user.

@jecampagne did you have some ideas regarding this issue?

ismael-mendoza avatar May 05 '23 21:05 ismael-mendoza

Hi Ismael

It is not my final guess but I  think it is related to https://github.com/jecampagne/JaxTutos/blob/main/JAX_PyTree_initialisation.ipynb (section "A more complete ex with different kinds of variables")

the point is when you want to call a function to define some initial value you certainly need the "explicit init which is a classmethod"  (see the _y variable)

Here it is a bit diffierent as you ask for a test but I guess it is quite similar as this test is common to all BoundsI objects. Now, another point, xmin... are certainly as "shape" variables that can (should) be considered as fshape in

https://github.com/jecampagne/JaxTutos/blob/main/JAX_static_traced_var_func.ipynb section More advanced exo

There is a takeaway in the Jax doc "A useful pattern is to use |numpy| for operations that should be static (i.e. done at compile-time), and use |jax.numpy| for operations that should be traced (i.e. compiled and executed at run-time). (https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#static-vs-traced-operations)

We certainly need a skeleton model for the Bounds checking pb.

No time yet to setup this demo. Sorry

JE

Le 5/5/23 à 11:07 PM, Ismael Mendoza a écrit :

In many classes in Galsim there is code that tests the instantiation of a class object. This is difficult to maintain if we want make the class jittable and vmappable. For instance, in the |Bounds| class we had to get rid of code from Galsim that did the following:

|if (self.xmin != int(self.xmin) or self.xmax != int(self.xmax) or self.ymin != int(self.ymin) or self.ymax != int(self.ymax)): raise TypeError("BoundsI must be initialized with integer values") |

I don't think it's urgent but I would like to discuss whether there is a workaround around this that we can use to bring us closer to the Galsim code since these checks are arguably useful for the user.

@jecampagne https://github.com/jecampagne did you have some ideas regarding this issue?

— Reply to this email directly, view it on GitHub https://github.com/GalSim-developers/JAX-GalSim/issues/37, or unsubscribe https://github.com/notifications/unsubscribe-auth/AE4WS3YHKCYECPRMAWLA7DLXEVTYNANCNFSM6AAAAAAXXSZSIU. You are receiving this because you were mentioned.Message ID: @.***>

jecampagne avatar May 06 '23 16:05 jecampagne

I don't think there is a workaround for this that is perfect. Our tests now jit and autodiff all of the objets, so some of the checks work, but not all of them. We have further noted in the doc strings where things are not checked. Shall we close this one?

beckermr avatar Sep 16 '24 15:09 beckermr

That's ok with me, since this is sounds like some intrinsic to jitting in JAX. As long as we note in the docs what we do check I'd be ok with closing this issue.

ismael-mendoza avatar Sep 16 '24 18:09 ismael-mendoza

Ok. I'll go through and update doc strings.

beckermr avatar Sep 16 '24 18:09 beckermr