point-e
point-e copied to clipboard
MPS Support
This PR introduces Metal GPU support, at the cost of slightly lowering accuracy on the gaussian_diffusion step (changing float64
to float32
, only when running on mps).
I think this works well
perhaps something like this to preserve float precision on cuda?
diff --git a/point_e/util/precision_compatibility.py b/point_e/util/precision_compatibility.py
new file mode 100644
--- /dev/null
+++ b/point_e/util/precision_compatibility.py
@@ -0,0 +1,5 @@
+import torch
+import numpy as np
+
+NP_FLOAT32_64 = np.float32 if torch.backends.mps.is_available() else np.float64
+TH_FLOAT32_64 = torch.float32 if torch.backends.mps.is_available() else torch.float64
\ No newline at end of filediff --git a/point_e/diffusion/gaussian_diffusion.py b/point_e/diffusion/gaussian_diffusion.py
--- point_e/diffusion/gaussian_diffusion.py
+++ point_e/diffusion/gaussian_diffusion.py
@@ -6,8 +6,9 @@
from typing import Any, Dict, Iterable, Optional, Sequence, Union
import numpy as np
import torch as th
+from point_e.util.precision_compatibility import NP_FLOAT32_64, TH_FLOAT32_64
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
"""
@@ -15,9 +16,9 @@
See get_named_beta_schedule() for the new library of schedules.
"""
if beta_schedule == "linear":
- betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float32)
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=NP_FLOAT32_64)
else:
raise NotImplementedError(beta_schedule)
assert betas.shape == (num_diffusion_timesteps,)
return betas
@@ -159,9 +160,9 @@
self.channel_scales = channel_scales
self.channel_biases = channel_biases
# originally uses float64 for accuracy, moving to float32 for mps compatibility
- betas = np.array(betas, dtype=np.float32)
+ betas = np.array(betas, dtype=NP_FLOAT32_64)
self.betas = betas
assert len(betas.shape) == 1, "betas must be 1-D"
assert (betas > 0).all() and (betas <= 1).all()
@@ -1012,9 +1013,9 @@
:param broadcast_shape: a larger shape of K dimensions with the batch
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
- res = th.from_numpy(arr).to(dtype=th.float32, device=timesteps.device)[timesteps].to(th.float32)
+ res = th.from_numpy(arr).to(dtype=TH_FLOAT32_64, device=timesteps.device)[timesteps].to(TH_FLOAT32_64)
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res + th.zeros(broadcast_shape, device=timesteps.device)
I love it!
@henrycunh Added!
Tried now on a macbook air M2.
It worked very well, for reference:
- circa 30s on windows 11 with cuda on RTX2080
- 3m 53s on macBook Air M2
- more than 20 minute using CPU on intel i9 11gen
Only problem is the actual implementation of pytorch for MPS, that get this: UserWarning: The operator 'aten::linalg_vector_norm' is not currently supported on the MPS backend and will fall back to run on the CPU.
I apologize for my question, but how noticeable is the change to float32?
I apologize for my question, but how noticeable is the change to float32?
I'm pretty confident that using higher precision, like float64
, will almost always give us tighter, more accurate results when we're smoothing out noise with a Gaussian diffusion algorithm. It's true that using higher precision can be a bit more computationally intensive, but the benefits are usually worth it. Plus, it's always nice to have the extra accuracy and stability in our results!
Could we set that as a parameter that defaults to 64 but write another paramter that is 32?
^ agree