cutlass icon indicating copy to clipboard operation
cutlass copied to clipboard

[BUG] CuTEDSL with tvm-ffi is globally resetting device

Open ngimel opened this issue 1 month ago • 20 comments

Which component has the problem?

CuTe DSL

Bug Report

Describe the bug When CuTEDSL works on tensors that are not on current device, it globally reset device to match the device of the input tensors, but doesn't set it back, leading to unexpected device change

Steps/Code to reproduce bug

import cutlass.cute as cute
import torch

@cute.kernel
def device_add_one(a: cute.Tensor, b: cute.Tensor):
   threads_per_block = 128
   cta_x_, _, _ = cute.arch.block_idx()
   tid_x, _, _ = cute.arch.thread_idx()
   tid = cta_x_ * threads_per_block + tid_x
   if tid < a.shape[0]:
      b[tid] = (a[tid] + 2.5).to(cute.BFloat16)

@cute.jit
def add_one(a: cute.Tensor, b: cute.Tensor):
   n = a.shape[0]
   threads_per_block = 128
   blocks = (n + threads_per_block - 1) // threads_per_block
   device_add_one(a, b).launch(
      grid=(blocks, 1, 1),
      block=(threads_per_block, 1, 1),
   )

def example_add_one():
   n = cute.sym_int()
   a_cute = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (n,))
   b_cute = cute.runtime.make_fake_compact_tensor(cute.BFloat16, (n,))
   # compile the kernel with "--enable-tvm-ffi" option and example input tensors
   compiled_add_one = cute.compile(add_one, a_cute, b_cute, options="--enable-tvm-ffi")
   # now compiled_add_one is a TVM-FFI function that can be called with torch.Tensor as input
   print(torch.cuda.current_device())
   aa=torch.randn(1024*1024*256, device="cuda")
   a_torch = torch.arange(1024*1024*128, dtype=torch.bfloat16, device="cuda:1")
   b_torch = torch.full((1024*1024*128,), float('nan'), dtype=torch.bfloat16, device="cuda:1")
   #with torch.profiler.profile(with_stack=True) as p:
   for _ in range(8):
         a_torch = a_torch.clone()
         b_torch = b_torch.clone()
         compiled_add_one(a_torch, b_torch)
         print(torch.cuda.current_device())
         aa=torch.randn(1024*1024*256, device="cuda")
         print(a_torch.device, aa.device) # aa.device is cuda:1 even though it should be cuda:0


if __name__  == "__main__":
   example_add_one()

Expected behavior Calling a CuTeDSL pre-compiled function should not have side effects

ngimel avatar Nov 25 '25 00:11 ngimel

Thanks a lot for the note, we have a fix for this and will work to ship in upcoming patch release

tqchen avatar Nov 25 '25 02:11 tqchen

Thank you! Is there a link to the fix?

ngimel avatar Nov 25 '25 03:11 ngimel

As of now it was still internal NV, but hopefully it should come out around 4.3.1

tqchen avatar Nov 25 '25 03:11 tqchen

Thank you! Is there a link to the fix?

It's patched internally, because the issue originates from internal LLVM/MLIR driver code which invokes TVM-FFI, not the TVM-FFI codebase.

@tqchen is there a timeline for cuteDSL 4.3.1 release?

Update: actually this patch will go public as 4.3.1 releases

junrushao avatar Nov 25 '25 08:11 junrushao

I thought the issue originates with https://github.com/NVIDIA/cutlass/blob/8cd5bef43a2b0d3f9846b026c271593c6e4a8e8a/python/CuTeDSL/cutlass/cutlass_dsl/tvm_ffi_provider.py#L256 that inserts only one cudaSetDevice and doesn't restore device?

ngimel avatar Nov 25 '25 17:11 ngimel

that is right, fix is to update to restore device

tqchen avatar Nov 25 '25 17:11 tqchen

When will our internal patch be reflected to this open source repo? - sorry I haven’t figured out the mechanism yet

junrushao avatar Nov 25 '25 17:11 junrushao

So it's not internal LLVM/MLIR driver code, it should go to cutlass repo

ngimel avatar Nov 25 '25 18:11 ngimel

@ngimel correct - I misunderstood the process - it will be public as 4.3.1 releases

junrushao avatar Nov 25 '25 18:11 junrushao

I tested the patch just now and it prints:

>>> python main.py
0
0
cuda:1 cuda:0
0
cuda:1 cuda:0
0
cuda:1 cuda:0
0
cuda:1 cuda:0
0
cuda:1 cuda:0
0
cuda:1 cuda:0
0
cuda:1 cuda:0
0
cuda:1 cuda:0

where aa.device is cuda:0. Seems the issue is fixed with the latest patch

junrushao avatar Nov 25 '25 19:11 junrushao

Also tested exception handling, e.g. throwing an error inside cute.jit - the latest patch seems to handle it properly

junrushao avatar Nov 25 '25 20:11 junrushao

One super annoying thing about cudaSetDevice is that it initializes context and pytorch goes through a lot of pain to prevent it. So e.g.

a=torch.randn(4, device="cuda:1")
b=torch.randn(4, device="cuda:1")
a+b

won't initialize context on device 0 But

a=torch.randn(4, device="cuda:1")
b=torch.randn(4, device="cuda:1")
cute_tvm_ffi_fn(a,b)

will, if it calls cudaSetDevice(0) in the end.

ngimel avatar Nov 25 '25 22:11 ngimel

Thanks @ngimel for the note! This is indeed something worth looking into and i haven't thought about, thanks for bringing it up.

If i understand correctly, the main goal is if we only use cuda:1, we want to avoid creating cuda context on cuda:0, calling cudaSetDevice(0) will trigger context creation.

