xla icon indicating copy to clipboard operation
xla copied to clipboard

[ROCm] Triton performance fixes

Open zoranjovanovic-ns opened this issue 8 months ago • 14 comments

zoranjovanovic-ns avatar Mar 13 '25 13:03 zoranjovanovic-ns

@zoranjovanovic-ns Could you please help fix the build failures in XLA Linux x86 CPU CI? (Other CIs failed too.)

./xla/backends/gpu/codegen/triton/compilation_pipeline.h:44:36: error: use of undeclared identifier 'se'
   44 |     mlir::OpPassManager* pm, const se::DeviceDescription& device_info, int num_warps, int num_ctas,
      |                                    ^
xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc:26:36: error: use of undeclared identifier 'se'
   26 |     mlir::OpPassManager* pm, const se::DeviceDescription& device_info, int num_warps, int num_ctas,
      |                                    ^
2 errors generated.

penpornk avatar Mar 14 '25 18:03 penpornk

One more CUDA build error please:

xla/pjrt/triton_cuda.cc:81:43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription'
   81 |       xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps,
      |                                           ^~~~~~~~~~~~~~~~~~~~~~

penpornk avatar Mar 17 '25 04:03 penpornk

Hi @xla-rotation wondering is this still pending to be approved?

i-chaochen avatar Mar 25 '25 17:03 i-chaochen

@i-chaochen @zoranjovanovic-ns Please see Penporn's previous comment about more issues that still need to be resolved.

dimitar-asenov avatar Mar 31 '25 06:03 dimitar-asenov

@i-chaochen @zoranjovanovic-ns Please see Penporn's previous comment about more issues that still need to be resolved.

At the moment, working on setting threadsPerWarp on all required places in the code, when this is finished will update this PR accordingly.

zoranjovanovic-ns avatar Mar 31 '25 10:03 zoranjovanovic-ns

@zoranjovanovic-ns I think the changes in the emitters are incorrect. It has to be WarpSize(). And it has to be adapted so that it works on AMD. Are you planning to do that in a follow-up?

pifon2a avatar Apr 14 '25 10:04 pifon2a

@zoranjovanovic-ns I think the changes in the emitters are incorrect. It has to be WarpSize(). And it has to be adapted so that it works on AMD. Are you planning to do that in a follow-up?

Yes, we are planning to modify emitters as a following PR, but for now idea was to avoid test failures.

zoranjovanovic-ns avatar Apr 17 '25 13:04 zoranjovanovic-ns

Rebased, resolved conflicts.

zoranjovanovic-ns avatar Apr 22 '25 10:04 zoranjovanovic-ns

Hi @zoranjovanovic-ns, the changes are breaking the PJRT API. Is there a reason you use a DeviceDescription instead of the arch_name in

absl::Status CreateTritonPipeline(
    mlir::OpPassManager* pm,
    const stream_executor::DeviceDescription& device_info, int num_warps,
    int num_ctas, int num_stages,
    mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
    bool is_xla_fusion);

Error:

[third_party/tensorflow/compiler/xla/pjrt/triton_cuda.cc:82] tensorflow/compiler/xla/pjrt/triton_cuda.cc?l=82):43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription'
   82 |       xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps,
      |                                           ^~~~~~~~~~~~~~~~~~~~~~

derdrdirk avatar Apr 23 '25 08:04 derdrdirk

Hi @zoranjovanovic-ns, the changes are breaking the PJRT API. Is there a reason you use a DeviceDescription instead of the arch_name in

absl::Status CreateTritonPipeline(
    mlir::OpPassManager* pm,
    const stream_executor::DeviceDescription& device_info, int num_warps,
    int num_ctas, int num_stages,
    mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
    bool is_xla_fusion);

Error:

[third_party/tensorflow/compiler/xla/pjrt/triton_cuda.cc:82] tensorflow/compiler/xla/pjrt/triton_cuda.cc?l=82):43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription'
   82 |       xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps,
      |                                           ^~~~~~~~~~~~~~~~~~~~~~

Hi @derdrdirk , we use DeviceDescription to set correct value (for different arch) for threadsPerWarp required by several Triton passes to generate optimal code. Also, believe that in the future there will be more situations like this (that we need more device specific data). We had discussion on this on our AMD/Google OpenXLA channel (with Sergei Lebedev, Benjamin Chetioui and Henning Becker) and got approval to proceed with this change.

zoranjovanovic-ns avatar Apr 23 '25 08:04 zoranjovanovic-ns

Hi @zoranjovanovic-ns, the changes are breaking the PJRT API. Is there a reason you use a DeviceDescription instead of the arch_name in

absl::Status CreateTritonPipeline(
    mlir::OpPassManager* pm,
    const stream_executor::DeviceDescription& device_info, int num_warps,
    int num_ctas, int num_stages,
    mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
    bool is_xla_fusion);

Error:

[third_party/tensorflow/compiler/xla/pjrt/triton_cuda.cc:82] tensorflow/compiler/xla/pjrt/triton_cuda.cc?l=82):43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription'
   82 |       xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps,
      |                                           ^~~~~~~~~~~~~~~~~~~~~~

Hi @derdrdirk , we use DeviceDescription to set correct value (for different arch) for threadsPerWarp required by several Triton passes to generate optimal code. Also, believe that in the future there will be more situations like this (that we need more device specific data). We had discussion on this on our AMD/Google OpenXLA channel (with Sergei Lebedev, Benjamin Chetioui and Henning Becker) and got approval to proceed with this change.

The problem is that we cannot merge a PR that is breaking our build. Can you please check whether you can reproduce this compile error, and work on a fix?

