pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Device error occurred when using `AutoGuideList`

Open hjnnjh opened this issue 1 year ago • 5 comments

Guidelines

NOTE: Issues are for bugs and feature requests only. If you have a question about using Pyro or general modeling questions, please post it on the forum.

If you would like to address any minor bugs in the documentation or source, please feel free to contribute a Pull Request without creating an issue first.

Please tag the issue appropriately in the title e.g. [bug], [feature request], [discussion], etc.

Please provide the following details:

Issue Description

When using AutoGuideList, The following error occurred: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cpu!

Environment

For any bugs, please provide the following:

  • Ubuntu 18.04, Python 3.7.13
  • PyTorch 1.13
  • Pyro version: 1.8.4(installed from dev branch)

Code Snippet

My code is very complex, so I can only provide some essential information about this error. First, the AutoGuideList in my code is as follows:

guide = AutoGuideList(self._model)
            guide.append(AutoDelta(poutine.block(self._model, expose_fn=lambda msg: msg["name"].startswith("phi_"))))
            guide.append(
                AutoLowRankMultivariateNormal(
                    poutine.block(self._model,
                                  hide_fn=lambda msg: msg["name"].startswith("z_") or msg["name"].startswith("state_")
                                  or msg["name"].startswith("phi_"))))

I implemented my model in a class named Model. Then, where the error occurred is as follows:

with pyro.plate(f"browsed_card_{t}", max_browsed_cards_num_t, dim=-2,
                                    device=self._device) as browsed_cards_plate:
                        with poutine.mask(mask=browsed_cards_plate.unsqueeze(-1) < browsed_cards_t.unsqueeze(0)):
                            z_b_t = pyro.sample(f"z_b_{t}",
                                                pyro_dist.Categorical(theta_b_t),
                                                infer={"enumerate": "parallel"})
...

I checked the values of these corresponding variables and found that browsed_cards_plate is not on device cuda:1. My model works well using AutoDelta or a Custom Guide. There are no other error messages for me to trace the source of the error under debug mode. So can you try to inspect this problem?

hjnnjh avatar May 18 '23 12:05 hjnnjh

@hjnnjh can you provide a minimal example that reproduces the error?

To work around this you could try passing a custom create_plates function to your guide, as in e.g. this unit test.

To fix this it might be sufficient to identify and add the right device argument here: https://github.com/pyro-ppl/pyro/blob/dev/pyro/infer/autoguide/guides.py#L143

-                         name, full_size, dim=frame.dim, subsample_size=frame.size
+                         name, full_size, dim=frame.dim, subsample_size=frame.size, device=...

eb8680 avatar May 18 '23 18:05 eb8680

@eb8680 Thanks for your reply! I'll check it and try to fix this error. Then I'll provide a minimal example that reproduces the error.

hjnnjh avatar May 20 '23 09:05 hjnnjh

@eb8680 By adding the following code, the AutoGuideList can work in my model properly now. Thanks for your essential advice!

def _create_plates(self):
        motivation_plate = pyro.plate("motivations", self._hyper_params["M"], dim=-1, device=self._device)
        user_plate = pyro.plate("users", self._data_dims["User"], dim=-1, device=self._device)
        batched_user_plate = pyro.plate("batched_users",
                                        self._data_dims["User"],
                                        self._args.batch_size,
                                        dim=-1,
                                        device=self._device)
        plates = [motivation_plate, user_plate, batched_user_plate]
        for t in pyro.markov(range(self._data_dims["Session"].max())):
            browsed_card_t_plate = pyro.plate(f"browsed_cards_{t}",
                                              self._obs_data_size["Browsed"][t].max(),
                                              dim=-2,
                                              device=self._device)
            clicked_card_t_plate = pyro.plate(f"clicked_cards_{t}",
                                              self._obs_data_size["Clicked"][t].max(),
                                              dim=-2,
                                              device=self._device)
            plates.append(browsed_card_t_plate)
            plates.append(clicked_card_t_plate)
        return plates

hjnnjh avatar May 23 '23 12:05 hjnnjh

@eb8680 I just upload a minimal example here. Run this file and it will raise the same error. But when switching to AutoDelta, it works fine.

hjnnjh avatar May 23 '23 14:05 hjnnjh

@hjnnjh great, thanks!

eb8680 avatar May 24 '23 16:05 eb8680