lab
lab copied to clipboard
B.eye doesn't support tensors on cuda
eye
doesn't keep tensors on the gpu:
In [1]: import lab as B
In [2]: import lab.torch
In [3]: import torch
In [4]: B.eye(torch.ones((2, 2)).cuda())
Out[4]:
tensor([[1., 0.],
[0., 1.]]) # isn't on cuda
I traced this issue back from pinv
which fails due to this issue.
Hey @InfProbSciX, this behaviour is as intended, although admittedly it might not be most convenient default behaviour.
What you're really after is likely the following:
with B.on_device(x):
eye = B.eye(x)
If desired, we could change the behaviour of B.eye(x)
for a tensor x
so that the above in fact what happens if you just call B.eye(x)
.