outlines icon indicating copy to clipboard operation
outlines copied to clipboard

JAX compatible API

Open borisdayma opened this issue 1 year ago • 1 comments

Presentation of the new feature

It would be great to have a JAX compatible API with the form of a logit processor.

Input would be current vocab probabilities and output would just make invalid ones based on the grammar at current state.

Where does it fit in Outlines?

I have used outlines with transformers and a similar experience with JAX would be great as there is not currently similar functionality.

Are you willing to open a PR?

Yes, I'd love a hint of where to start (for example recommended high level functions of it was in numpy or torch tensors).

My goal is to find how to integrate it easily with JAX sampling functions such as maxtext: https://github.com/google/maxtext/blob/5bc40298530c7b5acaa42a366da1e6c2d413fac9/MaxText/inference_utils.py#L30

borisdayma avatar Jul 09 '24 01:07 borisdayma

outlines.processors supports a number of array frameworks via dlpack copy-free type conversions. These are incredibly efficient and have near-zero overhead.

It seems Jax supports dlpack https://jax.readthedocs.io/en/latest/jax.dlpack.html

I'm glad you're interested in contributing! The only necessary change to support Jax are

  • Update OutlinesLogitsProcessor to support jax -> torch and torch -> jax. https://github.com/outlines-dev/outlines/blob/main/outlines/processors/base_logits_processor.py#L91-L135
  • Write unit tests (I suggest creating a new tests/processors/test_base_processor.py)

Please let me know if you have any questions.

lapp0 avatar Jul 13 '24 21:07 lapp0

Hi @lapp0 , I made the mentioned changes in the base_logits_processor.py and also wrote the tests. But I am getting an error in mypy pre-commit hook. image And could not find stubs for jax and jaxlib.

Any pointers on how to resolve this?

sky-2002 avatar Oct 13 '24 11:10 sky-2002

Yes you can simply ignore them by adding them to this list in pyproject.toml

rlouf avatar Oct 13 '24 13:10 rlouf

Hey @rlouf , thanks. I have created a draft PR, open to feedback.

sky-2002 avatar Oct 13 '24 16:10 sky-2002