xla
xla copied to clipboard
[ROCm] Triton performance fixes
@zoranjovanovic-ns Could you please help fix the build failures in XLA Linux x86 CPU CI? (Other CIs failed too.)
./xla/backends/gpu/codegen/triton/compilation_pipeline.h:44:36: error: use of undeclared identifier 'se'
44 | mlir::OpPassManager* pm, const se::DeviceDescription& device_info, int num_warps, int num_ctas,
| ^
xla/backends/gpu/codegen/triton/compilation_pipeline_stub.cc:26:36: error: use of undeclared identifier 'se'
26 | mlir::OpPassManager* pm, const se::DeviceDescription& device_info, int num_warps, int num_ctas,
| ^
2 errors generated.
One more CUDA build error please:
xla/pjrt/triton_cuda.cc:81:43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription'
81 | xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps,
| ^~~~~~~~~~~~~~~~~~~~~~
Hi @xla-rotation wondering is this still pending to be approved?
@i-chaochen @zoranjovanovic-ns Please see Penporn's previous comment about more issues that still need to be resolved.
@i-chaochen @zoranjovanovic-ns Please see Penporn's previous comment about more issues that still need to be resolved.
At the moment, working on setting threadsPerWarp on all required places in the code, when this is finished will update this PR accordingly.
@zoranjovanovic-ns I think the changes in the emitters are incorrect. It has to be WarpSize(). And it has to be adapted so that it works on AMD. Are you planning to do that in a follow-up?
@zoranjovanovic-ns I think the changes in the emitters are incorrect. It has to be WarpSize(). And it has to be adapted so that it works on AMD. Are you planning to do that in a follow-up?
Yes, we are planning to modify emitters as a following PR, but for now idea was to avoid test failures.
Rebased, resolved conflicts.
Hi @zoranjovanovic-ns, the changes are breaking the PJRT API. Is there a reason you use a DeviceDescription instead of the arch_name in
absl::Status CreateTritonPipeline(
mlir::OpPassManager* pm,
const stream_executor::DeviceDescription& device_info, int num_warps,
int num_ctas, int num_stages,
mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info,
bool is_xla_fusion);
Error:
[third_party/tensorflow/compiler/xla/pjrt/triton_cuda.cc:82] tensorflow/compiler/xla/pjrt/triton_cuda.cc?l=82):43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription'
82 | xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps,
| ^~~~~~~~~~~~~~~~~~~~~~
Hi @zoranjovanovic-ns, the changes are breaking the PJRT API. Is there a reason you use a DeviceDescription instead of the arch_name in
absl::Status CreateTritonPipeline( mlir::OpPassManager* pm, const stream_executor::DeviceDescription& device_info, int num_warps, int num_ctas, int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info, bool is_xla_fusion);Error:
[third_party/tensorflow/compiler/xla/pjrt/triton_cuda.cc:82] tensorflow/compiler/xla/pjrt/triton_cuda.cc?l=82):43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription' 82 | xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps, | ^~~~~~~~~~~~~~~~~~~~~~
Hi @derdrdirk , we use DeviceDescription to set correct value (for different arch) for threadsPerWarp required by several Triton passes to generate optimal code. Also, believe that in the future there will be more situations like this (that we need more device specific data). We had discussion on this on our AMD/Google OpenXLA channel (with Sergei Lebedev, Benjamin Chetioui and Henning Becker) and got approval to proceed with this change.
Hi @zoranjovanovic-ns, the changes are breaking the PJRT API. Is there a reason you use a DeviceDescription instead of the arch_name in
absl::Status CreateTritonPipeline( mlir::OpPassManager* pm, const stream_executor::DeviceDescription& device_info, int num_warps, int num_ctas, int num_stages, mlir::triton::nvidia_gpu::ClusterInfo& out_cluster_info, bool is_xla_fusion);Error:
[third_party/tensorflow/compiler/xla/pjrt/triton_cuda.cc:82] tensorflow/compiler/xla/pjrt/triton_cuda.cc?l=82):43: error: no viable conversion from 'std::string' (aka 'basic_string<char>') to 'const stream_executor::DeviceDescription' 82 | xla::gpu::CreateTritonPipeline(&pm, std::string(arch_name), num_warps, | ^~~~~~~~~~~~~~~~~~~~~~Hi @derdrdirk , we use DeviceDescription to set correct value (for different arch) for threadsPerWarp required by several Triton passes to generate optimal code. Also, believe that in the future there will be more situations like this (that we need more device specific data). We had discussion on this on our AMD/Google OpenXLA channel (with Sergei Lebedev, Benjamin Chetioui and Henning Becker) and got approval to proceed with this change.
The problem is that we cannot merge a PR that is breaking our build. Can you please check whether you can reproduce this compile error, and work on a fix?
Edit: the callsite that is broken is here: https://github.com/openxla/xla/blob/main/xla/pjrt/triton_cuda.cc
@zoranjovanovic-ns thanks for the quick reply. I am all for submitting this change, but have a hard time on how to propagate the DeviceDescription.
Could you use the arch_name to get the threadsPerWarp? IIUC the arch_name comes from the PJRT interface https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_triton_extension.h, which would be hard to update.
@zoranjovanovic-ns thanks for the quick reply. I am all for submitting this change, but have a hard time on how to propagate the DeviceDescription.
Could you use the arch_name to get the threadsPerWarp? IIUC the arch_name comes from the PJRT interface https://github.com/openxla/xla/blob/main/xla/pjrt/c/pjrt_c_api_triton_extension.h, which would be hard to update.
We can probably do that for threads per warp. Do you want me to modify this PR not to include change of CreateTritonPipeline interface?
That would be great. If https://github.com/openxla/xla/blob/main/xla/pjrt/triton_cuda.cc builds then we merge.
This actually breaks newer XLA on MI300X because there is an assertion when mfma is enabled that checks that the threadPerWarp is 64 on that arch.
The main issue is the device description, which calls into hip to that that info, feels a bit more robust. On CUDA it is hardcoded to 32.
I feel that the PJRT interface might not be what's best in that case. How could we deduce it from the arch_name only ?
That would be great. If https://github.com/openxla/xla/blob/main/xla/pjrt/triton_cuda.cc builds then we merge.
@derdrdirk @xla-rotation Updated PR. Keeping original CreateTritonPipeline singature.
Added small fix for GetThreadsPerWarp in rocm_executor.cc
@dimitar-asenov Added small fix for GetThreadsPerWarp in rocm_executor.cc. Sorry for inconvenience.
@zoranjovanovic-ns: We just noticed that this PR does not add any tests. It would be easy to regress. Therefore, please add a test that uses two different AMD architectures and results in two different threads_per_warp numbers. Up to you exactly what kind of test that should be.
@zoranjovanovic-ns Also, could you please look into the failing AMD ROCm -- Community CI Build failures below?
Hi @dimitar-asenov ROCm CI issue is fixed on our side. Now should be green. Thanks again!
@zoranjovanovic-ns:
Regarding the CI: it's still red here. If you fixed something with the integration (as opposed to the code), perhaps it will be green when we next rerun it. If there was a code fix, I don't see any new commits, perhaps you still need to upload it?
Regarding the tests: I still don't see any new commits adding tests. As soon as we have those, I can try to merge this internally. Thanks!
Hi @dimitar-asenov the CI is red that is nothing with this PR. This PR is no problem on the rocm build or unit tests on rocm....it's because something else in CI node issue to cause this PR is red.. our devops team is working on it now...sorry for the confused!
Fixed an issue CreateTritonPipeline, still working on test, will provide it by the end of the day.
Hi @zoranjovanovic-ns
Just wanted to drop in the discussion because we are impacted by this PR. We have a patch for our ROCm build of PJRT.
Concerning the CreateTritonPipeline method, I think a better fix would to pass a DeviceDescriptor instead of a std::string. Wherever CreateTritonPipeline is called you have access to a device_info variable. The advantages are two folds:
- You can directly access
device_info.(cuda|rocm)_compute_capability()instead of rebuilding it from the string - You have access to
device_info.threads_per_warp()instead of relying on a brittle match on a character of thearch_namestring
You can take a look at our patch
Concerning the
CreateTritonPipelinemethod, I think a better fix would to pass aDeviceDescriptorinstead of astd::string
Original PR version had DeviceDecription instead of std::string, I modified it because of comments here: https://github.com/openxla/xla/pull/23688#issuecomment-2823624862 https://github.com/openxla/xla/pull/23688#issuecomment-2823859136
I see
My initial thought was that it would be better to make XLA's core robust and not being blocked by an extension but I don't really have a solution for the extension...
Probably we should consider that type of change, but I would like to separate it from our triton performance fixes (this PR).
Hi @xla-rotation @dimitar-asenov, added unit test.