outlines
outlines copied to clipboard
JAX compatible API
Presentation of the new feature
It would be great to have a JAX compatible API with the form of a logit processor.
Input would be current vocab probabilities and output would just make invalid ones based on the grammar at current state.
Where does it fit in Outlines?
I have used outlines with transformers and a similar experience with JAX would be great as there is not currently similar functionality.
Are you willing to open a PR?
Yes, I'd love a hint of where to start (for example recommended high level functions of it was in numpy or torch tensors).
My goal is to find how to integrate it easily with JAX sampling functions such as maxtext: https://github.com/google/maxtext/blob/5bc40298530c7b5acaa42a366da1e6c2d413fac9/MaxText/inference_utils.py#L30
outlines.processors supports a number of array frameworks via dlpack copy-free type conversions. These are incredibly efficient and have near-zero overhead.
It seems Jax supports dlpack https://jax.readthedocs.io/en/latest/jax.dlpack.html
I'm glad you're interested in contributing! The only necessary change to support Jax are
- Update
OutlinesLogitsProcessorto support jax -> torch and torch -> jax. https://github.com/outlines-dev/outlines/blob/main/outlines/processors/base_logits_processor.py#L91-L135 - Write unit tests (I suggest creating a new
tests/processors/test_base_processor.py)
Please let me know if you have any questions.
Hi @lapp0 , I made the mentioned changes in the base_logits_processor.py and also wrote the tests. But I am getting an error in mypy pre-commit hook.
And could not find stubs for
jax and jaxlib.
Any pointers on how to resolve this?
Yes you can simply ignore them by adding them to this list in pyproject.toml
Hey @rlouf , thanks. I have created a draft PR, open to feedback.