pytorch
pytorch copied to clipboard
Build fails for gfx1010 architecture
🐛 Bug
PyTorch fails to build for gfx1010 architecture
To Reproduce
Steps to reproduce the behavior:
- Install ROCm 3.7
- Clone pytorch:
git clone --recursive https://github.com/ROCmSoftwarePlatform/pytorch -
git rev-parse HEAD daf055b249a90526cd5e5bdce364a31735eb3f0a - Run “hipify” to prepare source code:
$ python tools/amd_build/build_amd.py - Build and install pytorch:
$ export PATH=/opt/rocm-3.7.0/bin:$PATH \ ROCM_PATH=/opt/rocm-3.7.0 \ HIP_PATH=/opt/rocm-3.7.0/hip $ export PYTORCH_ROCM_ARCH=gfx1010 # Navi 10 $ USE_ROCM=1 USE_DISTRIBUTED=0 python setup.py install --user
Build fails with:
[3680/4568] Building HIPCC object caffe2/CMakeFiles/torch_hip.dir/operators/hip/torch_hip_generated_batch_sparse_to_dense_op.hip.o
FAILED: caffe2/CMakeFiles/torch_hip.dir/operators/hip/torch_hip_generated_batch_sparse_to_dense_op.hip.o
cd /home/erkki/Downloads/rocm2/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/operators/hip && /usr/bin/cmake -E make_directory /home/erkki/Downloads/rocm2/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/operators/hip/. && /usr/bin/cmake -D verbose:BOOL=OFF -D build_configuration:STRING=RELEASE -D generated_file:STRING=/home/erkki/Downloads/rocm2/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/operators/hip/./torch_hip_generated_batch_sparse_to_dense_op.hip.o -P /home/erkki/Downloads/rocm2/pytorch/build/caffe2/CMakeFiles/torch_hip.dir/operators/hip/torch_hip_generated_batch_sparse_to_dense_op.hip.o.cmake
error: Illegal instruction detected: Invalid dpp_ctrl value: broadcasts are not supported on GFX10+
renamable $vgpr23 = V_MOV_B32_dpp killed $vgpr23(tied-def 0), $vgpr13, 322, 15, 15, 0, implicit $exec
error: Illegal instruction detected: Invalid dpp_ctrl value: broadcasts are not supported on GFX10+
renamable $vgpr47 = V_MOV_B32_dpp killed $vgpr47(tied-def 0), $vgpr14, 322, 15, 15, 0, implicit $exec
Full build output here: https://gist.github.com/rigtorp/db4241d25b753c952962e554686735dc
Expected behavior
Build succeeds and tests completes:
PYTORCH_TEST_WITH_ROCM=1 python test/run_test.py --verbose
Environment
Collecting environment information... PyTorch version: N/A Is debug build: N/A CUDA used to build PyTorch: N/A
OS: Fedora 32 (Workstation Edition) (x86_64) GCC version: (GCC) 10.2.1 20200723 (Red Hat 10.2.1-1) Clang version: 10.0.0 (Fedora 10.0.0-2.fc32) CMake version: version 3.17.4
Python version: 3.8 (64-bit runtime) Is CUDA available: N/A CUDA runtime version: Could not collect GPU models and configuration: Could not collect Nvidia driver version: Could not collect cuDNN version: Could not collect
Versions of relevant libraries: [pip3] numpy==1.18.4 [pip3] torch==1.7.0a0+1f0cfba [conda] Could not collect
Additional context
Installed ROCm 3.7 on Fedora 32 with these instructions: https://rigtorp.se/notes/rocm/
ErrorInfo come from llvm-project: https://github.com/RadeonOpenCompute/llvm-project/blob/master/llvm/lib/Target/AMDGPU/SIInstrInfo.cpp#L4123 Compare https://developer.amd.com/wp-content/resources/RDNA_Shader_ISA.pdf and http://developer.amd.com/wordpress/media/2013/12/AMD_GCN3_Instruction_Set_Architecture_rev1.1.pdf, dpp_ctrl in rdna donot support DPP_ROW_BCAST15 or DPP_ROW_BCAST31. So where's it come from? There is no matches text content in pytorch codebase, so maybe llvm generate these instructions.
Yes I also did not find any inline assembler in PyTorch with these instructions. Looks like there is an issue with LLVM codegen for GFX10+.
@xuhuisheng The correct branch for latest AMD version of LLVM is amd-stg-open
I think I have tracked this down to inline assemble in the https://github.com/ROCmSoftwarePlatform/rocPRIM library
/opt/rocm-3.7.0/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp
59:int __amdgcn_update_dpp(int old, int src, int dpp_ctrl, int row_mask, int bank_mask, bool bound_ctrl)
60: __asm("llvm.amdgcn.update.dpp.i32");
62:template<class T, int dpp_ctrl, int row_mask = 0xf, int bank_mask = 0xf, bool bound_ctrl = false>
64:T warp_move_dpp(T input)
74: words[i] = __amdgcn_update_dpp(
76: dpp_ctrl, row_mask, bank_mask, bound_ctrl
/opt/rocm-3.7.0/rocprim/include/rocprim/warp/detail/warp_scan_dpp.hpp
41:class warp_scan_dpp
59: T t = scan_op(warp_move_dpp<T, 0x111>(output), output); // row_shr:1
64: T t = scan_op(warp_move_dpp<T, 0x112>(output), output); // row_shr:2
69: T t = scan_op(warp_move_dpp<T, 0x114>(output), output); // row_shr:4
74: T t = scan_op(warp_move_dpp<T, 0x118>(output), output); // row_shr:8
79: T t = scan_op(warp_move_dpp<T, 0x142>(output), output); // row_bcast:15
84: T t = scan_op(warp_move_dpp<T, 0x143>(output), output); // row_bcast:31
/opt/rocm-3.7.0/rocprim/include/rocprim/warp/detail/warp_reduce_dpp.hpp
43:class warp_reduce_dpp
59: output = reduce_op(warp_move_dpp<T, 0xb1>(output), output);
64: output = reduce_op(warp_move_dpp<T, 0x4e>(output), output);
69: output = reduce_op(warp_move_dpp<T, 0x114>(output), output);
74: output = reduce_op(warp_move_dpp<T, 0x118>(output), output);
79: output = reduce_op(warp_move_dpp<T, 0x142>(output), output);
84: output = reduce_op(warp_move_dpp<T, 0x143>(output), output);
Hi @rigtorp @xuhuisheng , the PyTorch project relies on ROCm to run on AMD GPUs. At this moment, ROCm3.7 doesn't have full support for Navi GPUs (GFX10xx), please refer to this link for the official supported device list: https://github.com/RadeonOpenCompute/ROCm#supported-gpus
@sunway513 Yes I know.
Commenting out the offending lines in rocPRIM allows the PyTorch build to proceed much further. See https://github.com/ROCmSoftwarePlatform/rocPRIM/issues/180
Besides the issue in rocPrime there is also this invalid use of inline assembly:
[ 86%] Building HIPCC object caffe2/CMakeFiles/torch_hip.dir/__/aten/src/THHUNN/torch_hip_generated_RReLU.hip.o
<inline asm>:1:24: error: invalid operand for instruction
v_mad_u64_u32 v[2:3], s[10:11], s1, v42, v[8:9]
^
note: !srcloc = 13700931
<inline asm>:1:24: error: invalid operand for instruction
v_mad_u64_u32 v[4:5], s[10:11], s46, v42, v[8:9]
^
@rigtorp Thank you for reporting the issue. There is a fix planned for an upcoming release of rocPRIM
@doctorcolinsmith Please drop the fix in develop branch, PyTorch is very close to working on Navi now, we don't want to have to wait for ROCm 3.8!
Besides the issue in rocPrime there is also this invalid use of inline assembly:
[ 86%] Building HIPCC object caffe2/CMakeFiles/torch_hip.dir/__/aten/src/THHUNN/torch_hip_generated_RReLU.hip.o <inline asm>:1:24: error: invalid operand for instruction v_mad_u64_u32 v[2:3], s[10:11], s1, v42, v[8:9] ^ note: !srcloc = 13700931 <inline asm>:1:24: error: invalid operand for instruction v_mad_u64_u32 v[4:5], s[10:11], s46, v42, v[8:9] ^
I traced this down to use of inline assembly in rocRAND:
/opt/rocm-3.7.0/rocrand/include/rocrand_common.h
61: asm volatile("v_mad_u64_u32 %0, %1, %2, %3, %4"
If I comment out the above line, the PyTorch build proceeds. See https://github.com/ROCmSoftwarePlatform/rocRAND/issues/140
:D After comment out the inline assmebly from rocPRIM and rocRAND's header, pytorch built successfully with gfx1010. But I havenot a nav10 card. Somebody who had card could do some tests on it.
in rocPRIM, just comment out lines 74-77 in /opt/rocm-3.7.0/rocprim/include/rocprim/intrinsics/warp_shuffle.hpp
74 words[i] = __amdgcn_update_dpp(
75 0, words[i],
76 dpp_ctrl, row_mask, bank_mask, bound_ctrl
77 );
in rocRAND, comment out lines 55-64 in /opt/rocm-3.7.0/rocrand/include/rocrand_common.h, and using default non-asm codes to do the operate:
55 unsigned long long r;
56 unsigned long long c; // carry bits, SGPR, unused
57 // x has "r" constraint. This allows to use both VGPR and SGPR
58 // (to save VGPR) as input.
59 // y and z have "v" constraints, because only one SGPR or literal
60 // can be read by the instruction.
61 asm volatile("v_mad_u64_u32 %0, %1, %2, %3, %4"
62 : "=v"(r), "=s"(c) : "r"(x), "v"(y), "v"(z)
63 );
64 return r;
return static_cast<unsigned long long>(x) * static_cast<unsigned long long>(y) + z;
I've tested the fixes to rocPRIM and rocRAND, and the build completes with no errors, however it does produce a lot of warnings. Once done and i run
PYTORCH_TEST_WITH_ROCM=1 python test/run_test.py --verbose
to verify installation, it responds:
/src/external/hip-on-vdi/rocclr/hip_code_object.cpp:92: guarantee(false && "hipErrorNoBinaryForGpu: Coudn't find binary for current devices!")
And dumps the core.
@rigtorp have you gotten this to fully work yet? Im running it in docker, in which im fairly new, so i may have missed something.
@niculw I get the same runtime error. I followed these steps to install ROCm: https://rigtorp.se/notes/rocm/. I have then replaced rocPRIM and rocRAND with the fixed versions. rocPRIM and rocRAND are header only libraries. Some of the other libraries used by pytorch (miopen and others) are not header only and requires compiled kernels for each GPU architecture, these libraries also needs to be built for GFX1010. I haven't had time to do that yet.
@niculw This is an issue on rocblas, new tensile client didnot provide kernel for gfx1010. https://github.com/ROCmSoftwarePlatform/tensorflow-upstream/issues/1106
If you have enviroment to compile rocblas, can try build rocblas with BUILD_WITH_TENSILE_HOST=false. it force rocblas using old tensile client which didnot depend the specific kernel. Maybe workaroud for gfx1010.
The compiling rocblas needing huge mem cost, amd recommended 64G mem, and taking lots time, maybe hours.
I hadnot a gfx1010 card, cannot test it, and best wishes,
@xuhuisheng rocBLAS seems to have more problems than just the new kernels. I can only complete the building by setting BUILD_TENSILE=false, disabling the entire generation. What is weird to me is that rocBLAS detects gfx1010 and reports information, yet it cant compile for the architecture. Furthermore the build will fail if calling the install script using '-c' which should build the client. Thus our limitation seems to now be at rocBLAS and the tensile generation. I feel like we are so close
After changing this line in Tensile: https://github.com/ROCmSoftwarePlatform/Tensile/issues/1165
I can build rocBLAS:
export ROCM_PATH=/opt/rocm-3.7.0/
export TENSILE_ROCM_ASSEMBLER_PATH=/opt/rocm-3.7.0/llvm/bin/clang++
CXX=/opt/rocm-3.7.0/bin/hipcc cmake .. -DTensile_ARCHITECTURE=gfx1010 -DTensile_COMPILER=hipcc -DHIP_CLANG_INCLUDE_PATH=/opt/rocm-3.7.0/include/hip/ -DRUN_HEADER_TESTING=0 -DCMAKE_INSTALL_PREFIX=/opt/rocm-3.7.0/
make install
Still same issue running pytorch:
python -c 'import torch; print(torch.rand(2,3).cuda())'
Warning: please export TSAN_OPTIONS='ignore_noninstrumented_modules=1' to avoid false positive reports from the OpenMP runtime.!
/data/jenkins_workspace/centos_pipeline_job_8.1_rel-3.7/rocm-rel-3.7/rocm-3.7-20-20200817/8.1/external/hip-on-vdi/rocclr/hip_code_object.cpp:92: guarantee(false && "hipErrorNoBinaryForGpu: Coudn't find binary for current devices!")
fish: “python -c 'import torch; print(…” terminated by signal SIGABRT (Abort)
After changing this line in Tensile: ROCmSoftwarePlatform/Tensile#1165
I downloaded the develop git for Tensile, changed the line you are refering to, and indeed Tensile compiles and creates build files.
CXX=hipcc cmake .. -DTensile_ARCHITECTURE=gfx1010 -DTensile_COMPILER=hipcc -DTensile_TEST_LOCAL_PATH=~/Tensile/ -DHIP_CLANG_INCLUDE_PATH=/opt/rocm-3.7.0/include/hip/ -DRUN_HEADER_TESTING=0 -DBUILD_WITH_TENSILE_HOST=true -DCMAKE_INSTALL_PREFIX=/opt/rocm-3.7.0/
However when i run make install when reaching around 100% i get the error:
[100%] Linking CXX shared library librocblas.so
Can't exec "file": No such file or directory at /opt/rocm-3.7.0/bin/hipcc line 600.
Use of uninitialized value $fileType in pattern match (m//) at /opt/rocm-3.7.0/bin/hipcc line 601.
Use of uninitialized value $fileType in pattern match (m//) at /opt/rocm-3.7.0/bin/hipcc line 601.
Use of uninitialized value $fileType in pattern match (m//) at /opt/rocm-3.7.0/bin/hipcc line 602.
...
/usr/bin/ld: /tmp/hipccAOZzQxCa/objectc.c.o: relocation R_X86_64_32S against `.rodata' can not be used when making a shared object; recompile with -fPIC
/usr/bin/ld: /tmp/hipccAOZzQxCa/unpack.c.o: relocation R_X86_64_32S against `.rodata' can not be used when making a shared object; recompile with -fPIC
/usr/bin/ld: /tmp/hipccAOZzQxCa/version.c.o: relocation R_X86_64_32 against `.rodata.str1.1' can not be used when making a shared object; recompile with -fPIC
/usr/bin/ld: final link failed: Nonrepresentable section on output
clang-11: error: linker command failed with exit code 1 (use -v to see invocation)
library/src/CMakeFiles/rocblas.dir/build.make:2566: recipe for target 'library/src/librocblas.so.0.1' failed
I cant seem to get past this point.
Running the build with -DBUILD_WITH_TENSILE_HOST=false results in different errors when running make install
[ 27%] Building CXX object library/src/CMakeFiles/rocblas.dir/blas3/rocblas_trmm_strided_batched.cpp.o
/root/rocBLAS/library/src/blas_ex/rocblas_gemm_ext2.cpp:81:9: error: unknown type name 'rocblas_union_t'; did you mean 'rocblas_int'?
rocblas_union_t alpha_h, beta_h;
^~~~~~~~~~~~~~~
rocblas_int
/root/rocBLAS/library/include/internal/rocblas-types.h:47:17: note: 'rocblas_int' declared here
typedef int32_t rocblas_int;
^
/root/rocBLAS/library/src/blas_ex/rocblas_gemm_ext2.cpp:82:33: error: use of undeclared identifier 'copy_alpha_beta_to_host_if_on_device'
RETURN_IF_ROCBLAS_ERROR(copy_alpha_beta_to_host_if_on_device(
^
2 errors generated when compiling for gfx900.
library/src/CMakeFiles/rocblas.dir/build.make:120: recipe for target 'library/src/CMakeFiles/rocblas.dir/blas_ex/rocblas_gemm_ext2.cpp.o' failed
Maybe you have rocm installed to a different path than /opt/rocm-3.7.0 ?
On Sat, Sep 19, 2020, 15:27 Nicolai [email protected] wrote:
After changing this line in Tensile: ROCmSoftwarePlatform/Tensile#1165 https://github.com/ROCmSoftwarePlatform/Tensile/issues/1165
I downloaded the develop git for Tensile, changed the line you are refering to, and indeed Tensile compiles and creates build files.
CXX=hipcc cmake .. -DTensile_ARCHITECTURE=gfx1010 -DTensile_COMPILER=hipcc -DTensile_TEST_LOCAL_PATH=~/Tensile/ -DHIP_CLANG_INCLUDE_PATH=/opt/rocm-3.7.0/include/hip/ -DRUN_HEADER_TESTING=0 -DBUILD_WITH_TENSILE_HOST=true -DCMAKE_INSTALL_PREFIX=/opt/rocm-3.7.0/
However when i run make install when reaching around 100% i get the error:
[100%] Linking CXX shared library librocblas.so Can't exec "file": No such file or directory at /opt/rocm-3.7.0/bin/hipcc line 600. Use of uninitialized value $fileType in pattern match (m//) at /opt/rocm-3.7.0/bin/hipcc line 601. Use of uninitialized value $fileType in pattern match (m//) at /opt/rocm-3.7.0/bin/hipcc line 601. Use of uninitialized value $fileType in pattern match (m//) at /opt/rocm-3.7.0/bin/hipcc line 602. ... /usr/bin/ld: /tmp/hipccAOZzQxCa/objectc.c.o: relocation R_X86_64_32S against
.rodata' can not be used when making a shared object; recompile with -fPIC /usr/bin/ld: /tmp/hipccAOZzQxCa/unpack.c.o: relocation R_X86_64_32S against.rodata' can not be used when making a shared object; recompile with -fPIC /usr/bin/ld: /tmp/hipccAOZzQxCa/version.c.o: relocation R_X86_64_32 against `.rodata.str1.1' can not be used when making a shared object; recompile with -fPIC /usr/bin/ld: final link failed: Nonrepresentable section on output clang-11: error: linker command failed with exit code 1 (use -v to see invocation) library/src/CMakeFiles/rocblas.dir/build.make:2566: recipe for target 'library/src/librocblas.so.0.1' failedI cant seem to get past this point.
— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/ROCmSoftwarePlatform/pytorch/issues/718#issuecomment-695362404, or unsubscribe https://github.com/notifications/unsubscribe-auth/AABLO2YSKCMABJO2AHN6JETSGUV6DANCNFSM4QJFSEFA .
I created a new docker container from scratch instead of an old one i pulled. I successfully compiled rocBLAS.
When compiling pytorch i get:
[ 80%] Building HIPCC object caffe2/CMakeFiles/torch_hip.dir/__/aten/src/THH/generated/torch_hip_generated_THHTensorMathReduceChar.hip.o
/root/pytorch/aten/src/THH/THHBlas.hip:229:17: warning: rocblas_gemm_strided_batched_ex: The workspace_size and workspace arguments are obsolete, and will be ignored [-W#pragma-messages]
THCublasCheck(rocblas_gemm_strided_batched_ex(handle, opa, opb, (int)m, (int)n, (int)k,
^
/opt/rocm-3.7.0/rocblas/include/internal/rocblas-functions.h:15268:41: note: expanded from macro 'rocblas_gemm_strided_batched_ex'
ROCBLAS_VA_OPT_PRAGMA(GCC warning "rocblas_gemm_strided_batched_ex: The workspace_size and workspace arguments are obsolete, and will be ignored", __VA_ARGS__) \
^
/opt/rocm-3.7.0/rocblas/include/internal/rocblas-functions.h:48:5: note: expanded from macro 'ROCBLAS_VA_OPT_PRAGMA'
ROCBLAS_VA_OPT_PRAGMA_IMPL(pragma, ROCBLAS_VA_OPT_COUNT(__VA_ARGS__))
^
/opt/rocm-3.7.0/rocblas/include/internal/rocblas-functions.h:46:51: note: expanded from macro 'ROCBLAS_VA_OPT_PRAGMA_IMPL'
#define ROCBLAS_VA_OPT_PRAGMA_IMPL(pragma, count) ROCBLAS_VA_OPT_PRAGMA_IMPL2(pragma, count)
^
/opt/rocm-3.7.0/rocblas/include/internal/rocblas-functions.h:45:52: note: expanded from macro 'ROCBLAS_VA_OPT_PRAGMA_IMPL2'
#define ROCBLAS_VA_OPT_PRAGMA_IMPL2(pragma, count) ROCBLAS_VA_OPT_PRAGMA_SELECT##count(pragma)
^
<scratch space>:31:1: note: expanded from here
ROCBLAS_VA_OPT_PRAGMA_SELECTN
^
/opt/rocm-3.7.0/rocblas/include/internal/rocblas-functions.h:44:52: note: expanded from macro 'ROCBLAS_VA_OPT_PRAGMA_SELECTN'
#define ROCBLAS_VA_OPT_PRAGMA_SELECTN(pragma, ...) _Pragma(#pragma)
^
1 warning generated when compiling for gfx1010.
Which i also had before, which makes me believe there is more happening in rocBLAS, but i dont have the knowledge to investigate. Pytorch compiles and returns same error as before when running the test scripts.
/src/external/hip-on-vdi/rocclr/hip_code_object.cpp:92: guarantee(false && "hipErrorNoBinaryForGpu: Coudn't find binary for current devices!")
@niculw
You can recompile rocBLAS without new tensile client using command like bash install.sh -i -t /home/work/tmp/Tensile -a gfx1010 -r.
-itold install rocblas and packaged deb-ttold use local tensile, which we clone from https://github.com/rocmsoftwareplatorm/tensile and do some modified. remember looking at the tensile_tag.txt of rocBLAS project, rocBLAS need specific commit of tensile, which in the tensile_tag.txt, when we used-t, we should executegit checkout <commit>on local tensile repository manually .-atold build for gfx1010, default is gfx803,gfx900,gfx906,gfx908,gfx1010,gfx1011, which will cost more times and useless for our test. OK, gfx1010,gfx1011 is in the default config, which a waste of our time and wont work on runtime.-rtold cmake used BUILD_TENSILE_HOST=false, and it will use old tensile client building without TensileLibrary_gfx1010.co. The old tensile client wont use hip_code_object.cpp to load any binaries.
After building rocblas_xx.deb, sudo dpkg -i build/release/*.deb, Neednot recompile the pytorch.
Additionally, -r only work on rocm-3.7.0, It met some error on rocm-3.8.0.
update: rechecked, rocblas can build successfully for gfx1010 with -r on 3.8.0, which gfx803 didnt.
@niculw I have an idea. Since there is no problems/solutions for gfx1010, ROCm wont actually use TensileLibrary_gfx1010.co. If we copy TensileLibrary_gfx900.co to TensileLibrary_gfx1010.co, the problem may be solved.
I did manage to build rocBLAS and Tensile using their WIP branch gfx10. The unit tests for rocBLAS also went well with the GPU being actually used. Special thanks for @rigtorp for tracing these issues and install instructions for Fedora ^_^ .
However, the hipErrorNoBinaryForGpu error still persists. It would be nice if a full stack trace of this error can be found to figure out if this is rocBLAS or something else causing the troubles.
@tuxbotix Need recompile rocBLAS, rocFFT, rocSPARSE, hipSPARSE, rccl, MIOpen, with AMDGPU_TARGETS=gfx1010. rocSPARSE and rccl need rocm-3.10.x branch to fix navi10 inline asm issues. I believe that AMD are starting to support RDNA2, so the related fix could be apply on gfx1010, too.
The hipErrorNoBinaryForGpu means we didnot compile kernels for specific ISA. I found this problem for ROCm-3.9 issues on gfx803.
Aha I was wondering if this is the exact issue. So this error means the kernels I got when building rocBLAS (and specifying gfx1010) aren't enough, I need to build other components for gfx1010 as well?
Yes I also noticed their dev. work for the newer GPUs, what wasn't clear to me was if gfx1030 for example corresponds to the new RDNA2 cards or something else.
BTW I didn't get any inline asm errors AFAIK.
@tuxbotix There may be some other issues on navi10. Likes rccl https://github.com/ROCmSoftwarePlatform/rccl/issues/289
RX5700XT is not better than vega 56/64 on compute, even Radeon VII. I am afraid AMD wont pay attention on gfx101x. The supporting of gfx103x aka RDNA2 is processing. Maybe we could see RDNA2 supporting on ROCm-4.0.
BTW, the price of video card is higher than middle of year. I am go on waiting for the rx5700 get cheap.
That's sad to hear. I'm very happy with this card for my gaming uses and much cheaper than closest Nvidia card. I got mine in last December but now only I got some time to look at compute usage 😅
At least it is nice to see they are working on RDNA 2, as the cards are almost launched~
Any update on RDNA 1?
Hello.
I have the exact same hipErrorNoBinaryForGpu as you, did you finally found a workaround for that? I followed rigtorp instructions to install rocm and pytorch on fedora (thank you, btw),
@Colosu After do some patches, We could run ROCm on navi10/navi14 successfully. But the loss of mnist wont change. please refer here: https://github.com/RadeonOpenCompute/ROCm/issues/887#issuecomment-751443281 https://github.com/RadeonOpenCompute/ROCm/issues/1306#issuecomment-760062894
The good news is AMD said ROCm will provide official support navi in 2021.