zenml icon indicating copy to clipboard operation
zenml copied to clipboard

[FEATURE] JAX integration

Open nicholasjng opened this issue 4 years ago • 8 comments

Salam / merhabalar friends,

Is your feature request related to a problem? Please describe. Let's bring Google's JAX to ZenML!

Describe the solution you'd like I would like to build an example (what exactly the example is about is TBD at this point) ZenML+JAX project - if it's cloud-ready, all the better (although JAX on GCP has some sharp edges as far as I understood).

Ideally, this could be accomplished by only a NumPy datasource plus a JAX trainer class, but that is a first hunch - let's hope karma does not strike me for this one.

Additional context Admittedly I am still in March with my mental model of a lot of ZenML's designs, so I will need to spend some time to go through the newer concepts. When I'm ready and made progress, I'll submit a PR with the aforementioned JAX example.

Also, I might have to unpin some requirements to get stuff to build from source (Apple M1), due to the present lack of wheel support (think JAX itself, scipy, or pandas) - I'll check in and document what works here (or just go to Linux instead).

What do you think?

nicholasjng avatar Aug 27 '21 13:08 nicholasjng

Haha @nicholasjng awesome to read this request. Love it. Will notify this thread as soon as we're ready for that undertaking! JAX is an awesome idea and thanks for the request. Lets do it :muscle:

htahir1 avatar Aug 27 '21 16:08 htahir1

Closing due to inactivity. We are migrating such issues to the roadmap for further voting :-)

htahir1 avatar Mar 22 '22 15:03 htahir1

Sorry to necro an old thread but I tried looking through the road map and could not find it. What's the status of this? Canceled, on hold, or completed?

Thank you :)

IanQS avatar May 29 '22 20:05 IanQS

We didn't find enough demand so had to prioritize other things for now. However, I can resurface it if you're interested. Contributions are welcome here, what would a JAX integration look like? Similar to the tensorflow one, i.e., having the ability to pass (materialize) JAX models through ZenML pipelines?

htahir1 avatar May 29 '22 21:05 htahir1

AFAIK there is no builtin way to load/save JAX models, as it just contains the mathematical machinery for applying transformations and differentiating. For my own uses, I did export some models into HDF5, but that has little to do with JAX.

It's possible that first-party NN-libraries (flax, haiku, etc.) have some machinery for it, though. I think a small example on how to load/train/save a model might suffice, what do you have to implement for that? I'll see if I can come up with something if you point me to the necessary components :)

nicholasjng avatar May 30 '22 05:05 nicholasjng

@nicholasjng So sorry for the late reply , this slipped through. Happy to receive your contribution!

A good place to start would be to see the guide to add your own example. IMO, the example would resemble something like the LightGBM example.

You might need to implement a custom materializer to get this to work

I think this is a good starting point for JAX integration. WDYT?

htahir1 avatar Jun 08 '22 11:06 htahir1

Totally. I got some time on the weekend, happy to take a look then. It may or may not be as straightforward for me though, depending on the usability of ZenML on M1. I'll get it done!

nicholasjng avatar Jun 08 '22 13:06 nicholasjng

M1 is a bit problem unfortunately. Let me know if it doesnt work though!

htahir1 avatar Jun 09 '22 07:06 htahir1