point-e icon indicating copy to clipboard operation
point-e copied to clipboard

MPS Support

Open m1guelpf opened this issue 1 year ago • 9 comments

This PR introduces Metal GPU support, at the cost of slightly lowering accuracy on the gaussian_diffusion step (changing float64 to float32, only when running on mps).

m1guelpf avatar Dec 21 '22 03:12 m1guelpf

I think this works well

jameshennessytempus avatar Dec 21 '22 16:12 jameshennessytempus

perhaps something like this to preserve float precision on cuda?

diff --git a/point_e/util/precision_compatibility.py b/point_e/util/precision_compatibility.py
new file mode 100644
--- /dev/null
+++ b/point_e/util/precision_compatibility.py
@@ -0,0 +1,5 @@
+import torch
+import numpy as np
+
+NP_FLOAT32_64 = np.float32 if torch.backends.mps.is_available() else np.float64
+TH_FLOAT32_64 = torch.float32 if torch.backends.mps.is_available() else torch.float64
\ No newline at end of filediff --git a/point_e/diffusion/gaussian_diffusion.py b/point_e/diffusion/gaussian_diffusion.py
--- point_e/diffusion/gaussian_diffusion.py
+++ point_e/diffusion/gaussian_diffusion.py
@@ -6,8 +6,9 @@
 from typing import Any, Dict, Iterable, Optional, Sequence, Union
 
 import numpy as np
 import torch as th
+from point_e.util.precision_compatibility import NP_FLOAT32_64, TH_FLOAT32_64
 
 
 def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
     """
@@ -15,9 +16,9 @@
 
     See get_named_beta_schedule() for the new library of schedules.
     """
     if beta_schedule == "linear":
-        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32)
+        betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=NP_FLOAT32_64)
     else:
         raise NotImplementedError(beta_schedule)
     assert betas.shape == (num_diffusion_timesteps,)
     return betas
@@ -159,9 +160,9 @@
         self.channel_scales = channel_scales
         self.channel_biases = channel_biases
 
         # originally uses float64 for accuracy, moving to float32 for mps compatibility
-        betas = np.array(betas, dtype=np.float32)
+        betas = np.array(betas, dtype=NP_FLOAT32_64)
         self.betas = betas
         assert len(betas.shape) == 1, "betas must be 1-D"
         assert (betas > 0).all() and (betas <= 1).all()
 
@@ -1012,9 +1013,9 @@
     :param broadcast_shape: a larger shape of K dimensions with the batch
                             dimension equal to the length of timesteps.
     :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
     """
-    res = th.from_numpy(arr).to(dtype=th.float32, device=timesteps.device)[timesteps].to(th.float32)
+    res = th.from_numpy(arr).to(dtype=TH_FLOAT32_64, device=timesteps.device)[timesteps].to(TH_FLOAT32_64)
     while len(res.shape) < len(broadcast_shape):
         res = res[..., None]
     return res + th.zeros(broadcast_shape, device=timesteps.device)
 

henrycunh avatar Dec 21 '22 20:12 henrycunh

I love it!

jamesthesnake avatar Dec 22 '22 01:12 jamesthesnake

@henrycunh Added!

m1guelpf avatar Dec 22 '22 03:12 m1guelpf

Tried now on a macbook air M2.

It worked very well, for reference:

  • circa 30s on windows 11 with cuda on RTX2080
  • 3m 53s on macBook Air M2
  • more than 20 minute using CPU on intel i9 11gen

Only problem is the actual implementation of pytorch for MPS, that get this: UserWarning: The operator 'aten::linalg_vector_norm' is not currently supported on the MPS backend and will fall back to run on the CPU.

xmario3 avatar Dec 22 '22 22:12 xmario3

I apologize for my question, but how noticeable is the change to float32?

peruginiandrea avatar Jan 01 '23 16:01 peruginiandrea

I apologize for my question, but how noticeable is the change to float32?

I'm pretty confident that using higher precision, like float64, will almost always give us tighter, more accurate results when we're smoothing out noise with a Gaussian diffusion algorithm. It's true that using higher precision can be a bit more computationally intensive, but the benefits are usually worth it. Plus, it's always nice to have the extra accuracy and stability in our results!

henrycunh avatar Jan 02 '23 18:01 henrycunh

Could we set that as a parameter that defaults to 64 but write another paramter that is 32?

jameshennessytempus avatar Jan 02 '23 20:01 jameshennessytempus

^ agree

jamesthesnake avatar Jan 02 '23 21:01 jamesthesnake