[FEATURE] Support for AMD gpu's (and others) on windows
Describe the solution you'd like Hi, I'm working on Cell ACDC. My personal computer runs Windows and has an AMD GPU. In order to run Cellpose on this GPU, I looked into DirectML, which uses Direct X to run PyTorch. Since DirectML doesn't support sparse tensors, I had to do some slight operation of pytorch. I'll provide the code in additional context, where I move every operation and tensor which is not supported to the CPU. Although I had my doubts about the performance, it still seems A LOT faster than just running it on the CPU. The only problem is that one needs to use python 3.11. Describe alternatives you've considered Waiting 5-10x longer while segmenting on CPU
Additional context Code for making cellpose use the DirectML:
def setup_custom_device(model, device):
model.gpu = True
model.device = device
model.mkldnn = False
if hasattr(model, 'cp'):
model.cp.gpu = True
model.cp.device = device
model.cp.mkldnn = False
if hasattr(model.cp, 'net'):
model.cp.net.to(device)
model.cp.net.mkldnn = False
if hasattr(model, 'net'):
model.net.to(device)
model.net.mkldnn = False
if hasattr(model, 'sz'):
model.sz.device = device
def setup_directML(model):
print(
'Using DirectML GPU for Cellpose model inference'
)
import torch_directml
directml_device = torch_directml.device()
setup_custom_device(model, directml_device)
Code for fixing sparse tensor implementation:
def fix_sparse_directML(verbose=True):
"""DirectML does not support sparse tensors, so we need to fallback to CPU
"""
import torch
import functools
import warnings
def fallback_to_cpu_on_sparse_error(func, verbose=True):
@functools.wraps(func) # wrapper shinanigans (thanks chatgpt)
def wrapper(*args, **kwargs):
device_arg = kwargs.get('device', None)
# Ensure indices are int64 if args[0] looks like indices
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
if args[0].dtype != torch.int64:
args = (args[0].to(dtype=torch.int64),) + args[1:]
try: # try to move result to dml
result = func(*args, **kwargs)
if device_arg is not None and str(device_arg).lower() == "dml":
try:
result.to("dml")
except RuntimeError as e:
if verbose:
warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
kwargs['device'] = torch.device("cpu")
return func(*args, **kwargs)
return result
except RuntimeError as e: # try and run on dlm, if it fails, fallback to cpu
if "sparse" in str(e).lower() or "not implemented" in str(e).lower():
if verbose:
warnings.warn(f"Sparse op failed on DirectML, falling back to CPU: {e}")
kwargs['device'] = torch.device("cpu")
# Re-apply indices dtype correction before retrying on CPU
if len(args) >= 1 and isinstance(args[0], torch.Tensor):
if args[0].dtype != torch.int64:
args = (args[0].to(dtype=torch.int64),) + args[1:]
return func(*args, **kwargs)
else:
raise e
return wrapper
# --- Patch Sparse Tensor Constructors ---
# High-level API
torch.sparse_coo_tensor = fallback_to_cpu_on_sparse_error(torch.sparse_coo_tensor, verbose=verbose)
# Low-level API
if hasattr(torch._C, "_sparse_coo_tensor_unsafe"):
torch._C._sparse_coo_tensor_unsafe = fallback_to_cpu_on_sparse_error(torch._C._sparse_coo_tensor_unsafe, verbose=verbose)
if hasattr(torch._C, "_sparse_coo_tensor_with_dims_and_tensors"):
torch._C._sparse_coo_tensor_with_dims_and_tensors = fallback_to_cpu_on_sparse_error(
torch._C._sparse_coo_tensor_with_dims_and_tensors, verbose=verbose
)
if hasattr(torch.sparse, 'SparseTensor'):
torch.sparse.SparseTensor = fallback_to_cpu_on_sparse_error(torch.sparse.SparseTensor, verbose=verbose)
# suppress warnings
import warnings
warnings.filterwarnings("once", message="Sparse op failed on DirectML*")
Hi can you make a PR and put this in the contrib folder? And provide some documentation on how/where to edit the pytorch code?
Hi, thanks for getting back so quickly! I have now created a fork and added a file. However, I cannot seem to get the working example to work. This is probably related to the fact that v4 is now used instead of v3, for which I originally created the code. I will have another look tomorrow. If you want to have a stab at it, here's the link to the file..
Hi, so I have fixed the remaining issues and also did some quick benchmarking. I also added extensive documentation and a working example in the file.
The torch code itself is not changed; the function fix_sparse_directML replaces some torch functions with wrapped versions. This is done so that sparse operations are carried out on the CPU.
The only place where I edited the code of cellpose is in dynamics.py, where I set the device to CPU for the remove_bad_flow_masks() function. Aside from bringing significant performance improvements, this also fixes that masks were being deleted due to high errors, probably caused by some rounding errors or incompatibility.
In the future, the function setup_custom_device can probably be better integrated into cellpose, as it only moves everything to the
DirectML device and sets the device attribute accordingly, turns off mkldnn and sets the GPU flag to true.
I also wanted to note that our implementation of cellpose v3 with custom weights didn't run properly on CUDA, until we now run the function setup_custom_device() with the normal CUDA device.
Benchmarking
10 images, 1000x1000 px, passed in for loop, "cpsam" segmentation:
CUDA: ~1.8 s/frame
DirectML: ~4.7 s/frame
CPU: ~55 s/frame
Specs: RTX3090, AMD 7900x3D, 64gb RAM DDR5 6000 I also confirmed that this works on the Radeon 6900 XT. I have also checked that the results are the same for GPU and DirectML. I have also noticed that Cellpose is slower for DirectML if provided a list of arrays for segmentation,
I also wanted to note that our implementation of cellpose v3 with custom weights didn't run properly on CUDA, until we now run the function setup_custom_device() with the normal CUDA device.
@Teranis Can you elaborate on this?
So, for our implementation of loading a custom model, there was a problem where model.gpu = True, but just by comparing computation time and GPU/CPU utilisation, it was clear that it was running on CPU. The function setup_custom_device fixed this.
We are still working on integrating v4 into Cell ACDC, but if the problem persists for the newer versions, we'll let you know!