Recurrent-Independent-Mechanisms icon indicating copy to clipboard operation
Recurrent-Independent-Mechanisms copied to clipboard

GroupLinearLayer should add "device" parameter

Open ildefons opened this issue 4 years ago • 2 comments

Hi RIM dev team,

Code fails when device = 'cuda' It can be easily solved adding an extra parameter "device" to all "Group" classes.

Thank you for the great RIM implementation, Ildefons

ildefons avatar Jan 08 '21 08:01 ildefons

HI @ildefons, could you share the error message with device='cuda'?

dido1998 avatar Jan 08 '21 09:01 dido1998

Error message:

RuntimeError Traceback (most recent call last) in 1 for x in xs: 2 print(1) ----> 3 hs, cs = rim_model(x, hs, cs)

~\anaconda3\envs\eg2\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result)

~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in forward(self, x, hs, cs) 249 250 # Compute input attention --> 251 inputs, mask = self.input_attention_mask(x, hs) 252 h_old = hs * 1.0 253 if cs is not None:

~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in input_attention_mask(self, x, h) 177 key_layer = self.key(x) 178 value_layer = self.value(x) --> 179 query_layer = self.query(h) 180 181 key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)

~\anaconda3\envs\eg2\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result)

~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in forward(self, x) 31 x = x.permute(1,0,2) 32 ---> 33 x = torch.bmm(x,self.w) 34 return x.permute(1,0,2) 35

RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat2' in call to _th_bmm

ildefons avatar Jan 08 '21 09:01 ildefons