Recurrent-Independent-Mechanisms
Recurrent-Independent-Mechanisms copied to clipboard
GroupLinearLayer should add "device" parameter
Hi RIM dev team,
Code fails when device = 'cuda' It can be easily solved adding an extra parameter "device" to all "Group" classes.
Thank you for the great RIM implementation, Ildefons
HI @ildefons, could you share the error message with device='cuda'?
Error message:
RuntimeError Traceback (most recent call last)
~\anaconda3\envs\eg2\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result)
~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in forward(self, x, hs, cs) 249 250 # Compute input attention --> 251 inputs, mask = self.input_attention_mask(x, hs) 252 h_old = hs * 1.0 253 if cs is not None:
~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in input_attention_mask(self, x, h) 177 key_layer = self.key(x) 178 value_layer = self.value(x) --> 179 query_layer = self.query(h) 180 181 key_layer = self.transpose_for_scores(key_layer, self.num_input_heads, self.input_key_size)
~\anaconda3\envs\eg2\lib\site-packages\torch\nn\modules\module.py in call(self, *input, **kwargs) 548 result = self._slow_forward(*input, **kwargs) 549 else: --> 550 result = self.forward(*input, **kwargs) 551 for hook in self._forward_hooks.values(): 552 hook_result = hook(self, input, result)
~\OneDrive\Documentos\YK\eg\Recurrent-Independent-Mechanisms\RIM.py in forward(self, x) 31 x = x.permute(1,0,2) 32 ---> 33 x = torch.bmm(x,self.w) 34 return x.permute(1,0,2) 35
RuntimeError: Expected object of device type cuda but got device type cpu for argument #2 'mat2' in call to _th_bmm