JetStream
JetStream copied to clipboard
Refactor jestream to allow different tokenizers
Issue
Currently we assume few things in jetstream which hinders it's generalization:
- tokenizer is SentencePiece based.
- pad_id is 0
- after encode, we pad to nearest power of 2
- ResultToken itself is jax specific (the @struct.dataclass annotation requires it's Jax pytreeable).
These assumptions hinders generalization (i.e. support wider varieties of models). Examples:
- llama3 uses tiktoken instead of SentencePiece
- llama3 uses pad_id of -1
- Pytorch GPU does NOT need to pad to nearst power of 2.
- Pytorch GPU version of jetstream would like to use
torch.Tensor
to hold the data, which is not jax-pytreeable.
Proposal:
-
EngineAPI.get_tokenizer
which returns the tokenizer, should be any object that implements the following interface:
def encode
def decode
@property
def pad_id
@property
def eos_id
Uses of tokenizer should restrict to only this methods.
In particular: encode
should do both encoding and padding. So jetstream doesnt do any padding itself; the engine can choose how to pad (or not to pad) by returning a custom tokenizer object whose encode also does the padding.
- Allow use different implementation for
ResultTokens
; same asPrefix
andDecodeState
. Implementations of the Engine can choose implementation of ResultTokens. jestream should interact with it only through it's 3 public methods (https://github.com/google/JetStream/blob/main/jetstream/engine/engine_api.py#L83)
def get_result_at_slot
def convert_to_numpy
def copy_to_host_async
and are not allowed to access it's fields directly.
#53, #67 adds the interface for tokenizer. Still need to work on ResultTokens.