xla icon indicating copy to clipboard operation
xla copied to clipboard

[XLA:GPU] add force inline and no preserve local option to get better llvm splits

Open Cjkkkk opened this issue 1 year ago • 3 comments
trafficstars

Add option to XLA to enforce inlining before llvm splitModule or set preserveLocals=False to get more balanced splits in parallel compilation case.

Some data of GPT3 5B model with different setting:

Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.754429084s (parallel + inline) runtime: 1.0s
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 4.676450341s (parallel + no inline) runtime: 1.0s
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.051018704s (parallel + perserve_local=False) runtime: 1.4s
Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 4.862938161s (serial) runtime: 1.0s

However, the runtime per step of perserve_locals=False vs other three setup is 1.4s vs 1.0s.

Cjkkkk avatar May 13 '24 23:05 Cjkkkk

Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.754429084 (parallel + inline)

What are the units, seconds?

Mentioning both runtime and compile time in the bug description is a bit confusing, what is the overall effect of the change? No runtime change and ~20% faster compilation?

cheshire avatar May 15 '24 07:05 cheshire

Compilation: TSL:XlaCompile:#module=pjit__wrapped_step_fn,program_id=24#: 3.754429084 (parallel + inline)

What are the units, seconds?

Mentioning both runtime and compile time in the bug description is a bit confusing, what is the overall effect of the change? No runtime change and ~20% faster compilation?

units: seconds. I updated the description to make it clear the compilation + runtime of each setup. this fix is for speeding up compilation time as xla right now is giving imbalanced splits. I mentioned the runtime as well to make sure the changes do not bring slowdown in runtime. I mainly want to push for the inline + parallel compilation, which gives 20% speedup in compilation and no change in runtime. However, preserveLocals=False + parallel compilation is also another option, it gives more speedup in compilation time but a slowdown in runtime. I was wondering if perf sacrifice for faster compilation is acceptable in some workloads? I know 1.4x slowdown is much but in some other workloads it might be less. So I think at least level 1 (inline + parallel) should be default for XLA.

Cjkkkk avatar May 15 '24 18:05 Cjkkkk

I really don't see a case for preserveLocals, unless we setup a ThinLTO-like backend which allows to cross-import and still inline.

preserveLocals is also a mode where the result of the compilation will actually depends on the number of threads, which is not something that seems desirable to me.

joker-eph avatar May 15 '24 18:05 joker-eph

Update some more models compilation with this changes: https://docs.google.com/spreadsheets/d/1uIRf66UT9hOBOge3nvRZebDintgM0zmozNts0tOiXQA/edit?usp=sharing. Seems preserveLocals=False is not doing any better than parallel + inline. So i will just remove that.

Cjkkkk avatar May 16 '24 20:05 Cjkkkk

@cheshire Hi, any updates on this?

Cjkkkk avatar May 20 '24 20:05 Cjkkkk