DeepSpeed
DeepSpeed copied to clipboard
[BUG] Install on AMD ROCm system but fails to build on CUDA dependencies
Describe the bug When I try to install deepspeed 0.8.3 on AMD GPU, it is able to detect hip but it fails to build because it cannot find cuda headers.`
To Reproduce Steps to reproduce the behavior: If installed with DS_BUILD_OPS=1
$ DS_BUILD_OPS=1 ./install.sh
Attempting to remove deepspeed/git_version_info_installed.py
Attempting to remove dist
Attempting to remove build
Attempting to remove deepspeed.egg-info
No hostfile exists at /job/hostfile, installing locally
Building deepspeed wheel
DS_BUILD_OPS=1
Traceback (most recent call last):
File "/tmp/DeepSpeed/setup.py", line 167, in <module>
ext_modules.append(builder.builder())
File "/tmp/DeepSpeed/op_builder/builder.py", line 623, in builder
self.build_for_cpu = not assert_no_cuda_mismatch(self.name)
File "/tmp/DeepSpeed/op_builder/builder.py", line 91, in assert_no_cuda_mismatch
cuda_major, cuda_minor = installed_cuda_version(name)
File "/tmp/DeepSpeed/op_builder/builder.py", line 44, in installed_cuda_version
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
AssertionError: CUDA_HOME does not exist, unable to compile CUDA op(s)
Error on line 155
Fail to install deepspeed
If installed with DS_BUILD_OPS=1, it installs correctly. However when using the op, it gives error on not able to find cuda_fp16.h
FAILED: cpu_adam_hip.o
c++ -MMD -MF cpu_adam_hip.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -I/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/csrc/includes -I/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/include -I/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/include/rocrand -I/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/include/hiprand -isystem /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/torch/include -isystem /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/torch/include/TH -isystem /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/torch/include/THC -isystem /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/torch/include/THH -isystem include -isystem /home/liberty/micromamba/envs/torch2-rocm/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++17 -O3 -std=c++14 -g -Wno-reorder -Llib -lcudart -lcublas -g -march=native -fopenmp -D__AVX256__ -D__ENABLE_CUDA__ -c /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/csrc/adam/cpu_adam_hip.cpp -o cpu_adam_hip.o -fPIC -D__HIP_PLATFORM_HCC__=1 -DUSE_ROCM=1
In file included from /home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/csrc/adam/cpu_adam_hip.cpp:2:
/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed/ops/csrc/includes/cpu_adam.h:15:10: fatal error: cuda_fp16.h: No such file or directory
15 | #include <cuda_fp16.h>
| ^~~~~~~~~~~~~
compilation terminated.
ninja: build stopped: subcommand failed.
Expected behavior A clear and concise description of what you expected to happen.
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
fused_adam ............. [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
[WARNING] sparse_attn is not compatible with ROCM
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
transformer_inference .. [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/torch']
torch version .................... 2.0.0+rocm5.4.2
deepspeed install path ........... ['/home/liberty/micromamba/envs/torch2-rocm/lib/python3.10/site-packages/deepspeed']
deepspeed info ................... 0.8.3+b3ec1c97, b3ec1c97, master
torch cuda version ............... None
torch hip version ................ 5.4.22803-474e8620
nvcc version ..................... None
deepspeed wheel compiled w. ...... torch 2.0, hip 5.4
Screenshots N/A
System info (please complete the following information):
- OS: Arch
- GPU count and types: 1 AMD W6800 PRO (ROCm officially supported GFX1030)
- Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
- Python version 3.10
- Any other relevant info about your setup
Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?
Docker context Are you using a specific docker image that you can share?
Additional context Add any other context about the problem here.
I was just going to open a similar issue...
I ran into the same error while trying to pre-compile the ops for the wheel. The original code expects $CUDA_HOME to be set in the environment at this point:
https://github.com/microsoft/DeepSpeed/blob/b3ec1c9712e1f954288811f78f2340d69abe84d1/op_builder/builder.py#L38-L43
However, this isn't the case when building via HIP on AMD systems.
The patch below seems to work around the problem, though I don't think it's the best fix.
diff --git a/op_builder/builder.py b/op_builder/builder.py
index 4e13e204..9ec094ea 100644
--- a/op_builder/builder.py
+++ b/op_builder/builder.py
@@ -88,6 +88,8 @@ cuda_minor_mismatch_ok = {
def assert_no_cuda_mismatch(name=""):
+ if OpBuilder.is_rocm_pytorch():
+ return True
cuda_major, cuda_minor = installed_cuda_version(name)
if cuda_minor == 0 and cuda_major == 0:
return False
Since assert_no_cuda_mismatch is checking CUDA version numbers, I suppose the symmetric fix would be to instead check ROCm version numbers.
Yes, I can verify the patch suggested by @adammoody works around the issue. I do encounter some other issues ROCm 5.4.2 but I will report separately.
Hi @helloworld1 - does this only occur if you pre-compile the ops, could you let me know if this happens if you just use JIT to build the ops or not?
I'm curious if its linked to this PR as well, and is an artifact of us not supporting cross-compilation well or if this is just unique to CUDA vs ROCm.
@loadams I remembered I got the error couldn't find "cuda/cuda_fp16.h". So clearly it doesn't recognize the ROCm's availability.
@helloworld1 - thanks, I was able to repro on my side as well now. I'll take a look at fixing this.
When just returning true like the code above in the assert_no_cuda_mismatch you have to specify DS_BUILD_SPARSE_ATTN=0. So that command becomes DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 ./install.sh. However, that still results in the following error on my side:
csrc/includes/conversion_utils_hip.h:270:12: error: use of undeclared identifier '__double2half'; did you mean '__double2hiint'?
return __double2half(val);
^~~~~~~~~~~~~
__double2hiint
/opt/rocm-5.4.0/hip/include/hip/amd_detail/../../../../include/hip/amd_detail/amd_device_functions.h:440:30: note: '__double2hiint' declared here
__device__ static inline int __double2hiint(double x) {
^
@helloworld1 - is that part of the ROCm 5.4 errors you've seen?
@loadams Yes I encountered the same errors regarding the data type conversation. Rocm does not have the half type there.
ROCm half type should be fixed in #3236. The fix on building the ops without finding cuda home is ongoing.
Nice work! Thanks. Will wait for the full fix to test out.
cc @rraminen
@loadams , I am running into further compile errors like those reported in: https://github.com/microsoft/DeepSpeed/issues/3548 and https://github.com/microsoft/DeepSpeed/issues/3653
I have found that I can complete a build by disabling a number of ops:
DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_QUANTIZER=0 DS_BUILD_RANDOM_LTD=0 DS_BUILD_TRANSFORMER_INFERENCE=0 python3 setup.py bdist_wheel
It seems that HIP may be missing a number of CUDA cooperative_group functions that these ops require, like meta_group_size().
https://github.com/microsoft/DeepSpeed/blob/f5dde36c1af9fe6b166ed05d381f546f53dcebce/csrc/includes/reduction_utils.h#L410
Cooperative group types and functions: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#explicit-groups
vs the "Cooperative Groups Functions" section from: https://docs.amd.com/bundle/HIP-Programming-Guide-v5.5/page/Programming_with_HIP.html
Are these ops not yet supported on AMD?
Yes, these are not yet supported in ROCm. We are working on adding support in ROCm. Additionally, we are also considering adding a way to disable unsupported extensions by default, so that we can enable them as we add support for them.
@adammoody - since I think we have all DeepSpeed related issues solved, I'm going to close this for now. If you hit other issues, please re-open.
Thanks, @loadams . All good by me.
I know @helloworld1 originally opened this ticket.