cellpose icon indicating copy to clipboard operation
cellpose copied to clipboard

[FEATURE] Support for AMD gpu's (and others) on windows

Open Teranis opened this issue 8 months ago • 5 comments

Describe the solution you'd like Hi, I'm working on Cell ACDC. My personal computer runs Windows and has an AMD GPU. In order to run Cellpose on this GPU, I looked into DirectML, which uses Direct X to run PyTorch. Since DirectML doesn't support sparse tensors, I had to do some slight operation of pytorch. I'll provide the code in additional context, where I move every operation and tensor which is not supported to the CPU. Although I had my doubts about the performance, it still seems A LOT faster than just running it on the CPU. The only problem is that one needs to use python 3.11. Describe alternatives you've considered Waiting 5-10x longer while segmenting on CPU

Additional context Code for making cellpose use the DirectML:

def setup_custom_device(model, device):
    model.gpu = True
    model.device = device
    model.mkldnn = False
    if hasattr(model, 'cp'):
        model.cp.gpu = True
        model.cp.device = device
        model.cp.mkldnn = False
        if hasattr(model.cp, 'net'):
            model.cp.net.to(device)
            model.cp.net.mkldnn = False
    if hasattr(model, 'net'):
        model.net.to(device)
        model.net.mkldnn = False
    if hasattr(model, 'sz'):
        model.sz.device = device

def setup_directML(model):
    print(
        'Using DirectML GPU for Cellpose model inference'
    )
    import torch_directml
    directml_device = torch_directml.device()
    setup_custom_device(model, directml_device)

Code for fixing sparse tensor implementation:

def fix_sparse_directML(verbose=True):
    """DirectML does not support sparse tensors, so we need to fallback to CPU
    """
    import torch
    import functools
    import warnings

    def fallback_to_cpu_on_sparse_error(func, verbose=True):
        @functools.wraps(func) # wrapper shinanigans (thanks chatgpt)
        def wrapper(*args, **kwargs):
            device_arg = kwargs.get('device', None)

            # Ensure indices are int64 if args[0] looks like indices
            if len(args) >= 1 and isinstance(args[0], torch.Tensor):
                if args[0].dtype != torch.int64:
                    args = (args[0].to(dtype=torch.int64),) + args[1:]

            try: # try to move result to dml
                result = func(*args, **kwargs)
                if device_arg is not None and str(device_arg).lower() == "dml":
                    try:
                        result.to("dml")
                    except RuntimeError as e:
                        if verbose:
                            warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
                        kwargs['device'] = torch.device("cpu")
                        return func(*args, **kwargs)
                return result

            except RuntimeError as e: # try and run on dlm, if it fails, fallback to cpu
                if "sparse" in str(e).lower() or "not implemented" in str(e).lower():
                    if verbose:
                        warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
                    kwargs['device'] = torch.device("cpu")

                    # Re-apply indices dtype correction before retrying on CPU
                    if len(args) >= 1 and isinstance(args[0], torch.Tensor):
                        if args[0].dtype != torch.int64:
                            args = (args[0].to(dtype=torch.int64),) + args[1:]

                    return func(*args, **kwargs)
                else:
                    raise e

        return wrapper

    # --- Patch Sparse Tensor Constructors ---

    # High-level API
    torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error(torch.sparse_coo_tensor, verbose=verbose)

    # Low-level API
    if hasattr(torch._C, "_sparse_coo_tensor_unsafe"):
        torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error(torch._C._sparse_coo_tensor_unsafe, verbose=verbose)

    if hasattr(torch._C, "_sparse_coo_tensor_with_dims_and_tensors"):
        torch._C._sparse_coo_tensor_with_dims_and_tensors = fallback_to_cpu_on_sparse_error(
            torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose
        )

    if hasattr(torch.sparse, 'SparseTensor'):
        torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error(torch.sparse.SparseTensor, verbose=verbose)
    
    # suppress warnings 
    import warnings
    warnings.filterwarnings("once", message="Sparse op failed on DirectML*")

Teranis avatar May 13 '25 16:05 Teranis

Hi can you make a PR and put this in the contrib folder? And provide some documentation on how/where to edit the pytorch code?

mrariden avatar May 13 '25 18:05 mrariden

Hi, thanks for getting back so quickly! I have now created a fork and added a file. However, I cannot seem to get the working example to work. This is probably related to the fact that v4 is now used instead of v3, for which I originally created the code. I will have another look tomorrow. If you want to have a stab at it, here's the link to the file..

Teranis avatar May 14 '25 15:05 Teranis

Hi, so I have fixed the remaining issues and also did some quick benchmarking. I also added extensive documentation and a working example in the file. The torch code itself is not changed; the function fix_sparse_directML replaces some torch functions with wrapped versions. This is done so that sparse operations are carried out on the CPU. The only place where I edited the code of cellpose is in dynamics.py, where I set the device to CPU for the remove_bad_flow_masks() function. Aside from bringing significant performance improvements, this also fixes that masks were being deleted due to high errors, probably caused by some rounding errors or incompatibility.

In the future, the function setup_custom_device can probably be better integrated into cellpose, as it only moves everything to the DirectML device and sets the device attribute accordingly, turns off mkldnn and sets the GPU flag to true.

I also wanted to note that our implementation of cellpose v3 with custom weights didn't run properly on CUDA, until we now run the function setup_custom_device() with the normal CUDA device.

Benchmarking

10 images, 1000x1000 px, passed in for loop, "cpsam" segmentation: CUDA: ~1.8 s/frame DirectML: ~4.7 s/frame CPU: ~55 s/frame

Specs: RTX3090, AMD 7900x3D, 64gb RAM DDR5 6000 I also confirmed that this works on the Radeon 6900 XT. I have also checked that the results are the same for GPU and DirectML. I have also noticed that Cellpose is slower for DirectML if provided a list of arrays for segmentation,

Teranis avatar May 16 '25 15:05 Teranis

I also wanted to note that our implementation of cellpose v3 with custom weights didn't run properly on CUDA, until we now run the function setup_custom_device() with the normal CUDA device.

@Teranis Can you elaborate on this?

mrariden avatar May 16 '25 21:05 mrariden

So, for our implementation of loading a custom model, there was a problem where model.gpu = True, but just by comparing computation time and GPU/CPU utilisation, it was clear that it was running on CPU. The function setup_custom_device fixed this.

We are still working on integrating v4 into Cell ACDC, but if the problem persists for the newer versions, we'll let you know!

Teranis avatar May 18 '25 12:05 Teranis