axlearn
axlearn copied to clipboard
Add Dataflow Inference Examples
This PR adds 2 examples for running batch inference on Dataflow:
- Using a Custom Model Handler for JAX models
- Using a Built-in HuggingFace Model Handler
These pipelines can run on CPUs or GPUs. In order to run on GPUs, see Link. Users would need to create their own custom image with the necessary libraries, and pass in an additional command line flag.
Future work for this PR:
- Implement sharing model across threads on the same machine. See Link.
- Write unit tests