score_sde_pytorch
score_sde_pytorch copied to clipboard
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
RuntimeError Traceback (most recent call last) Cell In[29], line 1 ----> 1 x, n = sampling_fn(score_model) 2 show_samples(x)
File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:407, in get_pc_sampler.
File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:341, in shared_predictor_update_fn(x, t, sde, model, predictor, probability_flow, continuous) 339 else: 340 predictor_obj = predictor(sde, score_fn, probability_flow) --> 341 return predictor_obj.update_fn(x, t)
File /workspace/pytorchcode/score_sde_pytorch-main/sampling.py:196, in ReverseDiffusionPredictor.update_fn(self, x, t) 195 def update_fn(self, x, t): --> 196 f, G = self.rsde.discretize(x, t) 197 z = torch.randn_like(x) 198 x_mean = x - f
File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:104, in SDE.reverse.
File /workspace/pytorchcode/score_sde_pytorch-main/sde_lib.py:251, in VESDE.discretize(self, x, t) 248 timestep = (t * (self.N - 1) / self.T).long() 249 sigma = self.discrete_sigmas.to(t.device)[timestep] 250 adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), --> 251 self.discrete_sigmas[timestep - 1].to(t.device)) 252 f = torch.zeros_like(x) 253 G = torch.sqrt(sigma ** 2 - adjacent_sigma ** 2)
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cpu)
Changing self.discrete_sigmas[timestep - 1].to(t.device)
to self.discrete_sigmas.to(t.device)[timestep - 1]
in this line of sde_lib.py seems to fix the problem.