[BUG] Problem with CPUAdam compilation on AMD CPUs
Describe the bug The CPUAdam gives compilation errors when on AMD CPUs. The errors show the compiler is invoked with -march=x86-64-v3 (no AVX-512) while the code uses mm512* intrinsics (and defines AVX512).
I'm using AMD Zen4. I was able to fix the issue by returning the correct output from this function: https://github.com/deepspeedai/DeepSpeed/blob/0e859aa0d37fda1160918421fcf7f36528cae2e8/op_builder/builder.py#L403
It defaults to returning the -march=x86-64-v3 flag even though in my case I needed the -march=skylake-avx512 flag.
Error Logs
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/3] c++ -MMD -MF cpu_adam_impl.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -isystem /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include -isystem /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/include -isystem /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v4/Compiler/gcccore/python/3.11.5/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/lib64 -lcudart -lcublas -g -march=x86-64-v3 -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -UC10_USE_GLOG -DBF16_AVAILABLE -c /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp -o cpu_adam_impl.o
FAILED: cpu_adam_impl.o
c++ -MMD -MF cpu_adam_impl.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -isystem /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include -isystem /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/include -isystem /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v4/Compiler/gcccore/python/3.11.5/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/lib64 -lcudart -lcublas -g -march=x86-64-v3 -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -UC10_USE_GLOG -DBF16_AVAILABLE -c /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp -o cpu_adam_impl.o
In file included from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/cpu_adam.h:14,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp:14:
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h: In function ‘__m512 load_16_bf16_as_f32(const void*)’:
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:47:51: warning: AVX512F vector return without AVX512F enabled changes the ABI [-Wpsabi]
47 | static __m512 load_16_bf16_as_f32(const void* data)
| ^
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h: In function ‘void simd_mul(AVX_Data*, AVX_Data*, AVX_Data) [with int span = 8]’:
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:231:13: note: the ABI for passing parameters with 64-byte alignment has changed in GCC 4.6
231 | inline void simd_mul(AVX_Data* dst, AVX_Data* src_a_l, AVX_Data src_a_r)
| ^~~~~~~~
In file included from /cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/immintrin.h:55,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/c10/util/Half.h:59,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/c10/util/Float8_e5m2.h:17,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/c10/core/ScalarType.h:8,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/c10/core/Scalar.h:9,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/ATen/core/TensorBody.h:16,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/ATen/core/Tensor.h:3,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/ATen/Tensor.h:3,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/autograd/function_hook.h:3,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/autograd/cpp_hook.h:2,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/autograd/variable.h:6,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/autograd/autograd.h:3,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/autograd.h:3,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include/torch/all.h:7,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/extension.h:5,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam_impl.cpp:6:
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h: In function ‘__m512 load_16_bf16_as_f32(const void*)’:
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h:5034:1: error: inlining failed in call to ‘always_inline’ ‘__m512i _mm512_slli_epi32(__m512i, unsigned int)’: target specific option mismatch
5034 | _mm512_slli_epi32 (__m512i __A, unsigned int __B)
| ^~~~~~~~~~~~~~~~~
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:51:34: note: called from here
51 | __m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
| ~~~~~~~~~~~~~~~~~^~~~~~~
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h:5811:1: error: inlining failed in call to ‘always_inline’ ‘__m512i _mm512_cvtepu16_epi32(__m256i)’: target specific option mismatch
5811 | _mm512_cvtepu16_epi32 (__m256i __A)
| ^~~~~~~~~~~~~~~~~~~~~
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:50:38: note: called from here
50 | __m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
| ~~~~~~~~~~~~~~~~~~~~~^~~
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h:5034:1: error: inlining failed in call to ‘always_inline’ ‘__m512i _mm512_slli_epi32(__m512i, unsigned int)’: target specific option mismatch
5034 | _mm512_slli_epi32 (__m512i __A, unsigned int __B)
| ^~~~~~~~~~~~~~~~~
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:51:34: note: called from here
51 | __m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
| ~~~~~~~~~~~~~~~~~^~~~~~~
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h:5811:1: error: inlining failed in call to ‘always_inline’ ‘__m512i _mm512_cvtepu16_epi32(__m256i)’: target specific option mismatch
5811 | _mm512_cvtepu16_epi32 (__m256i __A)
| ^~~~~~~~~~~~~~~~~~~~~
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:50:38: note: called from here
50 | __m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
| ~~~~~~~~~~~~~~~~~~~~~^~~
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h:5034:1: error: inlining failed in call to ‘always_inline’ ‘__m512i _mm512_slli_epi32(__m512i, unsigned int)’: target specific option mismatch
5034 | _mm512_slli_epi32 (__m512i __A, unsigned int __B)
| ^~~~~~~~~~~~~~~~~
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:51:34: note: called from here
51 | __m512i c = _mm512_slli_epi32(b, 16); // logical shift left of all u32 by
| ~~~~~~~~~~~~~~~~~^~~~~~~
/cvmfs/soft.computecanada.ca/gentoo/2023/x86-64-v3/usr/lib/gcc/x86_64-pc-linux-gnu/14/include/avx512fintrin.h:5811:1: error: inlining failed in call to ‘always_inline’ ‘__m512i _mm512_cvtepu16_epi32(__m256i)’: target specific option mismatch
5811 | _mm512_cvtepu16_epi32 (__m256i __A)
| ^~~~~~~~~~~~~~~~~~~~~
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:50:38: note: called from here
50 | __m512i b = _mm512_cvtepu16_epi32(a); // convert 8 u16 to 8 u32
| ~~~~~~~~~~~~~~~~~~~~~^~~
[2/3] c++ -MMD -MF cpu_adam.o.d -DTORCH_EXTENSION_NAME=cpu_adam -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1018\" -I/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes -isystem /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include -isystem /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/include/torch/csrc/api/include -isystem /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/include -isystem /cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v4/Compiler/gcccore/python/3.11.5/include/python3.11 -D_GLIBCXX_USE_CXX11_ABI=1 -fPIC -std=c++17 -O3 -std=c++17 -g -Wno-reorder -L/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/lib64 -lcudart -lcublas -g -march=x86-64-v3 -fopenmp -D__AVX512__ -D__ENABLE_CUDA__ -UC10_USE_GLOG -DBF16_AVAILABLE -c /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp -o cpu_adam.o
In file included from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/cpu_adam.h:14,
from /scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/adam/cpu_adam.cpp:6:
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h: In function ‘__m512 load_16_bf16_as_f32(const void*)’:
/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/csrc/includes/simd.h:47:51: warning: AVX512F vector return without AVX512F enabled changes the ABI [-Wpsabi]
47 | static __m512 load_16_bf16_as_f32(const void* data)
| ^
ninja: build stopped: subcommand failed.
[rank0]: Traceback (most recent call last):
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2506, in _run_ninja_build
[rank0]: subprocess.run(
[rank0]: File "/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v4/Compiler/gcccore/python/3.11.5/lib/python3.11/subprocess.py", line 571, in run
[rank0]: raise CalledProcessError(retcode, process.args,
[rank0]: subprocess.CalledProcessError: Command '['ninja', '-v']' returned non-zero exit status 1.
[rank0]: The above exception was the direct cause of the following exception:
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/alis/links/scratch/repos/perf-pilot/llama/train.py", line 371, in <module>
[rank0]: main()
[rank0]: File "/home/alis/links/scratch/repos/perf-pilot/llama/train.py", line 346, in main
[rank0]: trainer.train(resume_from_checkpoint= True if args.continue_from_dir else False)
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/transformers/trainer.py", line 2328, in train
[rank0]: return inner_training_loop(
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/transformers/trainer.py", line 2483, in _inner_training_loop
[rank0]: model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/accelerate/accelerator.py", line 1547, in prepare
[rank0]: result = self._prepare_deepspeed(*args)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/accelerate/accelerator.py", line 2279, in _prepare_deepspeed
[rank0]: optimizer = map_pytorch_optim_to_deepspeed(optimizer)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/accelerate/utils/deepspeed.py", line 97, in map_pytorch_optim_to_deepspeed
[rank0]: return optimizer_class(optimizer.param_groups, **defaults)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/adam/cpu_adam.py", line 94, in __init__
[rank0]: self.ds_opt_adam = CPUAdamBuilder().load()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/op_builder/builder.py", line 539, in load
[rank0]: return self.jit_load(verbose)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/op_builder/builder.py", line 588, in jit_load
[rank0]: op_module = load(name=self.name,
[rank0]: ^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 1623, in load
[rank0]: return _jit_compile(
[rank0]: ^^^^^^^^^^^^^
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2076, in _jit_compile
[rank0]: _write_ninja_file_and_build_library(
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2222, in _write_ninja_file_and_build_library
[rank0]: _run_ninja_build(
[rank0]: File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch/utils/cpp_extension.py", line 2522, in _run_ninja_build
[rank0]: raise RuntimeError(message) from e
[rank0]: RuntimeError: Error building extension 'cpu_adam'
Exception ignored in: <function DeepSpeedCPUAdam.__del__ at 0x14af92646fc0>
Traceback (most recent call last):
File "/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed/ops/adam/cpu_adam.py", line 102, in __del__
AttributeError: 'DeepSpeedCPUAdam' object has no attribute 'ds_opt_adam'
To Reproduce Steps to reproduce the behavior:
- Start finetuning Llama3.1 8B with CPU offload
- DeepSpeed tries to JIT compile CPU Adam
- Error happens
Expected behavior CPU Adam should compile on all CPU architectures
ds_report output
--------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
runtime if needed. Op compatibility means that your system
meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [92m[OKAY][0m
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
async_io ............... [93m[NO][0m ....... [92m[OKAY][0m
fused_adam ............. [93m[NO][0m ....... [92m[OKAY][0m
cpu_adam ............... [93m[NO][0m ....... [92m[OKAY][0m
cpu_adagrad ............ [93m[NO][0m ....... [92m[OKAY][0m
cpu_lion ............... [93m[NO][0m ....... [92m[OKAY][0m
dc ..................... [93m[NO][0m ....... [92m[OKAY][0m
[93m [WARNING] [0m Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [93m[NO][0m ....... [93m[NO][0m
[93m [WARNING] [0m FP Quantizer is using an untested triton version (3.3.1), only 2.3.(0, 1) and 3.0.0 are known to be compatible with these kernels
fp_quantizer ........... [93m[NO][0m ....... [93m[NO][0m
fused_lamb ............. [93m[NO][0m ....... [92m[OKAY][0m
fused_lion ............. [93m[NO][0m ....... [92m[OKAY][0m
gcc -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=x86-64-v4 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=x86-64-v4 -fno-math-errno -fPIC -fPIC -c /tmp/tmpkdakuhy7/test.c -o /tmp/tmpkdakuhy7/test.o
gcc /tmp/tmpkdakuhy7/test.o -L/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2 -L/cvmfs/soft.computecanada.ca/easybuild/software/2023/x86-64-v3/Core/cudacore/12.6.2/lib64 -lcufile -o /tmp/tmpkdakuhy7/a.out
gcc -DNDEBUG -g -fwrapv -O3 -Wall -O2 -ftree-vectorize -march=x86-64-v4 -fno-math-errno -fPIC -O2 -ftree-vectorize -march=x86-64-v4 -fno-math-errno -fPIC -fPIC -c /tmp/tmpe5xejuj0/test.c -o /tmp/tmpe5xejuj0/test.o
gcc /tmp/tmpe5xejuj0/test.o -laio -o /tmp/tmpe5xejuj0/a.out
gds .................... [93m[NO][0m ....... [92m[OKAY][0m
transformer_inference .. [93m[NO][0m ....... [92m[OKAY][0m
inference_core_ops ..... [93m[NO][0m ....... [92m[OKAY][0m
cutlass_ops ............ [93m[NO][0m ....... [92m[OKAY][0m
quantizer .............. [93m[NO][0m ....... [92m[OKAY][0m
ragged_device_ops ...... [93m[NO][0m ....... [92m[OKAY][0m
ragged_ops ............. [93m[NO][0m ....... [92m[OKAY][0m
random_ltd ............. [93m[NO][0m ....... [92m[OKAY][0m
[93m [WARNING] [0m sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.7
[93m [WARNING] [0m using untested triton version (3.3.1), only 1.0.0 is known to be compatible
sparse_attn ............ [93m[NO][0m ....... [93m[NO][0m
spatial_inference ...... [93m[NO][0m ....... [92m[OKAY][0m
transformer ............ [93m[NO][0m ....... [92m[OKAY][0m
stochastic_transformer . [93m[NO][0m ....... [92m[OKAY][0m
utils .................. [93m[NO][0m ....... [92m[OKAY][0m
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/torch']
torch version .................... 2.7.1
deepspeed install path ........... ['/scratch/alis/repos/perf-pilot/venv/lib/python3.11/site-packages/deepspeed']
deepspeed info ................... 0.17.5, unknown, unknown
torch cuda version ............... 12.6
torch hip version ................ None
nvcc version ..................... 12.6
deepspeed wheel compiled w. ...... torch 2.7, cuda 12.6
shared memory (/dev/shm) size .... 566.00 GB
System info (please complete the following information):
transformersversion: 4.56.1- Platform: Linux-5.14.0-570.25.1.el9_6.x86_64-x86_64-AMD_EPYC_9654_96-Core_Processor-with-glibc2.37
- Python version: 3.11.5
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.5.3
- Accelerate version: 1.10.1
- Accelerate config: not found
- DeepSpeed version: 0.17.5
- PyTorch version (accelerator?): 2.7.1 (CUDA)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?:
- Using GPU in script?:
- GPU type: NVIDIA H100 80GB HBM3
Launcher context
Accelerate config:
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
deepspeed_multinode_launcher: standard
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: true
zero3_save_16bit_model: true
zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Hi @Ali-Sayed-Salehi. Thanks for reporting this. Could you share the exact command you used to build/run so I can try to reproduce on my side? On my AMD CPU I didn’t hit this issue. Since you mentioned it happens with JIT compilation, does the same error also occur if you build ahead of time with DS_BUILD_CPU_ADAM=1 pip install deepspeed
I normally face this during fine-tuning with accelerate but minimally I could reproduce with this command:
python - <<'PY'
from deepspeed.ops.op_builder import CPUAdamBuilder
CPUAdamBuilder().load(verbose=True)
print("CPUAdam built and loaded")
PY
I still couldn’t reproduce this issue on the AMD hardware I can reach. @delock @tohtana do you got any idea what might be happening here?
This is my exact CPU model if it helps: AMD EPYC 9654 (Zen 4) @ 2.4 GHz, 384MB cache L3
I tried on the AMD zen4 CPU, and the code works OK. The cpu_arch function returns -march=native. I don't understand how the cpu_arch function returns -march=x86-64-v3 instead of -march=native. Can you check the source code and see if the return value of this function has been modified?
The implementation of cpu_arch() tends to return -march=native, so x86-64-v3 looks abnormal to me. Hi @Ali-Sayed-Salehi , some debugging into this function should reveal the exact line that returns -march=x86-64-v3, since you already see this function return an undesired value, the root cause should be very close.
I encountered the same issue with an AMD EPYC 7742 64-Core Processor.