EfficientNet-PyTorch-3D
EfficientNet-PyTorch-3D copied to clipboard
RuntimeError: expected scalar type Double but found Float
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_18/944444650.py in
/tmp/ipykernel_18/2880739147.py in train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs) 10 images, classes = item 11 optimizer.zero_grad() ---> 12 output = model(images.double()) 13 loss = criterion(output, classes) 14 accelerator.backward(loss)
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []
/tmp/ipykernel_18/791916627.py in forward(self, x) 6 def forward(self, x): 7 # x = nn.functional.interpolate(x, size=(224,224,224), mode='trilinear') ----> 8 x = self.effnet(x) 9 return x
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/efficientnet_pytorch_3d/model.py in forward(self, inputs) 189 bs = inputs.size(0) 190 # Convolution layers --> 191 x = self.extract_features(inputs) 192 193 if self._global_params.include_top:
/opt/conda/lib/python3.7/site-packages/efficientnet_pytorch_3d/model.py in extract_features(self, inputs) 171 172 # Stem --> 173 x = self._swish(self._bn0(self._conv_stem(inputs))) 174 175 # Blocks
/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1109 or _global_forward_hooks or _global_forward_pre_hooks): -> 1110 return forward_call(*input, **kwargs) 1111 # Do not call functions when jit is used 1112 full_backward_hooks, non_full_backward_hooks = [], []
/opt/conda/lib/python3.7/site-packages/efficientnet_pytorch_3d/utils.py in forward(self, x) 144 def forward(self, x): 145 x = self.static_padding(x) --> 146 x = F.conv3d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) 147 return x 148
RuntimeError: expected scalar type Double but found Float