I did some investigation on PyTorch's behavior and find out that torch.randn(4, device="cuda:1") will initialize the device to 1, and cudaGetDevice() will return 1 after this.

So in your example.


# cudaGetDevice() will return 0 here, won't intialize context

a=torch.randn(4, device="cuda:1")
b=torch.randn(4, device="cuda:1")

# cudaGetDevice() will return 1 after the randn

# now at the function call, we will query old_index = cudaGetDevice()
cute_tvm_ffi_fn(a,b)
# at exit it will set to old_index which is 1, calling cudaSetDevice(1)

As a result, it will always switch back to an device (in this case calling cudaSetDevice(1)) that already have an created context, and it will not call cudaSetDevice(0) in this case

The following code is what i used to confirm the context situation and the current device id.

from cuda.bindings import driver as cuda
from cuda.bindings import runtime as cudart
import torch

class CudaHelper:
    def __init__(self):
        """
        Initializes the CUDA Driver API on instantiation.
        Raises RuntimeError if the driver cannot be initialized.
        """
        err, = cuda.cuInit(0)
        if err != cuda.CUresult.CUDA_SUCCESS:
            raise RuntimeError(f"Failed to initialize CUDA Driver API: {err}")

    def is_context_active(self):
        """
        Returns True if a CUDA context is currently bound to the calling thread.
        Safe to call even if no context exists.
        """
        try:
            # cuCtxGetCurrent peeks at the stack without modifying it
            err, ctx = cuda.cuCtxGetCurrent()

            # If the API call fails, assume no context
            if err != cuda.CUresult.CUDA_SUCCESS:
                return False

            # If ctx is None or the handle value is 0, no context is active
            if ctx is None:
                return False

            return int(ctx) != 0
        except Exception:
            return False

    def set_device(self, device_id=0):
        """
        Sets the device using CUDA Runtime API.
        """
        err, = cudart.cudaSetDevice(device_id)
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError(f"Failed to set device {device_id}: {err}")

    def get_current_device(self):
        """
        Gets the current device using CUDA Runtime API.
        Returns the device ID.
        """
        err, device_id = cudart.cudaGetDevice()
        if err != cudart.cudaError_t.cudaSuccess:
            raise RuntimeError(f"Failed to get current device: {err}")
        return device_id

def main():
    helper = CudaHelper()
    print("\nCurrent device: {helper.get_current_device()=}")
    print(f"Is context active: {helper.is_context_active()}")
    tensor = torch.tensor([1.0, 2.0, 3.0], device='cuda:1')
    print(f"Tensor created: {tensor}")
    print(f"Is context active after torch tensor creation? {helper.is_context_active()}")
    print(f"Tensor device: {tensor.device}")
    print(f"current cuda device: {helper.get_current_device()=}")

if __name__ == "__main__":
    main()

tqchen avatar Nov 26 '25 00:11 tqchen

Perfect - Thanks @ngimel for communicating the expected behaviors and @tqchen for always bringing the most timely fixes! We all want cuteDSL + tvm-ffi to work the best for PyTorch users.

I'm not part of cuteDSL team, so may not be the best person to raise this question, but here's one, out of person curiosity as a tvm-ffi developer: it looks to me that, this MLIR/LLVM driver is doing some magics, i.e. if we launch a kernel:

a = torch.randn(4, device="cuda:1")
cute_tvm_ffi_fn(a)

under a different cuda context than tensor a is on, this MLIR/LLVM driver code will try to be smart by invoking cuda runtime APIs, which can be undesirable - as Natalia pointed out.

Let's think of that - @ngimel from PyTorch's perspective, do you think we should not call any cuda runtime/driver API at all in the MLIR/LLVM driver code? It's definitely doable if we just throw an error saying mismatched cuda context, i.e.

a = torch.randn(4, device="cuda:1")  # `a` is under context 1 

# suppose `cudaGetDevice()` return 0 here
# which means users asked to launch at context 0
cute_tvm_ffi_fn(a)  # just throw an error
# saying mismatched context, please use `torch.cuda.set_device`

junrushao avatar Nov 26 '25 08:11 junrushao

Since pytorch handles this situation by calling cudaSetDevice I think it would be good for tvm-ffi to do the same. The only minor issue as I said is avoiding initializing context on original device if it wasn't initialized otherwise, but I think we are getting lucky there, as @tqchen pointed out. Alternatively, we might expect integration on pytorch side to handle device guards, and in this case erroring out in tvm-ffi function if devices don't match would be reasonable. For robust handling, we need extra runtime checks anyway - e.g. in the example above something at runtime should be checking that a and b are on the same device, because otherwise we would error with IMA, and I was thinking that putting this responsibility on pytorch-side integration, along with taking care of device guards, would also be reasonable. wdyt?

ngimel avatar Nov 26 '25 19:11 ngimel

Thanks a lot for suggestions. We can by default align PyTorch conventions when possible and call cudaSetDevice when the input tensors device id does not match currentDevice and switch back, while be careful avoid initializing context on the original device. We can also add device checks to ensure a and b are of the same device id. To support some advanced usecases, we can also provide options to opt out of extra checks for expert users

tqchen avatar Nov 26 '25 22:11 tqchen

@tqchen Do you imagine there’s a way to formalize pytorch convention in all kernel packages tvm-ffi distributes, i.e. includes but not limited to cuteDSL, flashinfer, etc

junrushao avatar Nov 26 '25 23:11 junrushao

we can write up guidelines and also work with teams to make sure things are implemented this way.

tqchen avatar Nov 26 '25 23:11 tqchen

4.3.1 wheel is out and should contain the fix, thanks @ngimel for all the suggestions

tqchen avatar Nov 28 '25 01:11 tqchen

Thank you!

ngimel avatar Nov 29 '25 02:11 ngimel