lab icon indicating copy to clipboard operation
lab copied to clipboard

B.eye doesn't support tensors on cuda

Open InfProbSciX opened this issue 1 year ago • 1 comments

eye doesn't keep tensors on the gpu:

In [1]: import lab as B

In [2]: import lab.torch

In [3]: import torch

In [4]: B.eye(torch.ones((2, 2)).cuda())
Out[4]: 
tensor([[1., 0.],
        [0., 1.]])  # isn't on cuda

I traced this issue back from pinv which fails due to this issue.

InfProbSciX avatar Mar 16 '23 18:03 InfProbSciX

Hey @InfProbSciX, this behaviour is as intended, although admittedly it might not be most convenient default behaviour.

What you're really after is likely the following:

with B.on_device(x):
    eye = B.eye(x)

If desired, we could change the behaviour of B.eye(x) for a tensor x so that the above in fact what happens if you just call B.eye(x).

wesselb avatar Mar 20 '23 18:03 wesselb