jetstream-pytorch icon indicating copy to clipboard operation
jetstream-pytorch copied to clipboard

[RFC] Formalizing commandline arguments.

Open qihqi opened this issue 5 months ago • 0 comments

Recently we added a new cli jpt (https://github.com/google/jetstream-pytorch/pull/178) that massively simplified the command line args the user need to specify. However, there are other commandline args that are optional but the user might need to specify for other reasons, like customizing the workload itself.

We want split and document the current set of flags in the following way:

  1. Distinguish which ones are useful for a user to override: say --max_input_len and --max_output_len. From the ones we added to test different features: (say --ring_buffer used to enable ragged attention kernel). The second class shall start with --internal_ prefix; and have a default value that is supposed to deliver the best performance. The user are not expect to change it. We can retain the flag for testing / validation purposes.

  2. Have some flags to have automatically inferred values. For example, --max_cache_len usually is set to max_input_len + max_output_len; so we should have it's default value inferred instead of erroring out (or default to fixed hardcoded integer).

  3. Flag files for different models / hardware combination. Because we are using absl flags, and we can add a --flagfile file with flags; we can prepare few flagfiles for different "common good case" flags. So users can run jpt serve --flagfile v5e_llama3_8b.txt etc.

qihqi avatar Sep 10 '24 17:09 qihqi