xla
xla copied to clipboard
[XLA:GPU] add force inline and no preserve local option to get better llvm splits
Add option to XLA to enforce inlining before llvm splitModule or set preserveLocals=False to get more balanced splits in parallel compilation case.
Some data of GPT3 5B model with different setting:
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.754429084s (parallel + inline) runtime: 1.0s
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 4.676450341s (parallel + no inline) runtime: 1.0s
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.051018704s (parallel + perserve_local=False) runtime: 1.4s
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 4.862938161s (serial) runtime: 1.0s
However, the runtime per step of perserve_locals=False vs other three setup is 1.4s vs 1.0s.
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.754429084 (parallel + inline)
What are the units, seconds?
Mentioning both runtime and compile time in the bug description is a bit confusing, what is the overall effect of the change? No runtime change and ~20% faster compilation?
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.754429084 (parallel + inline)
What are the units, seconds?
Mentioning both runtime and compile time in the bug description is a bit confusing, what is the overall effect of the change? No runtime change and ~20% faster compilation?
units: seconds. I updated the description to make it clear the compilation + runtime of each setup. this fix is for speeding up compilation time as xla right now is giving imbalanced splits. I mentioned the runtime as well to make sure the changes do not bring slowdown in runtime. I mainly want to push for the inline + parallel compilation, which gives 20% speedup in compilation and no change in runtime. However, preserveLocals=False + parallel compilation is also another option, it gives more speedup in compilation time but a slowdown in runtime. I was wondering if perf sacrifice for faster compilation is acceptable in some workloads? I know 1.4x slowdown is much but in some other workloads it might be less. So I think at least level 1 (inline + parallel) should be default for XLA.
I really don't see a case for preserveLocals, unless we setup a ThinLTO-like backend which allows to cross-import and still inline.
preserveLocals is also a mode where the result of the compilation will actually depends on the number of threads, which is not something that seems desirable to me.
Update some more models compilation with this changes: https://docs.google.com/spreadsheets/d/1uIRf66UT9hOBOge3nvRZebDintgM0zmozNts0tOiXQA/edit?usp=sharing. Seems preserveLocals=False is not doing any better than parallel + inline. So i will just remove that.
@cheshire Hi, any updates on this?