RuntimeError was encountered while calling torch._C._cuda_init()
Problem description Run the SCUDA server on a GPU server, then run commands. RuntimeError was encountered.
Environmental information CUDA_VERSION=12.6.2 DISTRO_VERSION=24.04 OS_DISTRO=ubuntu CUDNN_TAG=cudnn
Reproduce steps
- Build an image using the example dockerfile and start the container
- Execute the command,
pip install numpy pandas torch - Use the command to start the server ./local.sh server
- Set environment var export SCUDA_SERVER=127.0.0.1
- Use the command to start the client LD_PRELOAD=./libscuda_12.6.so python3 test.py My test.py file like this:
import torch
print(torch.cuda.is_available())
print(torch.cuda.get_device_name())
Current behavior The output like this:
......
dlsym: cuModuleGetGlobal_v2
dlsym: PyInit__C
dlsym: PyInit__multiarray_umath
dlsym: PyInit__contextvars
dlsym: PyInit__umath_linalg
dlsym: PyInit_mmap
dlsym: PyInit__ssl
dlsym: PyInit__asyncio
dlsym: PyInit__queue
dlsym: PyInit__hashlib
dlsym: PyInit__multiprocessing
dlsym: cuDevicePrimaryCtxGetState
dlsym: cuGetErrorString
True
Traceback (most recent all last):
File "<string>", line 1, in <module>
...
File ".../site-packages/torch/cuda/__init__.py", live 372, in _lazy_init
torch._C._cuda_init()
RuntimeError: CUDA driver error: initialization error
It seems that torch.cuda.is_available() works normally, but cuda can not be initialized in fact.
I tried to run the script bellow without scuda method, it gave me the correct result. However, the runtime error will be encountered if you use LD_PRELOAD=./libscuda_12.6.so to run it. I guess there was a problem using RPC to implement CUDA's C interface.
Are you able to reproduce this issue without pytorch? You may have to write a small program and call cudaInit() yourself.
I suspect pytorch is doing more than just calling cudaInit() and support for all the APIs hasn’t gotten quite that far yet.
On Mon, Jun 9, 2025 at 5:28 PM James.L @.***> wrote:
James-Leong created an issue (kevmo314/scuda#111) https://github.com/kevmo314/scuda/issues/111
Problem description Run the SCUDA server on a GPU server, then run commands. RuntimeError was encountered.
Environmental information CUDA_VERSION=12.6.2 DISTRO_VERSION=24.04 OS_DISTRO=ubuntu CUDNN_TAG=cudnn
Reproduce steps
- Build an image using the example dockerfile https://github.com/kevmo314/scuda/blob/29026b0dbe0716bc86762655cc706eb303f5deb5/Dockerfile.build and start the container
- Execute the command, pip install numpy pandas torch
- Use the command to start the server ./local.sh server
- Set environment var export SCUDA_SERVER=127.0.0.1
- Use the command to start the client LD_PRELOAD=./libscuda_12.6.so python3 test.py My test.py file like this:
import torchprint(torch.cuda.is_available())print(torch.cuda.get_device_name())
Current behavior The output like this:
...... dlsym: cuModuleGetGlobal_v2 dlsym: PyInit__C dlsym: PyInit__multiarray_umath dlsym: PyInit__contextvars dlsym: PyInit__umath_linalg dlsym: PyInit_mmap dlsym: PyInit__ssl dlsym: PyInit__asyncio dlsym: PyInit__queue dlsym: PyInit__hashlib dlsym: PyInit__multiprocessing dlsym: cuDevicePrimaryCtxGetState dlsym: cuGetErrorString True Traceback (most recent all last): File "
", line 1, in ... File ".../site-packages/torch/cuda/init.py", live 372, in _lazy_init torch._C._cuda_init() RuntimeError: CUDA driver error: initialization error It seems that torch.cuda.is_available() works normally, but cuda can not be initialized in fact.
I tried to run the script bellow without scuda method, it gave me the correct result. However, the runtime error will be encountered if you use LD_PRELOAD=./libscuda_12.6.so to run it. I guess there was a problem using RPC to implement CUDA's C interface.
— Reply to this email directly, view it on GitHub https://github.com/kevmo314/scuda/issues/111, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAD423XG5TI6LHHNECF4LN33CVASTAVCNFSM6AAAAAB64KTMWSVHI2DSMVQWIX3LMV43ASLTON2WKOZTGEZDSNJXHA3DSMQ . You are receiving this because you are subscribed to this thread.Message ID: @.***>
Are you able to reproduce this issue without pytorch? You may have to write
a small program and call cudaInit() yourself.
I suspect pytorch is doing more than just calling cudaInit() and support
for all the APIs hasn’t gotten quite that far yet.
On Mon, Jun 9, 2025 at 5:28 PM James.L @.***> wrote:
James-Leong created an issue (kevmo314/scuda#111)
Problem description
Run the SCUDA server on a GPU server, then run commands. RuntimeError was
encountered.
Environmental information
CUDA_VERSION=12.6.2
DISTRO_VERSION=24.04
OS_DISTRO=ubuntu
CUDNN_TAG=cudnn
Reproduce steps
- Build an image using the example dockerfile
https://github.com/kevmo314/scuda/blob/29026b0dbe0716bc86762655cc706eb303f5deb5/Dockerfile.build
and start the container
- Execute the command, pip install numpy pandas torch
- Use the command to start the server ./local.sh server
- Set environment var export SCUDA_SERVER=127.0.0.1
- Use the command to start the client LD_PRELOAD=./libscuda_12.6.so
python3 test.py
My test.py file like this:
import torchprint(torch.cuda.is_available())print(torch.cuda.get_device_name())
Current behavior
The output like this:
......
dlsym: cuModuleGetGlobal_v2
dlsym: PyInit__C
dlsym: PyInit__multiarray_umath
dlsym: PyInit__contextvars
dlsym: PyInit__umath_linalg
dlsym: PyInit_mmap
dlsym: PyInit__ssl
dlsym: PyInit__asyncio
dlsym: PyInit__queue
dlsym: PyInit__hashlib
dlsym: PyInit__multiprocessing
dlsym: cuDevicePrimaryCtxGetState
dlsym: cuGetErrorString
True
Traceback (most recent all last):
File "
", line 1, in ...
File ".../site-packages/torch/cuda/init.py", live 372, in _lazy_init
torch._C._cuda_init()RuntimeError: CUDA driver error: initialization error
It seems that torch.cuda.is_available() works normally, but cuda can not
be initialized in fact.
I tried to run the script bellow without scuda method, it gave me the
correct result. However, the runtime error will be encountered if you use
LD_PRELOAD=./libscuda_12.6.so to run it. I guess there was a problem
using RPC to implement CUDA's C interface.
—
Reply to this email directly, view it on GitHub
https://github.com/kevmo314/scuda/issues/111, or unsubscribe
.
You are receiving this because you are subscribed to this thread.Message
ID: @.***>
The following is the small program i wrote without PyTorch. It can run normally with scuda.
#include <stdio.h>
#include <cuda_runtime.h>
//__global__声明的函数,告诉编译器这段代码交由CPU调用,由GPU执行
__global__ void add(const int *dev_a,const int *dev_b,int *dev_c)
{
int i=threadIdx.x;
dev_c[i]=dev_a[i]+dev_b[i];
}
int test_cal() {
//申请主机内存,并进行初始化
int host_a[512],host_b[512],host_c[512];
for(int i=0;i<512;i++)
{
host_a[i]=i;
host_b[i]=i<<1;
}
//定义cudaError,默认为cudaSuccess(0)
cudaError_t err = cudaSuccess;
//申请GPU存储空间
int *dev_a,*dev_b,*dev_c;
err=cudaMalloc((void **)&dev_a, sizeof(int)*512);
err=cudaMalloc((void **)&dev_b, sizeof(int)*512);
err=cudaMalloc((void **)&dev_c, sizeof(int)*512);
if(err!=cudaSuccess)
{
printf("the cudaMalloc on GPU is failed");
return 1;
}
printf("SUCCESS");
//将要计算的数据使用cudaMemcpy传送到GPU
cudaMemcpy(dev_a,host_a,sizeof(host_a),cudaMemcpyHostToDevice);
cudaMemcpy(dev_b,host_b,sizeof(host_b),cudaMemcpyHostToDevice);
//调用核函数在GPU上执行。数据较少,使用一个Block,含有512个线程
add<<<1,512>>>(dev_a,dev_b,dev_c);
cudaMemcpy(&host_c,dev_c,sizeof(host_c),cudaMemcpyDeviceToHost);
for(int i=0;i<512;i++)
printf("host_a[%d] + host_b[%d] = %d + %d = %d\n",i,i,host_a[i],host_b[i],host_c[i]);
cudaFree(dev_a);//释放GPU内存
cudaFree(dev_b);//释放GPU内存
cudaFree(dev_c);//释放GPU内存
printf("successfully calculate on GPU.\n");
return 0;
}
int test_device() {
int deviceCount;
cudaError_t error = cudaGetDeviceCount(&deviceCount);
if (error != cudaSuccess) {
printf("Error getting device count: %s\n", cudaGetErrorString(error));
return -1;
}
printf("Number of CUDA devices: %d\n", deviceCount);
for (int i = 0; i < deviceCount; i++) {
cudaDeviceProp prop;
error = cudaGetDeviceProperties(&prop, i);
if (error != cudaSuccess) {
printf("Error getting properties for device %d: %s\n", i, cudaGetErrorString(error));
continue;
}
printf("Device %d: %s\n", i, prop.name);
cudaDeviceReset();
}
return 0;
}
int main(int argc, char* argv[]) {
if (argc > 1 && strcmp(argv[1], "cal") == 0) {
return test_cal();
}
return test_device();
}
Maybe your speculation that there are issues with PyTorch is correct. I have reviewed the source code of PyTorch regarding the definition of the function torch._C._cuda_init().
May I ask if there are plans to optimize SCUDA's support for PyTorch, as PyTorch has a wide range of applications and is an important tool for scientific computing.
May I ask if there are plans to optimize SCUDA's support for PyTorch, as PyTorch has a wide range of applications and is an important tool for scientific computing.
Yes, there are plans and it's on the roadmap but unfortunately it's quite difficult. We could certainly use help though! The biggest barriers are covering all the APIs that PyTorch uses: we only support a small subset of CUDA at the moment.
@kevmo314 @James-Leong Thank you for your contribution to this project — it’s a great piece of work!
We tried to run the following test case: pytorch/examples: mnist, but it didn’t pass. It seems to be related to this issue. Are there any updates or workarounds available?
I'm afraid I have to say that the project doesn't support pytorch at present. @dyunwei
Thank you for your reply. If I want to make SCUDA compatible with PyTorch, what work would be required? Do you have any suggestions or ideas on how to approach this? Many thanks. @James-Leong @kevmo314
First, it is necessary to analyze the CUDA-related interfaces used by PyTorch, and then implement the RPC call process for these interfaces in these two files: manual_server.cpp, manual_client.cpp.
I previously wanted to solve the problem of using PyTorch, but I found it to be a major undertaking. @dyunwei
Thanks for the quick reply. That does sound like a fair amount of work. @James-Leong
We attempted to intercept all functions in the NVML, CUDA, CUDA Runtime, cuDNN, and cuBLAS libraries. This resulted in functions that should execute on the host being called and executed remotely, causing program crashes. Would it be sufficient to only intercept some functions in NVML, CUDA, and CUDA Runtime?@James-Leong
I think the GPU-dependent methods in cudnn should also be intercepted, but the methods that should be executed locally indeed should not be sent to the remote. @violinY