diffsol icon indicating copy to clipboard operation
diffsol copied to clipboard

Wanted - Developers for higher-level language wrappers

Open martinjrobins opened this issue 8 months ago • 10 comments

Diffsol is designed to be easy to use from higher-level languages like Python or R. I'd prefer not to split my focus away from the core library, so I'm looking for developers who would like to lead the development of these wrappers. If you're interested, please get in touch.

  • [x] Python (e.g. using PyO3). https://github.com/alexallmont/pydiffsol.
  • [ ] Python ML frameworks (e.g. JAX, PyTorch)
  • [ ] R (e.g. using extendr).
  • [ ] Julia
  • [ ] Matlab
  • [ ] JavaScript backend (e.g using Neon)
  • [ ] JavaScript in browser (e.g. using wasm-pack)
  • [ ] Others, feel free to suggest your favourite language.

martinjrobins avatar Mar 17 '25 11:03 martinjrobins

Hello Martin! I would love to write the R-package for diffsol. I have been lurking (and part of the discord) just for this moment! Please feel free to communicate with me about your plans, etc.

CGMossa avatar Apr 07 '25 09:04 CGMossa

Hi Martin, I am interested in developing a high level library for the python ml frameworks!

jackbmontgomery avatar Jun 20 '25 11:06 jackbmontgomery

Hi Martin, I am interested in developing a high level library for the python ml frameworks!

Great, thanks @jackbmontgomery. FYI I was thinking that this would probably involve wrapping diffsol as a custom JAX operator. You can read more about jax's FFI interface here (https://docs.jax.dev/en/latest/ffi.html), and I notice that there are already rust bindings for XLA here (https://docs.rs/xla/latest/xla/) which might be useful. You will be able to use the forward and adjoint sensitivity functionality in diffsol to implement the forwards and backwards gradient passes required by JAX. I would implement this as a separate python library using pyo3 and maturin

This is just my idea, but feel free to propose anything else that you are interested in. Please feel free to lead this project however you want, and I'll be available for help and advice as needed. Once you have something I can link to it from this repo and the documentation as well. I'm happy to jump on a call as well if that is useful for you.

martinjrobins avatar Jun 20 '25 12:06 martinjrobins

That all makes sense to me @martinjrobins. I have not used jax's FFI feature before so let me start there. Thanks a lot for the suggestions and I may be in contact soon with some questions.

jackbmontgomery avatar Jun 20 '25 13:06 jackbmontgomery

Hi @martinjrobins, I am just wanting to follow up on what I have been reading and working on with the JAX wrapper. The rust XLA bindings do not include the macros or the buffers that are required by XLA for the arguments and results in custom calls (https://openxla.org/xla/custom_call). So I don't think the crate will work out of the box for this. I have thought of some options for what to do next:

  1. Build on the xla crate to implement these features that are needed by XLA for foreign calls. And then proceed as you have described with py03 and maturin.

  2. Use C++ as the interface between Rust and JAX with c++ binding created by cxx (https://cxx.rs/). Then using nanobind to generate the python package and stubs. This is what I have been experimenting with to replicate the tutorial in this repo.

  3. The final option I have seen is to use a python package called jaxbind. This package allows you to create jax primitive operations by defining the jvp, vjp and behaviour for jax's abstract evaluation to create the operator. This will allow for the use of py03 and maturin as well.

I’d love to hear your thoughts on any of these approaches - whether you see any issues, or if you have preferences or suggestions. Any feedback would be greatly appreciated!

jackbmontgomery avatar Jun 25 '25 15:06 jackbmontgomery

you would have to experiment, but I suspect 2. is the most straightforward and most likely to succeed and the one I would recommend you start with (as you already have), while 1 would be more work but would provide most benifit to the wider rust community. 3 is a bit risky because I don't know the details of jaxbind, and what its restrictions and drawbacks are (e.g. do you have to go back to python with every call to f or vjp/jvp?)

martinjrobins avatar Jun 25 '25 20:06 martinjrobins

Yes I was thinking the same. I think what I will do is carry on working along 2 and experiment with 1 since I think that would be the best implementation of the jax wrapper.

jackbmontgomery avatar Jun 26 '25 13:06 jackbmontgomery

@jackbmontgomery any updates on this?

martinjrobins avatar Oct 29 '25 22:10 martinjrobins

Hi @martinjrobins, sorry for the lack of communication. I started my masters so I wasn't able to work on this very much. It became clear that I would need C++ and I do not know the language. I replicated the jax.ffi example with rust here but I did not have time to do anything more than that.

I see that jaxbind updated their backend to use the new ffi api. So as soon as I get some free time I will try adapt that to make the jax wrapper but I don't think I will have enough time for this until June next year when I finish.

jackbmontgomery avatar Nov 20 '25 12:11 jackbmontgomery

Copying some info on the javascript wrapper from #203:

Sure! diffsol will work with wasm out of the box in the way you are using it, but I'm also keen to get a proper javascript wrapper library, similar to pydiffsol in Python (https://github.com/alexallmont/pydiffsol). The main TODOs here are:

  1. Write an initial javascript wrapper using https://github.com/wasm-bindgen/wasm-bindgen. This can use callbacks into javascript code so that the user can provide rhs, mass matrix functions as plain javascript functions. This is slow, but does not require you to build diffsl and add it to your bundle.
  2. Next step is to add optional diffsl support. This requires you to onstruct a suitable compilation toolchain to be able to build llvm, enzyme, diffsl and diffsol using wasm. This requires building llvm and enzyme to wasm and linking them into a cargo build, I've made a start on the llvm and enzyme build here: https://github.com/martinjrobins/wasm-compilers. This requires llvm v20, and I'm in the process of adding support for this here (https://github.com/martinjrobins/diffsl/commit/4ffa31fa1f526a1dbf167342fdc5fdd73f557443).

martinjrobins avatar Nov 28 '25 18:11 martinjrobins