pytorch-deform-conv-v2
pytorch-deform-conv-v2 copied to clipboard
Solve specific GPU problem.
When I hope the whole net run on the GPU2, instead of GPU0. Current code always has some tensors running on GPU0, and lead a sum operation p = p_0 + p_n + offset failed, which is not what I want.
Modified point:
at _get_p_n and _get_p_0 functions, add device parameter.
In file: https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py
Original code:
' def _get_p_n(self, N, dtype): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
return p_n
def _get_p_0(self, h, w, N, dtype):
p_0_x, p_0_y = torch.meshgrid(
torch.arange(1, h*self.stride+1, self.stride),
torch.arange(1, w*self.stride+1, self.stride))
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)
return p_0
def _get_p(self, offset, dtype): N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
# (1, 2N, 1, 1)
p_n = self._get_p_n(N, **dtype)**
# (1, 2N, h, w)
p_0 = self._get_p_0(h, w, N, **dtype)**
p = p_0 + p_n + offset
return p
'
Suggested modified code:
' def _get_p_n(self, N, device, dataType): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).to(device, dtype=dataType)
return p_n
def _get_p_0(self, h, w, N, device, dataType):
p_0_x, p_0_y = torch.meshgrid(
torch.arange(1, h*self.stride+1, self.stride),
torch.arange(1, w*self.stride+1, self.stride))
p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
p_0 = torch.cat([p_0_x, p_0_y], 1).to(device, dtype=dataType)
return p_0
def _get_p(self, offset, dtype):
N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
# (1, 2N, 1, 1)
p_n = self._get_p_n(N, offset.device, offset.dtype)
# (1, 2N, h, w)
p_0 = self._get_p_0(h, w, N, offset.device, offset.dtype)
p = p_0 + p_n + offset
return p
'
When I hope the whole net run on the GPU2, instead of GPU0. Current code always has some tensors running on GPU0, and lead a sum operation p = p_0 + p_n + offset failed, which is not what I want.
Modified point:
at _get_p_n and _get_p_0 functions, add device parameter.In file: https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py
Original code:
' def _get_p_n(self, N, dtype): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).type(dtype)
return p_n def _get_p_0(self, h, w, N, dtype): p_0_x, p_0_y = torch.meshgrid( torch.arange(1, h*self.stride+1, self.stride), torch.arange(1, w*self.stride+1, self.stride)) p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype) return p_0def _get_p(self, offset, dtype): N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)
# (1, 2N, 1, 1) p_n = self._get_p_n(N, **dtype)** # (1, 2N, h, w) p_0 = self._get_p_0(h, w, N, **dtype)** p = p_0 + p_n + offset return p'
Suggested modified code:
' def _get_p_n(self, N, device, dataType): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).to(device, dtype=dataType)
return p_n def _get_p_0(self, h, w, N, device, dataType): p_0_x, p_0_y = torch.meshgrid( torch.arange(1, h*self.stride+1, self.stride), torch.arange(1, w*self.stride+1, self.stride)) p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1) p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1) p_0 = torch.cat([p_0_x, p_0_y], 1).to(device, dtype=dataType) return p_0 def _get_p(self, offset, dtype): N, h, w = offset.size(1)//2, offset.size(2), offset.size(3) # (1, 2N, 1, 1) p_n = self._get_p_n(N, offset.device, offset.dtype) # (1, 2N, h, w) p_0 = self._get_p_0(h, w, N, offset.device, offset.dtype) p = p_0 + p_n + offset return p'
Great work!