zenml icon indicating copy to clipboard operation
zenml copied to clipboard

Add JAX integration

Open nicholasjng opened this issue 3 years ago • 15 comments

This commit adds a ZenML integration for JAX, a Python ML research framework.

Describe changes

See above.

Pre-requisites

Please ensure you have done the following:

  • [x] I have read the CONTRIBUTING.md document.
  • [ ] If my change requires a change to docs, I have updated the documentation accordingly.
  • [ ] If I have added an integration, I have updated the integrations table.
  • [ ] I have added tests to cover my changes.

Types of changes

  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to change)
  • [ ] Other (add details above)

TODOs:

  • Add an example for using JAX together with one of its associated DL frameworks, possibly https://github.com/nicholasjng/jax-styletransfer.
  • Add a model materializer for loading weights from a file, e.g. HDF5 (also implemented in the style transfer example I linked above).
  • Figure out how to enable the installation of CUDA-ready jaxlib packages if requested. jaxlib is the associated computational backend of JAX, and comes in different wheel flavors depending on what type of installation you want (CPU-only, or with CUDA/ROCm support on Linux/Windows). Supporting all different options will require some work.

nicholasjng avatar Jul 04 '22 09:07 nicholasjng

CLA assistant check
All committers have signed the CLA.

CLAassistant avatar Jul 04 '22 09:07 CLAassistant

Thanks for the PR @nicholasjng!

I think this a good draft but needs some work. Your TODO list is actually quite intriguing and is more in line what a JAX integration could look like for us. Do you need any support in drafting the example and materializers?

htahir1 avatar Jul 04 '22 12:07 htahir1

Thanks for the PR @nicholasjng!

I think this a good draft but needs some work. Your TODO list is actually quite intriguing and is more in line what a JAX integration could look like for us. Do you need any support in drafting the example and materializers?

I think the materializer API is pretty well documented. I would probably go for something that saves and loads a model to/from disk, like I mentioned earlier.

I can reuse the code from my repo in the example (after some reorganization), for everything else, I will check the "Add your own example" guide.

What is your opinion on the installation point I mentioned? The process is described in the JAX readme, by the way, if you want to make up your mind independently of what I wrote there. Happy to hear your thoughts, going back to work :)

nicholasjng avatar Jul 04 '22 12:07 nicholasjng

What is your opinion on the installation point I mentioned? The process is described in the JAX readme, by the way, if you want to make up your mind independently of what I wrote there. Happy to hear your thoughts, going back to work :)

Sounds to me like the JAX installation can exist outside of the ZenML integration. We can simply do jax[gpu] in our requirements and indicate in the documentation that we simply automate pip install jax[gpu] for you. There is no magic to what we do there anyway, simply do a pip install

htahir1 avatar Jul 05 '22 14:07 htahir1

Okay. I noticed that my style transfer example is most likely >400 LOC, which would make it more complex than other existing examples. Should I settle for something easier instead?

If I do that, it might still require some more boilerplate than usual due to the functional style of Jax. Do let me know what your preference is.

nicholasjng avatar Jul 06 '22 08:07 nicholasjng

Would be happy to see a draft and go with what you like!

htahir1 avatar Jul 06 '22 15:07 htahir1

Apologies for the delay. I worked some more, decided to implement a custom model materializer for Haiku (that way, plugging in different jax NN libraries can also be supported by implementing a materializer each).

Still left TODO:

  • Reorganize functions in the example (model load -> then augmentation)
  • Finish the materializer implementation (especially the return)

nicholasjng avatar Jul 14 '22 20:07 nicholasjng

Really sorry for the turnaround in this review. Somehow the last message escaped my inbox.

I think its hard to review the current state as the zenml repo has sharply diverged from the point of development. Also, please note that we use develop as our trunk branch, meaning your work should branch off of develop and the PR should therefore also be against develop, not main.

Would it be possible to perhaps rebase and resolve conflicts from this point? ill make sure to review faster once thats done. Apologies for the delay!!

htahir1 avatar Jul 31 '22 18:07 htahir1

Hey, apologies for the delay. I had some difficulties at work, so this fell off the edge.

In principle, I hope to continue with it next week latest (this week is almost fully occupied by my work in the bakery). I still need to implement the HDF5 writing part of the materializer. It should, obviously, reproduce the names of the layer groups and the layers themselves, so I will first need to check how the guy who made the repo saved the pretrained weights.

Otherwise, I don't have an environment ready for testing the pipeline (M1 user...), any chance I can scrape by and test this locally without the infamous ml-metadata dependency for now? (Otherwise I need to install all the x86-emulated stuff like Python, brew, etc. side-by-side, and that's something that I would really like to avoid).

EDIT: Rebased on current develop.

nicholasjng avatar Aug 17 '22 08:08 nicholasjng

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@nicholasjng No worries, thanks for contributing! Unfortunately there is no way of running ZenML without ml-metadata (for now). I guess you could test the read/write logic of the model outside of ZenML to make sure that works. Once that is ready, feel free to ping me here and I'll make sure either myself or someone else runs the pipeline

schustmi avatar Aug 17 '22 08:08 schustmi

@nicholasjng No worries, thanks for contributing! Unfortunately there is no way of running ZenML without ml-metadata (for now). I guess you could test the read/write logic of the model outside of ZenML to make sure that works. Once that is ready, feel free to ping me here and I'll make sure either myself or someone else runs the pipeline

Thanks! (Also thank you for cleaning up the branch).

As a quick point of discussion (perhaps also relevant for @htahir1):

In principle, it might be that the materializer contract of reading and writing the same type is violated in this example. Let me explain:

The arrays from the HDF5 file read into memory are directly used to instantiate the style transfer neural network (so the return type of handle_input would be something like Dict[str, Dict[str, jnp.ndarray]]. However, due to JAX's functional design, this stateful initialization needs to be expressed in a functional way, resulting in the hacky pattern of using an init method inside the training script.

This method returns a params object, which holds the weights in a similar structure as the object resulting from the HDF5 file (an hk.FlatMapping), but which is strictly speaking not the same type (for JAX's intents and purposes, it is, however, as both objects are valid PyTrees of the same structure).

So the input to handle_return would be a hk.Params object (a FlatMapping, if you will). Due to the above, it behaves the same for what our intents are. Is it fine to just silently return the params in this type then, or should I explicitly cast to a bare Python dict? (I think there might be a Haiku API for that - I'll check that before implementing.)

nicholasjng avatar Aug 17 '22 08:08 nicholasjng

I added the return materializer implementation, at least the way I believe it is supposed to work. Will revisit the whole thing in the coming evenings, and test it on my local Linux machine in a pipeline towards the weekend.

nicholasjng avatar Aug 22 '22 12:08 nicholasjng

Wanna give this a go @schustmi ? I think the implementation stands as it is right now, the failures are related to linting and spelling

Edit: Should I add a config.yaml file like for other configurable trainers in the examples? I don't know if this is important to get it to work, please advise

nicholasjng avatar Sep 14 '22 08:09 nicholasjng

Hi @nicholasjng, sorry for the late reply. Pretty busy at the moment but I hope to give this a try till Friday!

Regarding the config.yaml: Yes that would be awesome, as we use these to run integration tests for all examples

schustmi avatar Sep 21 '22 13:09 schustmi