Edit: the callsite that is broken is here: https://github.com/openxla/xla/blob/main/xla/pjrt/triton_cuda.cc

akuegel avatar Apr 23 '25 09:04 akuegel

@zoranjovanovic-ns thanks for the quick reply. I am all for submitting this change, but have a hard time on how to propagate the DeviceDescription.

Could you use the arch_name to get the threadsPerWarp? IIUC the arch_name comes from the PJRT interface https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_triton_extension.h, which would be hard to update.

derdrdirk avatar Apr 23 '25 09:04 derdrdirk

@zoranjovanovic-ns thanks for the quick reply. I am all for submitting this change, but have a hard time on how to propagate the DeviceDescription.

Could you use the arch_name to get the threadsPerWarp? IIUC the arch_name comes from the PJRT interface https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_triton_extension.h, which would be hard to update.

We can probably do that for threads per warp. Do you want me to modify this PR not to include change of CreateTritonPipeline interface?

zoranjovanovic-ns avatar Apr 23 '25 09:04 zoranjovanovic-ns

That would be great. If https://github.com/openxla/xla/blob/main/xla/pjrt/triton_cuda.cc builds then we merge.

derdrdirk avatar Apr 23 '25 10:04 derdrdirk

This actually breaks newer XLA on MI300X because there is an assertion when mfma is enabled that checks that the threadPerWarp is 64 on that arch.

The main issue is the device description, which calls into hip to that that info, feels a bit more robust. On CUDA it is hardcoded to 32.

I feel that the PJRT interface might not be what's best in that case. How could we deduce it from the arch_name only ?

steeve avatar Jun 09 '25 19:06 steeve

That would be great. If https://github.com/openxla/xla/blob/main/xla/pjrt/triton_cuda.cc builds then we merge.

@derdrdirk @xla-rotation Updated PR. Keeping original CreateTritonPipeline singature.

zoranjovanovic-ns avatar Jun 10 '25 11:06 zoranjovanovic-ns

Added small fix for GetThreadsPerWarp in rocm_executor.cc

zoranjovanovic-ns avatar Jun 10 '25 12:06 zoranjovanovic-ns

@dimitar-asenov Added small fix for GetThreadsPerWarp in rocm_executor.cc. Sorry for inconvenience.

zoranjovanovic-ns avatar Jun 10 '25 12:06 zoranjovanovic-ns

@zoranjovanovic-ns: We just noticed that this PR does not add any tests. It would be easy to regress. Therefore, please add a test that uses two different AMD architectures and results in two different threads_per_warp numbers. Up to you exactly what kind of test that should be.

dimitar-asenov avatar Jun 10 '25 12:06 dimitar-asenov

@zoranjovanovic-ns Also, could you please look into the failing AMD ROCm -- Community CI Build failures below?

dimitar-asenov avatar Jun 10 '25 13:06 dimitar-asenov

Hi @dimitar-asenov ROCm CI issue is fixed on our side. Now should be green. Thanks again!

i-chaochen avatar Jun 10 '25 14:06 i-chaochen

@zoranjovanovic-ns:

Regarding the CI: it's still red here. If you fixed something with the integration (as opposed to the code), perhaps it will be green when we next rerun it. If there was a code fix, I don't see any new commits, perhaps you still need to upload it?

Regarding the tests: I still don't see any new commits adding tests. As soon as we have those, I can try to merge this internally. Thanks!

dimitar-asenov avatar Jun 10 '25 19:06 dimitar-asenov

Hi @dimitar-asenov the CI is red that is nothing with this PR. This PR is no problem on the rocm build or unit tests on rocm....it's because something else in CI node issue to cause this PR is red.. our devops team is working on it now...sorry for the confused!

i-chaochen avatar Jun 10 '25 23:06 i-chaochen

Fixed an issue CreateTritonPipeline, still working on test, will provide it by the end of the day.

zoranjovanovic-ns avatar Jun 11 '25 10:06 zoranjovanovic-ns

Hi @zoranjovanovic-ns

Just wanted to drop in the discussion because we are impacted by this PR. We have a patch for our ROCm build of PJRT.

Concerning the CreateTritonPipeline method, I think a better fix would to pass a DeviceDescriptor instead of a std::string. Wherever CreateTritonPipeline is called you have access to a device_info variable. The advantages are two folds:

  • You can directly access device_info.(cuda|rocm)_compute_capability() instead of rebuilding it from the string
  • You have access to device_info.threads_per_warp() instead of relying on a brittle match on a character of the arch_name string

You can take a look at our patch

Corendos avatar Jun 11 '25 14:06 Corendos

Concerning the CreateTritonPipeline method, I think a better fix would to pass a DeviceDescriptor instead of a std::string

Original PR version had DeviceDecription instead of std::string, I modified it because of comments here: https://github.com/openxla/xla/pull/23688#issuecomment-2823624862 https://github.com/openxla/xla/pull/23688#issuecomment-2823859136

zoranjovanovic-ns avatar Jun 11 '25 14:06 zoranjovanovic-ns

I see

My initial thought was that it would be better to make XLA's core robust and not being blocked by an extension but I don't really have a solution for the extension...

Corendos avatar Jun 11 '25 15:06 Corendos

Probably we should consider that type of change, but I would like to separate it from our triton performance fixes (this PR).

zoranjovanovic-ns avatar Jun 11 '25 15:06 zoranjovanovic-ns

Hi @xla-rotation @dimitar-asenov, added unit test.

zoranjovanovic-ns avatar Jun 11 '25 15:06 zoranjovanovic-ns