pytorch-deform-conv-v2 icon indicating copy to clipboard operation
pytorch-deform-conv-v2 copied to clipboard

Solve specific GPU problem.

Open Hui-Xie opened this issue 6 years ago • 1 comments

When I hope the whole net run on the GPU2, instead of GPU0. Current code always has some tensors running on GPU0, and lead a sum operation p = p_0 + p_n + offset failed, which is not what I want.

Modified point:

   at _get_p_n and _get_p_0 functions, add device parameter. 

In file: https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py

Original code:

' def _get_p_n(self, N, dtype): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

    return p_n

def _get_p_0(self, h, w, N, dtype):
    p_0_x, p_0_y = torch.meshgrid(
        torch.arange(1, h*self.stride+1, self.stride),
        torch.arange(1, w*self.stride+1, self.stride))
    p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

    return p_0

def _get_p(self, offset, dtype): N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

    # (1, 2N, 1, 1)
    p_n = self._get_p_n(N, **dtype)**
    # (1, 2N, h, w)
    p_0 = self._get_p_0(h, w, N, **dtype)**
    p = p_0 + p_n + offset
    return p

'

Suggested modified code:

' def _get_p_n(self, N, device, dataType): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).to(device, dtype=dataType)

    return p_n


def _get_p_0(self, h, w, N, device, dataType):
    p_0_x, p_0_y = torch.meshgrid(
        torch.arange(1, h*self.stride+1, self.stride),
        torch.arange(1, w*self.stride+1, self.stride))
    p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0 = torch.cat([p_0_x, p_0_y], 1).to(device, dtype=dataType)

    return p_0

def _get_p(self, offset, dtype):
    N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

    # (1, 2N, 1, 1)
    p_n = self._get_p_n(N, offset.device, offset.dtype)
    # (1, 2N, h, w)
    p_0 = self._get_p_0(h, w, N, offset.device, offset.dtype)
    p = p_0 + p_n + offset
    return p

'

Hui-Xie avatar Aug 23 '19 15:08 Hui-Xie

When I hope the whole net run on the GPU2, instead of GPU0. Current code always has some tensors running on GPU0, and lead a sum operation p = p_0 + p_n + offset failed, which is not what I want.

Modified point:

   at _get_p_n and _get_p_0 functions, add device parameter. 

In file: https://github.com/4uiiurz1/pytorch-deform-conv-v2/blob/master/deform_conv_v2.py

Original code:

' def _get_p_n(self, N, dtype): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

    return p_n

def _get_p_0(self, h, w, N, dtype):
    p_0_x, p_0_y = torch.meshgrid(
        torch.arange(1, h*self.stride+1, self.stride),
        torch.arange(1, w*self.stride+1, self.stride))
    p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

    return p_0

def _get_p(self, offset, dtype): N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

    # (1, 2N, 1, 1)
    p_n = self._get_p_n(N, **dtype)**
    # (1, 2N, h, w)
    p_0 = self._get_p_0(h, w, N, **dtype)**
    p = p_0 + p_n + offset
    return p

'

Suggested modified code:

' def _get_p_n(self, N, device, dataType): p_n_x, p_n_y = torch.meshgrid( torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1), torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1)) # (2N, 1) p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0) p_n = p_n.view(1, 2*N, 1, 1).to(device, dtype=dataType)

    return p_n


def _get_p_0(self, h, w, N, device, dataType):
    p_0_x, p_0_y = torch.meshgrid(
        torch.arange(1, h*self.stride+1, self.stride),
        torch.arange(1, w*self.stride+1, self.stride))
    p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
    p_0 = torch.cat([p_0_x, p_0_y], 1).to(device, dtype=dataType)

    return p_0

def _get_p(self, offset, dtype):
    N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

    # (1, 2N, 1, 1)
    p_n = self._get_p_n(N, offset.device, offset.dtype)
    # (1, 2N, h, w)
    p_0 = self._get_p_0(h, w, N, offset.device, offset.dtype)
    p = p_0 + p_n + offset
    return p

'

Great work!

wcyjerry avatar Nov 04 '22 14:11 wcyjerry