pymde icon indicating copy to clipboard operation
pymde copied to clipboard

CUDA not used when device=torch.device("cuda") ?

Open sjfleming opened this issue 1 year ago • 0 comments

I have here a little demo that appears to show that pymde.preserve_neighbors uses cuda only if the user specifies device="cuda" and not if the user specifies device=torch.device("cuda"). While I see that the type hint is str for that input, I still think this behavior is unexpected. Example:

import pymde
import torch

device = "cuda"

mnist = pymde.datasets.MNIST()
embedding = pymde.preserve_neighbors(mnist.data, embedding_dim=2, verbose=True, device=device).embed()
pymde.plot(embedding, color_by=mnist.attributes['digits'])

This works fine and seems to run on CUDA.

But this

import pymde
import torch

device = torch.device("cuda")

mnist = pymde.datasets.MNIST()
embedding = pymde.preserve_neighbors(mnist.data, embedding_dim=2, verbose=True, device=device).embed()
pymde.plot(embedding, color_by=mnist.attributes['digits'])

leads to ArpackError: ARPACK error -9: Starting vector is zero. (see #82 ... this error appears when arpack is used on cpu), so this seems not to be running on CUDA.

I think the potential fix would be to change this line https://github.com/cvxgrp/pymde/blob/40472bc47a6d4b53b2a196ed2b6741471a04e830/pymde/recipes.py#L366 to allow for torch.device("cuda") as well.

If you think I'm on the right track, I'd be happy to write a PR.

sjfleming avatar Feb 10 '25 22:02 sjfleming