pytreeclass icon indicating copy to clipboard operation
pytreeclass copied to clipboard

Using pytreeclass with jax and pytorch without specifying backend as environment variable

Open mfinzi opened this issue 1 year ago • 0 comments

Hi @ASEM000 ,

I really like your library compared to some other automated pytree alternatives and would love to see more people using it. I was interested in using pytreeclass in CoLA, a numerical linear algebra library that I have been involved in developing. One of the design constraints is that we need to be able to support usage in both jax and pytorch, whether jax is installed, pytorch is installed, or both. This decision depends on the LinearOperator objects that the user creates , and there can be scenarios even where both jax and pytorch objects exist simultaneously.

We were hoping to use pytreeclass as the base pytree for the LinearOperator objects, but have run into some issues with this cross-platform support. We know that pytreeclass was designed with support for both jax and pytorch in mind, but I couldn't find details on this topic in the docs.

Having a look in pytreeclass/_src/backend/init.py is this specified using the environment variable? Is there any way that pytree class can function whether or not jax or pytorch is installed based on whether the imports succeed or fail? Also do you have any thoughts for whether it would be possible to have jax and pytorch pytrees existing at the same time?

Cheers, Marc

mfinzi avatar Oct 06 '23 18:10 mfinzi