POT
POT copied to clipboard
ot.solve uses GPU even though tensors are on CPU?
Describe the bug
Running ot.solve
with tensors on the CPU allows memory on the GPU (thisis documented in get_backend_list
) but also seems to use the GPU, as the Watts are increasing. See attached screencast :
Screencast from 08-03-2024 11:24:44.webm
Is it normal?
Script
import torch
import ot
n_samples = 5_000
x = torch.randn(n_samples, 2)
y = torch.randn(n_samples, 2)
a = torch.rand(n_samples)
a /= a.sum()
b = torch.rand(n_samples)
b /= b.sum()
M = ot.dist(x, y)
res = ot.solve(M, a, b, reg=0.1, reg_type="entropy")