pyro
pyro copied to clipboard
Device error occurred when using `AutoGuideList`
Guidelines
NOTE: Issues are for bugs and feature requests only. If you have a question about using Pyro or general modeling questions, please post it on the forum.
If you would like to address any minor bugs in the documentation or source, please feel free to contribute a Pull Request without creating an issue first.
Please tag the issue appropriately in the title e.g. [bug], [feature request], [discussion], etc.
Please provide the following details:
Issue Description
When using AutoGuideList
, The following error occurred: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!
Environment
For any bugs, please provide the following:
- Ubuntu 18.04, Python 3.7.13
- PyTorch 1.13
- Pyro version: 1.8.4(installed from dev branch)
Code Snippet
My code is very complex, so I can only provide some essential information about this error. First, the AutoGuideList
in my code is as follows:
guide = AutoGuideList(self._model)
guide.append(AutoDelta(poutine.block(self._model, expose_fn=lambda msg: msg["name"].startswith("phi_"))))
guide.append(
AutoLowRankMultivariateNormal(
poutine.block(self._model,
hide_fn=lambda msg: msg["name"].startswith("z_") or msg["name"].startswith("state_")
or msg["name"].startswith("phi_"))))
I implemented my model in a class named Model
. Then, where the error occurred is as follows:
with pyro.plate(f"browsed_card_{t}", max_browsed_cards_num_t, dim=-2,
device=self._device) as browsed_cards_plate:
with poutine.mask(mask=browsed_cards_plate.unsqueeze(-1) < browsed_cards_t.unsqueeze(0)):
z_b_t = pyro.sample(f"z_b_{t}",
pyro_dist.Categorical(theta_b_t),
infer={"enumerate": "parallel"})
...
I checked the values of these corresponding variables and found that browsed_cards_plate
is not on device cuda:1
. My model works well using AutoDelta
or a Custom Guide
. There are no other error messages for me to trace the source of the error under debug mode. So can you try to inspect this problem?
@hjnnjh can you provide a minimal example that reproduces the error?
To work around this you could try passing a custom create_plates
function to your guide, as in e.g. this unit test.
To fix this it might be sufficient to identify and add the right device
argument here: https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/autoguide/guides.py#L143
- name, full_size, dim=frame.dim, subsample_size=frame.size
+ name, full_size, dim=frame.dim, subsample_size=frame.size, device=...
@eb8680 Thanks for your reply! I'll check it and try to fix this error. Then I'll provide a minimal example that reproduces the error.
@eb8680 By adding the following code, the AutoGuideList
can work in my model properly now. Thanks for your essential advice!
def _create_plates(self):
motivation_plate = pyro.plate("motivations", self._hyper_params["M"], dim=-1, device=self._device)
user_plate = pyro.plate("users", self._data_dims["User"], dim=-1, device=self._device)
batched_user_plate = pyro.plate("batched_users",
self._data_dims["User"],
self._args.batch_size,
dim=-1,
device=self._device)
plates = [motivation_plate, user_plate, batched_user_plate]
for t in pyro.markov(range(self._data_dims["Session"].max())):
browsed_card_t_plate = pyro.plate(f"browsed_cards_{t}",
self._obs_data_size["Browsed"][t].max(),
dim=-2,
device=self._device)
clicked_card_t_plate = pyro.plate(f"clicked_cards_{t}",
self._obs_data_size["Clicked"][t].max(),
dim=-2,
device=self._device)
plates.append(browsed_card_t_plate)
plates.append(clicked_card_t_plate)
return plates
@eb8680 I just upload a minimal example here. Run this file and it will raise the same error. But when switching to AutoDelta
, it works fine.
@hjnnjh great, thanks!