hetvae
hetvae copied to clipboard
Fix propagation of 'device' setting from 'models.load_network' to `la…
Hey, thanks for your paper and code!
I was trying to run it on my Mac m1, that has cpu only, but setting device with load_network lead to error.
>>>device = 'cpu'
>>>net = models.load_network(args, dim, union_tp,device=device)
>>>qz, hidden = net.encode(context_x, context_y)
AssertionError: Torch not compiled with CUDA enabled
This is because net.to(device)
is just not enough. You see, layer.TimeEmbedding
explicitly puts data to self.device
,
and setting net.to(device)
doesn't change the field of class layer.TimeEmbedding.device
https://github.com/reml-lab/hetvae/blob/78296e28f0daa592b4bfa63076d23fd6d75b1c9f/src/layers.py#L32
The solution is just propagate the field self.device
straight through the models.load_network
to layer.TimeEmbedding
as it is done in the pull request.
Otherwise just get rid of explicit moving to device in layer.TimeEmbedding
if it is not needed anymore.