jetstream-pytorch
jetstream-pytorch copied to clipboard
[RFC] Formalizing commandline arguments.
Recently we added a new cli jpt
(https://github.com/google/jetstream-pytorch/pull/178) that massively simplified the command line args the user need to specify. However, there are other commandline args that are optional but the user might need to specify for other reasons, like customizing the workload itself.
We want split and document the current set of flags in the following way:
-
Distinguish which ones are useful for a user to override: say
--max_input_len
and--max_output_len
. From the ones we added to test different features: (say--ring_buffer
used to enable ragged attention kernel). The second class shall start with--internal_
prefix; and have a default value that is supposed to deliver the best performance. The user are not expect to change it. We can retain the flag for testing / validation purposes. -
Have some flags to have automatically inferred values. For example,
--max_cache_len
usually is set tomax_input_len + max_output_len
; so we should have it's default value inferred instead of erroring out (or default to fixed hardcoded integer). -
Flag files for different models / hardware combination. Because we are using absl flags, and we can add a
--flagfile
file with flags; we can prepare few flagfiles for different "common good case" flags. So users can runjpt serve --flagfile v5e_llama3_8b.txt
etc.