POT
POT copied to clipboard
ot.solve uses GPU even though tensors are on CPU?
Describe the bug
Running ot.solve
with tensors on the CPU allows memory on the GPU (thisis documented in get_backend_list
) but also seems to use the GPU, as the Watts are increasing. See attached screencast :
Screencast from 08-03-2024 11:24:44.webm
Is it normal?
Script
import torch
import ot
n_samples = 5_000
x = torch.randn(n_samples, 2)
y = torch.randn(n_samples, 2)
a = torch.rand(n_samples)
a /= a.sum()
b = torch.rand(n_samples)
b /= b.sum()
M = ot.dist(x, y)
res = ot.solve(M, a, b, reg=0.1, reg_type="entropy")
Hello @mathurinm ,
It might relate to the following closed issue https://github.com/PythonOT/POT/issues/516 , are you using POT >= 0.9.2 ?
Thanks for the quick reply Cedric, this is happening with the latest dev version :
In [1]: import ot
ot.__version__
In [2]: ot.__version__
Out[2]: '0.9.3dev'
Contrary to #516, the memory consumption starts when calling a function from ot
, like dist
(when the backend is determined, I guess). Importing ot
does not use the GPU indeed.
Let me know if I can provide additional info
Ok thank you for the feedback, we will look into this and go back to you asap.