hetvae icon indicating copy to clipboard operation
hetvae copied to clipboard

Fix propagation of 'device' setting from 'models.load_network' to `la…

Open egorssed opened this issue 2 years ago • 0 comments

Hey, thanks for your paper and code!

I was trying to run it on my Mac m1, that has cpu only, but setting device with load_network lead to error.

>>>device = 'cpu'
>>>net = models.load_network(args, dim, union_tp,device=device)
>>>qz, hidden = net.encode(context_x, context_y)
AssertionError: Torch not compiled with CUDA enabled

This is because net.to(device) is just not enough. You see, layer.TimeEmbedding explicitly puts data to self.device, and setting net.to(device) doesn't change the field of class layer.TimeEmbedding.device https://github.com/reml-lab/hetvae/blob/78296e28f0daa592b4bfa63076d23fd6d75b1c9f/src/layers.py#L32

The solution is just propagate the field self.device straight through the models.load_network to layer.TimeEmbedding as it is done in the pull request.

Otherwise just get rid of explicit moving to device in layer.TimeEmbedding if it is not needed anymore.

egorssed avatar Jun 22 '22 13:06 egorssed