Wave-U-Net-Pytorch
Wave-U-Net-Pytorch copied to clipboard
Support for Apple Metal (MPS) backend
Please add support for the MPS backend as you do for cuda:
if torch.backends.mps.is_available():
mps = torch.device("mps")
model = model_utils.DataParallel(model)
model.to(mps)
# ... and so on...
x = x.to(mps)
Has anyone looked into implementing MPS for this yet?