Add JAX integration
This commit adds a ZenML integration for JAX, a Python ML research framework.
Describe changes
See above.
Pre-requisites
Please ensure you have done the following:
- [x] I have read the CONTRIBUTING.md document.
- [ ] If my change requires a change to docs, I have updated the documentation accordingly.
- [ ] If I have added an integration, I have updated the integrations table.
- [ ] I have added tests to cover my changes.
Types of changes
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to change)
- [ ] Other (add details above)
TODOs:
- Add an example for using JAX together with one of its associated DL frameworks, possibly https://github.com/nicholasjng/jax-styletransfer.
- Add a model materializer for loading weights from a file, e.g. HDF5 (also implemented in the style transfer example I linked above).
- Figure out how to enable the installation of CUDA-ready jaxlib packages if requested. jaxlib is the associated computational backend of JAX, and comes in different wheel flavors depending on what type of installation you want (CPU-only, or with CUDA/ROCm support on Linux/Windows). Supporting all different options will require some work.
Thanks for the PR @nicholasjng!
I think this a good draft but needs some work. Your TODO list is actually quite intriguing and is more in line what a JAX integration could look like for us. Do you need any support in drafting the example and materializers?
Thanks for the PR @nicholasjng!
I think this a good draft but needs some work. Your TODO list is actually quite intriguing and is more in line what a JAX integration could look like for us. Do you need any support in drafting the example and materializers?
I think the materializer API is pretty well documented. I would probably go for something that saves and loads a model to/from disk, like I mentioned earlier.
I can reuse the code from my repo in the example (after some reorganization), for everything else, I will check the "Add your own example" guide.
What is your opinion on the installation point I mentioned? The process is described in the JAX readme, by the way, if you want to make up your mind independently of what I wrote there. Happy to hear your thoughts, going back to work :)
What is your opinion on the installation point I mentioned? The process is described in the JAX readme, by the way, if you want to make up your mind independently of what I wrote there. Happy to hear your thoughts, going back to work :)
Sounds to me like the JAX installation can exist outside of the ZenML integration. We can simply do jax[gpu] in our requirements and indicate in the documentation that we simply automate pip install jax[gpu] for you. There is no magic to what we do there anyway, simply do a pip install
Okay. I noticed that my style transfer example is most likely >400 LOC, which would make it more complex than other existing examples. Should I settle for something easier instead?
If I do that, it might still require some more boilerplate than usual due to the functional style of Jax. Do let me know what your preference is.
Would be happy to see a draft and go with what you like!
Apologies for the delay. I worked some more, decided to implement a custom model materializer for Haiku (that way, plugging in different jax NN libraries can also be supported by implementing a materializer each).
Still left TODO:
- Reorganize functions in the example (model load -> then augmentation)
- Finish the materializer implementation (especially the return)
Really sorry for the turnaround in this review. Somehow the last message escaped my inbox.
I think its hard to review the current state as the zenml repo has sharply diverged from the point of development. Also, please note that we use develop as our trunk branch, meaning your work should branch off of develop and the PR should therefore also be against develop, not main.
Would it be possible to perhaps rebase and resolve conflicts from this point? ill make sure to review faster once thats done. Apologies for the delay!!
Hey, apologies for the delay. I had some difficulties at work, so this fell off the edge.
In principle, I hope to continue with it next week latest (this week is almost fully occupied by my work in the bakery). I still need to implement the HDF5 writing part of the materializer. It should, obviously, reproduce the names of the layer groups and the layers themselves, so I will first need to check how the guy who made the repo saved the pretrained weights.
Otherwise, I don't have an environment ready for testing the pipeline (M1 user...), any chance I can scrape by and test this locally without the infamous ml-metadata dependency for now? (Otherwise I need to install all the x86-emulated stuff like Python, brew, etc. side-by-side, and that's something that I would really like to avoid).
EDIT: Rebased on current develop.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
@nicholasjng No worries, thanks for contributing! Unfortunately there is no way of running ZenML without ml-metadata (for now). I guess you could test the read/write logic of the model outside of ZenML to make sure that works. Once that is ready, feel free to ping me here and I'll make sure either myself or someone else runs the pipeline
@nicholasjng No worries, thanks for contributing! Unfortunately there is no way of running ZenML without
ml-metadata(for now). I guess you could test the read/write logic of the model outside of ZenML to make sure that works. Once that is ready, feel free to ping me here and I'll make sure either myself or someone else runs the pipeline
Thanks! (Also thank you for cleaning up the branch).
As a quick point of discussion (perhaps also relevant for @htahir1):
In principle, it might be that the materializer contract of reading and writing the same type is violated in this example. Let me explain:
The arrays from the HDF5 file read into memory are directly used to instantiate the style transfer neural network (so the return type of handle_input would be something like Dict[str, Dict[str, jnp.ndarray]]. However, due to JAX's functional design, this stateful initialization needs to be expressed in a functional way, resulting in the hacky pattern of using an init method inside the training script.
This method returns a params object, which holds the weights in a similar structure as the object resulting from the HDF5 file (an hk.FlatMapping), but which is strictly speaking not the same type (for JAX's intents and purposes, it is, however, as both objects are valid PyTrees of the same structure).
So the input to handle_return would be a hk.Params object (a FlatMapping, if you will). Due to the above, it behaves the same for what our intents are. Is it fine to just silently return the params in this type then, or should I explicitly cast to a bare Python dict? (I think there might be a Haiku API for that - I'll check that before implementing.)
I added the return materializer implementation, at least the way I believe it is supposed to work. Will revisit the whole thing in the coming evenings, and test it on my local Linux machine in a pipeline towards the weekend.
Wanna give this a go @schustmi ? I think the implementation stands as it is right now, the failures are related to linting and spelling
Edit: Should I add a config.yaml file like for other configurable trainers in the examples? I don't know if this is important to get it to work, please advise
Hi @nicholasjng, sorry for the late reply. Pretty busy at the moment but I hope to give this a try till Friday!
Regarding the config.yaml: Yes that would be awesome, as we use these to run integration tests for all examples