HIP error: invalid device function on ROCm RX 7600
🐛 Describe the bug
When attempting to perform any GPU compute task using PyTorch with the ROCm/HIP backend, I encounter the following error:
torch.AcceleratorError: HIP error: invalid device function
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Example:
import torch
print("PyTorch detects GPU:", torch.cuda.is_available())
print(f"ROCm device detected: {torch.cuda.get_device_name(0)}")
print(f"VRAM available: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
device = torch.device("cuda")
print("Allocating tensors on GPU...")
a = torch.randn((1000, 1000), device=device, dtype=torch.float32)
b = torch.randn((1000, 1000), device=device, dtype=torch.float32)
print("Running matrix multiplication...")
result = torch.matmul(a, b)
torch.cuda.synchronize()
print("✅ PyTorch HIP execution successful!")
Results in:
PyTorch detects GPU: True
ROCm device detected: AMD Radeon RX 7600
VRAM available: 7.98 GB
Allocating tensors on GPU...
Traceback (most recent call last):
File "/app/tensor.py", line 10, in <module>
a = torch.randn((1000, 1000), device=device, dtype=torch.float32)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: HIP error: invalid device function
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Versions
root@fb554eea83ff:/app# python collect_env.py
Collecting environment information...
PyTorch version: 2.8.0+rocm7.0.0.git64359f59
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 7.0.51831-a3e329ad8
OS: Ubuntu 24.04.3 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04) 13.3.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.14.0-33-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Radeon RX 7600 (gfx1102)
Nvidia driver version: Could not collect
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: 7.0.51831
MIOpen runtime version: 3.5.0
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 43 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 12
On-line CPU(s) list: 0-11
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 5 3600 6-Core Processor
CPU family: 23
Model: 113
Thread(s) per core: 2
Core(s) per socket: 6
Socket(s): 1
Stepping: 0
Frequency boost: enabled
CPU(s) scaling MHz: 86%
CPU max MHz: 4208.0000
CPU min MHz: 550.0000
BogoMIPS: 7199.74
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es
Virtualization: AMD-V
L1d cache: 192 KiB (6 instances)
L1i cache: 192 KiB (6 instances)
L2 cache: 3 MiB (6 instances)
L3 cache: 32 MiB (2 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Ghostwrite: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==2.3.3
[pip3] pytorch-triton-rocm==3.4.0+rocm7.0.0.gitf9e5bf54
[pip3] torch==2.8.0+rocm7.0.0.lw.git64359f59
[pip3] torchaudio==2.8.0+rocm7.0.0.git6e1c7fe9
[pip3] torchvision==0.23.0+rocm7.0.0.git824e8c87
[pip3] triton==3.5.0
[conda] Could not collect
Hi @Amund, would you mind posting the output of rocminfo here when you get a chance? Thanks!
Yep, here it is :
root@fb554eea83ff:/app# rocminfo
ROCk module is loaded
=====================
HSA System Attributes
=====================
Runtime Version: 1.18
Runtime Ext Version: 1.11
System Timestamp Freq.: 1000.000000MHz
Sig. Max Wait Duration: 18446744073709551615 (0xFFFFFFFFFFFFFFFF) (timestamp count)
Machine Model: LARGE
System Endianness: LITTLE
Mwaitx: DISABLED
XNACK enabled: NO
DMAbuf Support: YES
VMM Support: YES
==========
HSA Agents
==========
*******
Agent 1
*******
Name: AMD Ryzen 5 3600 6-Core Processor
Uuid: CPU-XX
Marketing Name: AMD Ryzen 5 3600 6-Core Processor
Vendor Name: CPU
Feature: None specified
Profile: FULL_PROFILE
Float Round Mode: NEAR
Max Queue Number: 0(0x0)
Queue Min Size: 0(0x0)
Queue Max Size: 0(0x0)
Queue Type: MULTI
Node: 0
Device Type: CPU
Cache Info:
L1: 32768(0x8000) KB
Chip ID: 0(0x0)
ASIC Revision: 0(0x0)
Cacheline Size: 64(0x40)
Max Clock Freq. (MHz): 4208
BDFID: 0
Internal Node ID: 0
Compute Unit: 12
SIMDs per CU: 0
Shader Engines: 0
Shader Arrs. per Eng.: 0
WatchPts on Addr. Ranges:1
Memory Properties:
Features: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: FINE GRAINED
Size: 48192132(0x2df5a84) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 48192132(0x2df5a84) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 3
Segment: GLOBAL; FLAGS: KERNARG, FINE GRAINED
Size: 48192132(0x2df5a84) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
Pool 4
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 48192132(0x2df5a84) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:4KB
Alloc Alignment: 4KB
Accessible by all: TRUE
ISA Info:
*******
Agent 2
*******
Name: gfx1102
Uuid: GPU-XX
Marketing Name: AMD Radeon RX 7600
Vendor Name: AMD
Feature: KERNEL_DISPATCH
Profile: BASE_PROFILE
Float Round Mode: NEAR
Max Queue Number: 128(0x80)
Queue Min Size: 64(0x40)
Queue Max Size: 131072(0x20000)
Queue Type: MULTI
Node: 1
Device Type: GPU
Cache Info:
L1: 32(0x20) KB
L2: 2048(0x800) KB
Chip ID: 29824(0x7480)
ASIC Revision: 0(0x0)
Cacheline Size: 128(0x80)
Max Clock Freq. (MHz): 2356
BDFID: 2560
Internal Node ID: 1
Compute Unit: 32
SIMDs per CU: 2
Shader Engines: 2
Shader Arrs. per Eng.: 2
WatchPts on Addr. Ranges:4
Coherent Host Access: FALSE
Memory Properties:
Features: KERNEL_DISPATCH
Fast F16 Operation: TRUE
Wavefront Size: 32(0x20)
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Max Waves Per CU: 32(0x20)
Max Work-item Per CU: 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
Max fbarriers/Workgrp: 32
Packet Processor uCode:: 542
SDMA engine uCode:: 21
IOMMU Support:: None
Pool Info:
Pool 1
Segment: GLOBAL; FLAGS: COARSE GRAINED
Size: 8372224(0x7fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 2
Segment: GLOBAL; FLAGS: EXTENDED FINE GRAINED
Size: 8372224(0x7fc000) KB
Allocatable: TRUE
Alloc Granule: 4KB
Alloc Recommended Granule:2048KB
Alloc Alignment: 4KB
Accessible by all: FALSE
Pool 3
Segment: GROUP
Size: 64(0x40) KB
Allocatable: FALSE
Alloc Granule: 0KB
Alloc Recommended Granule:0KB
Alloc Alignment: 0KB
Accessible by all: FALSE
ISA Info:
ISA 1
Name: amdgcn-amd-amdhsa--gfx1102
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
ISA 2
Name: amdgcn-amd-amdhsa--gfx11-generic
Machine Models: HSA_MACHINE_MODEL_LARGE
Profiles: HSA_PROFILE_BASE
Default Rounding Mode: NEAR
Default Rounding Mode: NEAR
Fast f16: TRUE
Workgroup Max Size: 1024(0x400)
Workgroup Max Size per Dimension:
x 1024(0x400)
y 1024(0x400)
z 1024(0x400)
Grid Max Size: 4294967295(0xffffffff)
Grid Max Size per Dimension:
x 2147483647(0x7fffffff)
y 65535(0xffff)
z 65535(0xffff)
FBarrier Max Size: 32
*** Done ***
@Amund Can you please clarify where you downloading your PyTorch installation? Thanks
@naromero77amd Sure, I am using a Docker image based on rocm/pytorch:latest (seems to be rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0)
@taylding-amd @naromero77amd
I created a repo with a streamlined example : https://github.com/Amund/rocm-test-2730
@naromero77amd Sure, I am using a Docker image based on rocm/pytorch:latest (seems to be rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0)
Are you using the PyTorch version that comes with that image or are you building a new PyTorch version and installing into the image?
@naromero77amd No build or installation, only what is available in the image.
Tested with rocm 7.0.2 and 7.1 : https://github.com/Amund/rocm-test-2730
Hi @Amund, I tried with rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 and was not able to reproduce the issue, how did you run the docker image? I have a feeling that the error could be group permissions issue. here's the command I used:
sudo docker run -it --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --device=/dev/kfd --device=/dev/dri --group-add video --ipc=host --shm-size 8G rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0
I've tested it
$ sudo docker run -it --rm \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--device=/dev/kfd \
--device=/dev/dri \
--group-add video \
--ipc=host \
--shm-size 8G \
-v ./src:/app \
-w /app \
rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 \
python tensor.py
PyTorch detects GPU: True
ROCm device detected: AMD Radeon RX 7600
VRAM available: 7.98 GB
Allocating tensors on GPU...
Traceback (most recent call last):
File "/app/tensor.py", line 11, in <module>
a = torch.randn((1000, 1000), device=device, dtype=torch.float32)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Actually, that's what's done in my docker-compose.yml file, but you're right, it's even more streamlined.
services:
app:
container_name: py-rocm-test
image: rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0
# image: rocm/pytorch:rocm7.1_ubuntu24.04_py3.12_pytorch_release_2.8.0
restart: unless-stopped
volumes:
- ./src:/app
working_dir: /app
entrypoint: ["sleep", "infinity"]
devices:
- /dev/kfd
- /dev/dri
group_add:
- video
cap_add:
- SYS_PTRACE
security_opt:
- seccomp=unconfined
ipc: host
shm_size: 8G
@taylding-amd
I've checked root groups in the container
root@5e0bcdad7c36:/app# groups
root video
Do you have any other tests I can do ?
@Amund From outside the container (bare metal), can you send us the output of:
rocm-smi --showpids
from inside the container, can you send the output of:
rocm-smi
Also, the output env, you current environment.
@naromero77amd
From outside the container:
# rocm-smi --showpids
============================ ROCm System Management Interface ============================
===================================== KFD Processes ======================================
No KFD PIDs currently running
==========================================================================================
================================== End of ROCm SMI Log ===================================
From inside the container:
# rocm-smi
======================================== ROCm System Management Interface ========================================
================================================== Concise Info ==================================================
Device Node IDs Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
(DID, GUID) (Edge) (Avg) (Mem, Compute, ID)
==================================================================================================================
0 1 0x7480, 38054 42.0°C 12.0W N/A, N/A, 0 827Mhz 456Mhz 0% auto 145.0W 12% 0%
==================================================================================================================
============================================== End of ROCm SMI Log ===============================================
And env inside the container:
# env
HOSTNAME=5e0bcdad7c36
PWD=/app
HOME=/root
LS_COLORS=rs=0:di=01;34:ln=01;36:mh=00:pi=40;33:so=01;35:do=01;35:bd=40;33;01:cd=40;33;01:or=40;31;01:mi=00:su=37;41:sg=30;43:ca=00:tw=30;42:ow=34;42:st=37;44:ex=01;32:*.tar=01;31:*.tgz=01;31:*.arc=01;31:*.arj=01;31:*.taz=01;31:*.lha=01;31:*.lz4=01;31:*.lzh=01;31:*.lzma=01;31:*.tlz=01;31:*.txz=01;31:*.tzo=01;31:*.t7z=01;31:*.zip=01;31:*.z=01;31:*.dz=01;31:*.gz=01;31:*.lrz=01;31:*.lz=01;31:*.lzo=01;31:*.xz=01;31:*.zst=01;31:*.tzst=01;31:*.bz2=01;31:*.bz=01;31:*.tbz=01;31:*.tbz2=01;31:*.tz=01;31:*.deb=01;31:*.rpm=01;31:*.jar=01;31:*.war=01;31:*.ear=01;31:*.sar=01;31:*.rar=01;31:*.alz=01;31:*.ace=01;31:*.zoo=01;31:*.cpio=01;31:*.7z=01;31:*.rz=01;31:*.cab=01;31:*.wim=01;31:*.swm=01;31:*.dwm=01;31:*.esd=01;31:*.avif=01;35:*.jpg=01;35:*.jpeg=01;35:*.mjpg=01;35:*.mjpeg=01;35:*.gif=01;35:*.bmp=01;35:*.pbm=01;35:*.pgm=01;35:*.ppm=01;35:*.tga=01;35:*.xbm=01;35:*.xpm=01;35:*.tif=01;35:*.tiff=01;35:*.png=01;35:*.svg=01;35:*.svgz=01;35:*.mng=01;35:*.pcx=01;35:*.mov=01;35:*.mpg=01;35:*.mpeg=01;35:*.m2v=01;35:*.mkv=01;35:*.webm=01;35:*.webp=01;35:*.ogm=01;35:*.mp4=01;35:*.m4v=01;35:*.mp4v=01;35:*.vob=01;35:*.qt=01;35:*.nuv=01;35:*.wmv=01;35:*.asf=01;35:*.rm=01;35:*.rmvb=01;35:*.flc=01;35:*.avi=01;35:*.fli=01;35:*.flv=01;35:*.gl=01;35:*.dl=01;35:*.xcf=01;35:*.xwd=01;35:*.yuv=01;35:*.cgm=01;35:*.emf=01;35:*.ogv=01;35:*.ogx=01;35:*.aac=00;36:*.au=00;36:*.flac=00;36:*.m4a=00;36:*.mid=00;36:*.midi=00;36:*.mka=00;36:*.mp3=00;36:*.mpc=00;36:*.ogg=00;36:*.ra=00;36:*.wav=00;36:*.oga=00;36:*.opus=00;36:*.spx=00;36:*.xspf=00;36:*~=00;90:*#=00;90:*.bak=00;90:*.crdownload=00;90:*.dpkg-dist=00;90:*.dpkg-new=00;90:*.dpkg-old=00;90:*.dpkg-tmp=00;90:*.old=00;90:*.orig=00;90:*.part=00;90:*.rej=00;90:*.rpmnew=00;90:*.rpmorig=00;90:*.rpmsave=00;90:*.swp=00;90:*.tmp=00;90:*.ucf-dist=00;90:*.ucf-new=00;90:*.ucf-old=00;90:
LESSCLOSE=/usr/bin/lesspipe %s %s
TERM=xterm
LESSOPEN=| /usr/bin/lesspipe %s
SHLVL=1
LD_LIBRARY_PATH=/opt/rocm/lib
PATH=/opt/venv/bin:/opt/rocm/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
DEBIAN_FRONTEND=noninteractive
_=/usr/bin/env
@naromero77amd @taylding-amd
Note: I updated my Ubuntu PC from 25.04 to 25.10 and removed the rocm7.01 driver from the host to revert to the generic amdgpu driver. As a result, the rocm-smi command outside the container no longer works.
# lshw -c video
*-display
description: VGA compatible controller
produit: Navi 33 [Radeon RX 7600/7600 XT/7600M XT/7600S/7700S / PRO W7600]
fabricant: Advanced Micro Devices, Inc. [AMD/ATI]
identifiant matériel: 0
information bus: pci@0000:0a:00.0
nom logique: /dev/fb0
version: cf
bits: 64 bits
horloge: 33MHz
fonctionnalités: vga_controller bus_master cap_list rom fb
configuration : depth=32 driver=amdgpu latency=0 resolution=3440,1440
ressources : irq:79 mémoire:d0000000-dfffffff mémoire:e0000000-e01fffff portE/S:e000(taille=256) mémoire:fcd00000-fcdfffff mémoire:fce00000-fce1ffff
# rocm-smi --showpids
bash: rocm-smi : commande introuvable
However, everything else inside the container remains the same, and the error is still present.
@Amund Can you do which rocm-smi outside the container? Can you please confirm it is installed?
@naromero77amd
As said, I removed the rocm driver from the host, so I confirm rocm-smi is not installed on the host (which rocm-smi doesn't return anything).
But I'm starting to have doubts here...
Does the rocm/pytorch image contain both the rocm and pytorch drivers? Not just pytorch? Is it not necessary to install the rocm drivers on the host if I use this image?
@Amund You must have the GPU driver (kernel modules) installed and loaded on the host, otherwise the container will not work properly.
If you believe that you have the drivers installed, you would still need to load the kernel modules. I am less familiar with consumer grade cards, but it is likely similar to data center cards.
sudo modprobe amdgpu
@naromero77amd OK, so I'm sure that the amdgpu driver is loaded on the host, and therefore there is no rocm driver or rocm-smi command on the host.
# lsmod | grep “amdgpu”
amdgpu 20373504 41
amdxcp 12288 1 amdgpu
drm_panel_backlight_quirks 12288 1 amdgpu
gpu_sched 65536 1 amdgpu
drm_buddy 28672 1 amdgpu
drm_ttm_helper 16384 1 amdgpu
ttm 131072 2 amdgpu,drm_ttm_helper
drm_exec 12288 1 amdgpu
drm_suballoc_helper 24576 1 amdgpu
drm_display_helper 294,912 1 amdgpu
cec 106,496 2 drm_display_helper,amdgpu
i2c_algo_bit 16,384 1 amdgpu
video 77,824 2 asus_wmi,amdgpu
Are you able to run any simple HIP program to completion? (Should be straightforward to ask chatGPT to write one for you).
Ideally, it should run on the host and in the container.
In the container:
I don't know C++, but ChatGPT provided me with this code for vector_add.cpp:
#include <hip/hip_runtime.h>
#include <iostream>
#define N 100
__global__ void vector_add(const float *A, const float *B, float *C, int n) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < n)
C[idx] = A[idx] + B[idx];
}
int main() {
float *A, *B, *C;
float *dA, *dB, *dC;
A = new float[N];
B = new float[N];
C = new float[N];
for (int i = 0; i < N; i++) {
A[i] = i;
B[i] = 2 * i;
}
hipMalloc(&dA, N * sizeof(float));
hipMalloc(&dB, N * sizeof(float));
hipMalloc(&dC, N * sizeof(float));
hipMemcpy(dA, A, N * sizeof(float), hipMemcpyHostToDevice);
hipMemcpy(dB, B, N * sizeof(float), hipMemcpyHostToDevice);
int threadsPerBlock = 64;
int blocks = (N + threadsPerBlock - 1) / threadsPerBlock;
hipLaunchKernelGGL(vector_add, dim3(blocks), dim3(threadsPerBlock), 0, 0, dA, dB, dC, N);
hipMemcpy(C, dC, N * sizeof(float), hipMemcpyDeviceToHost);
std::cout << "Result (first 10 elements):\n";
for (int i = 0; i < 10; i++)
std::cout << A[i] << " + " << B[i] << " = " << C[i] << "\n";
hipFree(dA);
hipFree(dB);
hipFree(dC);
delete[] A;
delete[] B;
delete[] C;
return 0;
}
I was able to compile and run it, and it works flawlessly:
# hipcc vector_add.cpp -o vector_add
# ./vector_add
Result (first 10 elements):
0 + 0 = 0
1 + 2 = 3
2 + 4 = 6
3 + 6 = 9
4 + 8 = 12
5 + 10 = 15
6 + 12 = 18
7 + 14 = 21
8 + 16 = 24
9 + 18 = 27
Outside the container
Once again, there is no longer a rocm driver on the host, nor any rocm-smi or hipcc commands. I cannot compile or start the binary. And when I try to start it anyway (because I'm stubborn), I get (logically) an error, since I don't have any HIP shared libraries (libamdhip64.so.7 in this case).
And that's great, because that's exactly why I want to use Docker, so I don't have to install stuff on the host. I just want to fine-tune my local LLM in the container! 😅
Can you try running your simple PyTorch example while setting this environment variable:
PYTORCH_NO_HIP_MEMORY_CACHING=1 python simple.py
and post the output here.
Here it is:
# PYTORCH_NO_HIP_MEMORY_CACHING=1 python tensor.py
PyTorch detects GPU: True
ROCm device detected: AMD Radeon RX 7600
VRAM available: 7.98 GB
Allocating tensors on GPU...
Traceback (most recent call last):
File "/app/tensor.py", line 11, in <module>
a = torch.randn((1000, 1000), device=device, dtype=torch.float32)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Does this program run for you in the docker image?
import torch
# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
# Create a tensor on the GPU
x = torch.tensor([1.0, 2.0, 3.0], device=device)
# Perform a simple operation
y = x + 1
Sorry, no better luck:
# python simple.py
cuda
Traceback (most recent call last):
File "/app/simple.py", line 11, in <module>
y = x + 1
~~^~~
torch.AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
Can you re-run the last example and set AMD_LOG_LEVEL=3?
Yes, and there is a hint:
# AMD_LOG_LEVEL=3 python simple.py
:3:rocdevice.cpp :480 : 46756822561 us: Initalizing runtime stack, Enumerated GPU agents = 1
:3:rocdevice.cpp :233 : 46756822612 us: Numa selects cpu agent[0]=0x3c263620(fine=0x3c8df8e0,coarse=0x3c2fff50) for gpu agent=0x3c300490 CPU<->GPU XGMI=0
:3:rocsettings.cpp :282 : 46756822630 us: Using dev kernel arg wa = 0
:3:comgrctx.cpp :127 : 46756822663 us: Loaded COMGR library version 3.0.
:3:rocdevice.cpp :1677: 46756823300 us: Gfx Major/Minor/Stepping: 11/0/2
:3:rocdevice.cpp :1679: 46756823307 us: HMM support: 1, XNACK: 0, Direct host access: 0
:3:rocdevice.cpp :1681: 46756823312 us: Max SDMA Read Mask: 0x3, Max SDMA Write Mask: 0x3
:3:runtime.cpp :82 : 46756824109 us: ROCclr version: 7c9236b16
:3:hip_context.cpp :56 : 46756824117 us: HIP Version: 7.0.51831.7c9236b16, Direct Dispatch: 1
:3:os_posix.cpp :966 : 46756824125 us: HIP Library Path: /opt/rocm/lib/libamdhip64.so.7
:3:hip_device_runtime.cpp :702 : 46756830725 us: hipGetDeviceCount ( 0x7ffd73263d00 )
:3:hip_device_runtime.cpp :704 : 46756830742 us: hipGetDeviceCount: Returned hipSuccess :
cuda
:3:hip_device_runtime.cpp :702 : 46756831050 us: hipGetDeviceCount ( 0x7ffd73263350 )
:3:hip_device_runtime.cpp :704 : 46756831057 us: hipGetDeviceCount: Returned hipSuccess :
:3:hip_device_runtime.cpp :702 : 46756831077 us: hipGetDeviceCount ( 0x7ff4f45a828c )
:3:hip_device_runtime.cpp :704 : 46756831081 us: hipGetDeviceCount: Returned hipSuccess :
:3:hip_device.cpp :659 : 46756831090 us: hipGetDevicePropertiesR0600 ( 0x7ffd73262e68, 0 )
:3:hip_device.cpp :661 : 46756831097 us: hipGetDevicePropertiesR0600: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756831109 us: hipGetDevice ( 0x7ffd73262f4c )
:3:hip_device_runtime.cpp :698 : 46756831114 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :702 : 46756831121 us: hipGetDeviceCount ( 0x7ffd7326305c )
:3:hip_device_runtime.cpp :704 : 46756831125 us: hipGetDeviceCount: Returned hipSuccess :
:3:hip_context.cpp :361 : 46756831364 us: hipDevicePrimaryCtxGetState ( 0, 0x7ffd73262f68, 0x7ffd73262f6c )
:3:hip_context.cpp :375 : 46756831372 us: hipDevicePrimaryCtxGetState: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756831380 us: hipGetDevice ( 0x7ffd73262f9c )
:3:hip_device_runtime.cpp :698 : 46756831384 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_context.cpp :361 : 46756831390 us: hipDevicePrimaryCtxGetState ( 0, 0x7ffd73262fb8, 0x7ffd73262fbc )
:3:hip_context.cpp :375 : 46756831395 us: hipDevicePrimaryCtxGetState: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756831402 us: hipGetDevice ( 0x7ffd73262f3c )
:3:hip_device_runtime.cpp :698 : 46756831406 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_context.cpp :361 : 46756831411 us: hipDevicePrimaryCtxGetState ( 0, 0x7ffd73262f58, 0x7ffd73262f5c )
:3:hip_context.cpp :375 : 46756831417 us: hipDevicePrimaryCtxGetState: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756831903 us: hipGetDevice ( 0x7ffd73263754 )
:3:hip_device_runtime.cpp :698 : 46756831910 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756831958 us: hipGetDevice ( 0x7ffd73262b84 )
:3:hip_device_runtime.cpp :698 : 46756831962 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756831969 us: hipGetDevice ( 0x7ffd73262a44 )
:3:hip_device_runtime.cpp :698 : 46756831974 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756831991 us: hipGetDevice ( 0x7ffd7326285c )
:3:hip_device_runtime.cpp :698 : 46756831995 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_stream.cpp :313 : 46756832032 us: hipDeviceGetStreamPriorityRange ( 0x7ffd73262810, 0x7ffd73262830 )
:3:hip_stream.cpp :321 : 46756832038 us: hipDeviceGetStreamPriorityRange: Returned hipSuccess :
:3:hip_error.cpp :36 : 46756832051 us: hipGetLastError ( )
:3:hip_device_runtime.cpp :686 : 46756832055 us: hipGetDevice ( 0x7ffd732620fc )
:3:hip_device_runtime.cpp :698 : 46756832059 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_graph.cpp :1065: 46756832066 us: hipStreamIsCapturing ( stream:<null>, 0x7ffd73262360 )
:3:hip_graph.cpp :1066: 46756832071 us: hipStreamIsCapturing: Returned hipSuccess :
:3:hip_memory.cpp :770 : 46756832093 us: hipMalloc ( 0x7ffd73262440, 2097152 )
:3:rocdevice.cpp :2327: 46756832345 us: Device=0x3cad0bb0, freeMem_ = 0x1fee00000
:3:hip_memory.cpp :772 : 46756832354 us: hipMalloc: Returned hipSuccess : 0x7ff35aa00000: duration: 261 us
:3:hip_device_runtime.cpp :717 : 46756832373 us: hipSetDevice ( 0 )
:3:hip_device_runtime.cpp :721 : 46756832379 us: hipSetDevice: Returned hipSuccess :
:3:hip_device_runtime.cpp :717 : 46756832383 us: hipSetDevice ( 0 )
:3:hip_device_runtime.cpp :721 : 46756832386 us: hipSetDevice: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756832443 us: hipGetDevice ( 0x7ffd73262ac4 )
:3:hip_device_runtime.cpp :698 : 46756832448 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756832454 us: hipGetDevice ( 0x7ffd732628dc )
:3:hip_device_runtime.cpp :698 : 46756832461 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_memory.cpp :820 : 46756832480 us: hipMemcpyWithStream ( 0x7ff35aa00000, 0x3cd58d40, 12, hipMemcpyHostToDevice, stream:<null> )
:3:rocdevice.cpp :2967: 46756832490 us: Number of allocated hardware queues with low priority: 0, with normal priority: 0, with high priority: 0, maximum per priority is: 4
:3:rocdevice.cpp :3048: 46756839285 us: Created SWq=0x7ff47fbee000 to map on HWq=0x7ff359400000 with size 16384 with priority 1, cooperative: 0
:3:rocdevice.cpp :3141: 46756839312 us: acquireQueue refCount: 0x7ff359400000 (1)
:3:devprogram.cpp :2621: 46756978651 us: Using Code Object V5.
:3:rocvirtual.cpp :774 : 46756981149 us: Arg0: uchar* src = ptr:0x7ff353a00000
:3:rocvirtual.cpp :774 : 46756981159 us: Arg1: uchar* dst = ptr:0x7ff35aa00000
:3:rocvirtual.cpp :883 : 46756981166 us: Arg2: ulong size = val:0xc (size:0x8)
:3:rocvirtual.cpp :883 : 46756981171 us: Arg3: uint remainder = val:0xc (size:0x4)
:3:rocvirtual.cpp :883 : 46756981175 us: Arg4: uint aligned_size = val:0x10 (size:0x4)
:3:rocvirtual.cpp :883 : 46756981181 us: Arg5: ulong end_ptr = val:0x7ff35aa00000 (size:0x8)
:3:rocvirtual.cpp :883 : 46756981185 us: Arg6: uint next_chunk = val:0x200 (size:0x4)
:3:rocvirtual.cpp :883 : 46756981189 us: Arg7: uint workgroup_size = val:0x200 (size:0x4)
:3:rocvirtual.cpp :3351: 46756981195 us: ShaderName : __amd_rocclr_copyBuffer
:3:rocvirtual.cpp :3549: 46756981216 us: KernargSegmentByteSize = 48 KernargSegmentAlignment = 128
:3:hip_memory.cpp :837 : 46756981260 us: hipMemcpyWithStream: Returned hipSuccess : : duration: 148780 us
:3:hip_device_runtime.cpp :717 : 46756981274 us: hipSetDevice ( 0 )
:3:hip_device_runtime.cpp :721 : 46756981278 us: hipSetDevice: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756981467 us: hipGetDevice ( 0x7ffd732632ec )
:3:hip_device_runtime.cpp :698 : 46756981472 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756981481 us: hipGetDevice ( 0x7ffd73262fd4 )
:3:hip_device_runtime.cpp :698 : 46756981486 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756981494 us: hipGetDevice ( 0x7ffd73262c7c )
:3:hip_device_runtime.cpp :698 : 46756981499 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :717 : 46756981512 us: hipSetDevice ( 0 )
:3:hip_device_runtime.cpp :721 : 46756981516 us: hipSetDevice: Returned hipSuccess :
:3:hip_device_runtime.cpp :686 : 46756981538 us: hipGetDevice ( 0x7ffd73262efc )
:3:hip_device_runtime.cpp :698 : 46756981541 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device_runtime.cpp :686 : 46756981546 us: hipGetDevice ( 0x7ffd73262f3c )
:3:hip_device_runtime.cpp :698 : 46756981551 us: hipGetDevice: Returned hipSuccess : 0
:3:hip_device.cpp :659 : 46756981562 us: hipGetDevicePropertiesR0600 ( 0x7ffd73262a40, 0 )
:3:hip_device.cpp :661 : 46756981568 us: hipGetDevicePropertiesR0600: Returned hipSuccess :
:3:hip_platform.cpp :240 : 46756981588 us: __hipPushCallConfiguration ( {1,1,1}, {256,1,1}, 0, stream:<null> )
:3:hip_platform.cpp :244 : 46756981595 us: __hipPushCallConfiguration: Returned hipSuccess :
:3:hip_platform.cpp :249 : 46756981603 us: __hipPopCallConfiguration ( {256,0,2295673471}, {1931882752,32765,262}, 0x7ffd73263160, 0x7ffd73263130 )
:3:hip_platform.cpp :258 : 46756981610 us: __hipPopCallConfiguration: Returned hipSuccess :
:3:hip_module.cpp :812 : 46756981624 us: hipLaunchKernel ( 0x7ff587231b08, {1,1,1}, {256,1,1}, 0x7ffd73263190, 0, stream:<null> )
:3:hip_code_object.cpp :957 : 46756985319 us: amd::Comgr::get_data() return 0 size for agent_triple_target_ids[0]=amdgcn-amd-amdhsa--gfx1102
:3:hip_code_object.cpp :957 : 46756985339 us: amd::Comgr::get_data() return 0 size for agent_triple_target_ids[0]=amdgcn-amd-amdhsa--gfx1102
:1:hip_fatbin.cpp :736 : 46756985350 us: Cannot find CO in the bundle /opt/venv/lib/python3.12/site-packages/torch/lib/libtorch_hip.so for ISA: amdgcn-amd-amdhsa--gfx1102
:3:hip_module.cpp :813 : 46756985361 us: hipLaunchKernel: Returned hipErrorInvalidDeviceFunction : : duration: 3737 us
:3:hip_error.cpp :36 : 46756985370 us: hipGetLastError ( )
:3:hip_error.cpp :36 : 46756985376 us: hipGetLastError ( )
:3:hip_device_runtime.cpp :717 : 46756985634 us: hipSetDevice ( 0 )
:3:hip_device_runtime.cpp :721 : 46756985640 us: hipSetDevice: Returned hipSuccess :
Traceback (most recent call last):
File "/app/simple.py", line 11, in <module>
y = x + 1
~~^~~
torch.AcceleratorError: HIP error: invalid device function
HIP kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing AMD_SERIALIZE_KERNEL=3
Compile with `TORCH_USE_HIP_DSA` to enable device-side assertions.
:3:rocdevice.cpp :3173: 46757303054 us: releaseQueue refCount:0x7ff359400000 (0)
:3:rocdevice.cpp :286 : 46757304592 us: Deleting hardware queue 0x7ff359400000 with refCount 0
You found it!
Cannot find CO in the bundle /opt/venv/lib/python3.12/site-packages/torch/lib/libtorch_hip.so for ISA: amdgcn-amd-amdhsa--gfx1102
Just a sanity check, can you confirm that the file exists? And do a ls <filename> so we can see the size.
Sorry, I thought it was good news, but it's true that I have no idea about the underlying implications.
# ls -la /opt/venv/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
-rwxr-xr-x 1 root root 318427912 Oct 21 18:27 /opt/venv/lib/python3.12/site-packages/torch/lib/libtorch_hip.so
Since I'm a bit of a tease, I tested the script with the latest image, using ROCM 7.1. The error is different, and more severe.
# AMD_LOG_LEVEL=3 python simple.py
:3:rocdevice.cpp :420 : 48627338869 us: Initalizing runtime stack, Enumerated GPU agents = 1
:3:rocdevice.cpp :187 : 48627339128 us: Numa selects cpu agent[0]=0x34b087a0(fine=0x35183060,coarse=0x34ba1ee0) for gpu agent=0x34ba3720 CPU<->GPU XGMI=0
:3:rocsettings.cpp :277 : 48627339141 us: Using dev kernel arg wa = 0
:3:comgrctx.cpp :127 : 48627339179 us: Loaded COMGR library version 3.0.
:3:rocdevice.cpp :1590: 48627342705 us: addressableNumVGPRs=256, totalNumVGPRs=1024, vGPRAllocGranule=16, availableRegistersPerCU_=131072
:3:rocdevice.cpp :1604: 48627342721 us: imageSupport=1
:3:rocdevice.cpp :1635: 48627342725 us: Gfx Major/Minor/Stepping: 11/0/2
:3:rocdevice.cpp :1637: 48627342729 us: HMM support: 1, XNACK: 0, Direct host access: 0
:3:rocdevice.cpp :1639: 48627342735 us: Max SDMA Read Mask: 0x3, Max SDMA Write Mask: 0x3
:3:hip_context.cpp :60 : 48627346075 us: HIP Version: 7.1.25424.4179531dcd, Direct Dispatch: 1
:3:os_posix.cpp :961 : 48627346087 us: HIP Library Path: /opt/rocm/lib/libamdhip64.so.7
...
:3:hip_fatbin.cpp :511 : 48627995665 us: Looking up generic name of : amdgcn-amd-amdhsa--gfx1102 - amdgcn-amd-amdhsa--gfx11-generic
:3:hip_fatbin.cpp :537 : 48628103387 us: Device name: amdgcn-amd-amdhsa--gfx1102 Generic name: amdgcn-amd-amdhsa--gfx11-generic
:1:hip_fatbin.cpp :694 : 48628103406 us: No compatible code objects found for: gfx1102, value of HIP_FORCE_SPIRV_CODEOBJECT: 0
Segmentation fault (core dumped)
I hope this helps you.
Thanks for hanging in there and running all these tests. Definitely progress.
In the ROCm 7.0 docker image, can you try installing the PyTorch nightly wheel. It will have its own version of the ROCm libraries bundled with it.
Here are the instructions:
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0
# pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.0
Looking in indexes: https://download.pytorch.org/whl/nightly/rocm7.0
Requirement already satisfied: torch in /opt/venv/lib/python3.12/site-packages (2.8.0+rocm7.0.2.lw.git245bf6ed)
Requirement already satisfied: torchvision in /opt/venv/lib/python3.12/site-packages (0.23.0+rocm7.0.2.git824e8c87)
Requirement already satisfied: filelock in /opt/venv/lib/python3.12/site-packages (from torch) (3.20.0)
Requirement already satisfied: typing-extensions>=4.10.0 in /opt/venv/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /opt/venv/lib/python3.12/site-packages (from torch) (80.9.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/venv/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /opt/venv/lib/python3.12/site-packages (from torch) (3.5)
Requirement already satisfied: jinja2 in /opt/venv/lib/python3.12/site-packages (from torch) (3.1.6)
Requirement already satisfied: fsspec in /opt/venv/lib/python3.12/site-packages (from torch) (2025.9.0)
Requirement already satisfied: triton==3.4.0+rocm7.0.2.gitf9e5bf54 in /opt/venv/lib/python3.12/site-packages (from torch) (3.4.0+rocm7.0.2.gitf9e5bf54)
Requirement already satisfied: numpy in /opt/venv/lib/python3.12/site-packages (from torchvision) (2.3.4)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/venv/lib/python3.12/site-packages (from torchvision) (12.0.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/venv/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/venv/lib/python3.12/site-packages (from jinja2->torch) (3.0.3)
Apparently everything was already in the image. I tested the script again, but with the same result.