ailia-models icon indicating copy to clipboard operation
ailia-models copied to clipboard

ADD GLIP

Open kyakuno opened this issue 2 years ago • 4 comments

Grounded Language-Image Pre-training https://github.com/microsoft/GLIP MIT

kyakuno avatar May 21 '22 11:05 kyakuno

  • backbone.onnx

○ maskrcnn_benchmark/modeling/detector/generalized_vl_rcnn.py

class GeneralizedVLRCNN(nn.Module):
    ...
    def forward(...):
        ...
        if 'vl' in self.cfg.MODEL.SWINT.VERSION:
            ...
        else:
             visual_features = self.backbone(images.tensors)

class GeneralizedVLRCNN(nn.Module):
    ...
    def forward(...):
        ...
        if 'vl' in self.cfg.MODEL.SWINT.VERSION:
            ...
        else:
            print("------>")
            from torch.autograd import Variable
            x = Variable(images.tensors)
            torch.onnx.export(
                self.backbone, x, 'backbone.onnx',
                input_names=["images"],
                output_names=["out0", "out1", "out2", "out3", "out4"],
                dynamic_axes={'images' : {0: 'n', 2 : 'h', 3 : 'w'}, 'out0' : {0: 'n', 2 : 'h0', 3 : 'w0'}, 'out1' : {0: 'n', 2 : 'h1', 3 : 'w1'}, 'out2' : {0: 'n', 2 : 'h2', 3 : 'w2'}, 'out3' : {0: 'n', 2 : 'h3', 3 : 'w3'}, 'out4' : {0: 'n', 2 : 'h4', 3 : 'w4'}},
                verbose=False, opset_version=12
            )
            print("<------")
class SwinTransformerBlock(nn.Module):
     ...
    def forward(self, x, mask_matrix): 
        ...
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        ...
        if pad_r > 0 or pad_b > 0:
            x = x[:, :H, :W, :].contiguous()

class SwinTransformerBlock(nn.Module):
     ...
    def forward(self, x, mask_matrix): 
        ...
        if self.shift_size > 0:
            shifted_x = torch.cat([x[:, self.shift_size:, :, :], x[:, :self.shift_size, :, :]], dim=1)
            shifted_x = torch.cat([shifted_x[:, :, self.shift_size:, :], shifted_x[:, :, :self.shift_size, :]], dim=2)
        ...
        if True:
            x = x[:, :H, :W, :].contiguous()

○ maskrcnn_benchmark/modeling/backbone/swint.py

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))

def window_reverse(windows, window_size, H, W):
    B = (windows.shape[0] / (H * W / window_size / window_size)).type(torch.int64)

ooe1123 avatar Jul 09 '22 05:07 ooe1123

  • rpn

○ maskrcnn_benchmark/modeling/rpn/vldyhead.py

class VLDyHead(torch.nn.Module):
    ...
    def forward(self, feature0, feature1, feature2, feature3, feature4, hidden, masks, embedding):
        ...
        return logits, bbox_reg, centerness, t_logits, proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features

class VLDyHead(torch.nn.Module):
    ...
    def forward(self, feature0, feature1, feature2, feature3, feature4, hidden, masks, embedding):
        ...
        x = (feature0, feature1, feature2, feature3, feature4)
        language_dict_features = {
            "hidden": hidden,
            "masks": masks,
        }
        swint_feature_c4 = None
        ...
        return logits[0], logits[1], logits[2], logits[3], logits[4], bbox_reg[0], bbox_reg[1], bbox_reg[2], bbox_reg[3], bbox_reg[4], centerness[0], centerness[1], centerness[2], centerness[3], centerness[4], dot_product_logits[0], dot_product_logits[1], dot_product_logits[2], dot_product_logits[3], dot_product_logits[4]
class VLDyHeadModule(torch.nn.Module):
    ...
    def forward(...):
        ...
        box_cls, box_regression, centerness, token_logits, \
        proj_tokens, contrastive_logits, dot_product_logits, mlm_logits, shallow_img_emb_feats, fused_visual_features = self.head(features,
                                                                        language_dict_features,
                                                                        embedding,
                                                                        swint_feature_c4
                                                                        )

class VLDyHeadModule(torch.nn.Module):
    ...
    def forward(...):
        ...
        print("------>")
        from torch.autograd import Variable
        xx = (Variable(features[0]), Variable(features[1]), Variable(features[2]), Variable(features[3]), Variable(features[4]), Variable(language_dict_features["hidden"]), Variable(language_dict_features["masks"]), Variable(embedding))
        torch.onnx.export(
            self.head, xx, 'rpn.onnx',
            input_names=["feat0", "feat1", "feat2", "feat3", "feat4", "hidden", "masks", "embedding"],
            output_names=["logits0", "logits1", "logits2", "logits3", "logits4", "bbox_reg0", "bbox_reg1", "bbox_reg2", "bbox_reg3", "bbox_reg4", "centerness0", "centerness1", "centerness2", "centerness3", "centerness4", "dot_product_logits0", "dot_product_logits1", "dot_product_logits2", "dot_product_logits3", "dot_product_logits4"],
            dynamic_axes={
            'feat0' : {0 : 'n', 2 : 'h0', 3 : 'w0'}, 'feat1' : {0 : 'n', 2 : 'h1', 3 : 'w1'}, 'feat2' : {0 : 'n', 2 : 'h2', 3 : 'w2'}, 'feat3' : {0 : 'n', 2 : 'h3', 3 : 'w3'}, 'feat4' : {0 : 'n', 2 : 'h4', 3 : 'w4'}, 'hidden' : {0 : 'n'}, 'masks' : {0 : 'n'}, 'embedding' : {0 : 'n'},
            'logits0' : {0 : 'n', 2 : 'h0', 3 : 'w0'}, 'logits1' : {0 : 'n', 2 : 'h1', 3 : 'w1'}, 'logits2' : {0 : 'n', 2 : 'h2', 3 : 'w2'}, 'logits3' : {0 : 'n', 2 : 'h3', 3 : 'w3'}, 'logits4' : {0 : 'n', 2 : 'h4', 3 : 'w4'},
            'bbox_reg0' : {0 : 'n', 2 : 'h0', 3 : 'w0'}, 'bbox_reg1' : {0 : 'n', 2 : 'h1', 3 : 'w1'}, 'bbox_reg2' : {0 : 'n', 2 : 'h2', 3 : 'w2'}, 'bbox_reg3' : {0 : 'n', 2 : 'h3', 3 : 'w3'}, 'bbox_reg4' : {0 : 'n', 2 : 'h4', 3 : 'w4'},
            'centerness0' : {0 : 'n', 2 : 'h0', 3 : 'w0'}, 'centerness1' : {0 : 'n', 2 : 'h1', 3 : 'w1'}, 'centerness2' : {0 : 'n', 2 : 'h2', 3 : 'w2'}, 'centerness3' : {0 : 'n', 2 : 'h3', 3 : 'w3'}, 'centerness4' : {0 : 'n', 2 : 'h4', 3 : 'w4'},
            'dot_product_logits0' : {0 : 'n', 1 : 'l0'}, 'dot_product_logits1' : {0 : 'n', 1 : 'l1'}, 'dot_product_logits2' : {0 : 'n', 1 : 'l2'}, 'dot_product_logits3' : {0 : 'n', 1 : 'l3'}, 'dot_product_logits4' : {0 : 'n', 1 : 'l4'},
            },
            verbose=False, opset_version=12
        )
        print("<------")

○ maskrcnn_benchmark/layers/deform_conv.py

class ModulatedDeformConv(nn.Module):
    def __init__(...):
        ...

    def forward(self, input, offset, mask):
        return modulated_deform_conv(
            input, offset, mask, self.weight, self.bias, self.stride,
            self.padding, self.dilation, self.groups, self.deformable_groups)

class ModulatedDeformConv(nn.Module):
    def __init__(...):
        ...
        self.deform_conv = DeformConv2d(
            self.in_channels, self.out_channels, kernel_size=3, 
            padding=self.padding, stride=self.stride, bias=True, modulation=True).cuda()
        self.deform_conv.conv.weight = self.weight
        self.deform_conv.conv.bias = self.bias

    def forward(self, input, offset, mask):
        return self.deform_conv(input, offset, mask)

ooe1123 avatar Jul 10 '22 13:07 ooe1123

○ deform_conv

class DeformConv2d(nn.Module):
    def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, bias=None, modulation=False):
        """
        Args:
            modulation (bool, optional): If True, Modulated Defomable Convolution (Deformable ConvNets v2).
        """
        super(DeformConv2d, self).__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride
        self.zero_padding = nn.ZeroPad2d(padding)
        self.conv = nn.Conv2d(inc, outc, kernel_size=kernel_size, stride=kernel_size, bias=bias)
        self.modulation = modulation

    @staticmethod
    def _set_lr(module, grad_input, grad_output):
        grad_input = (grad_input[i] * 0.1 for i in range(len(grad_input)))
        grad_output = (grad_output[i] * 0.1 for i in range(len(grad_output)))

    def forward(self, x, offset, mask):
        m = mask

        dtype = offset.data.type()
        ks = self.kernel_size
        N = offset.size(1) // 2

        b, _, h, w = x.shape
        if h < offset.size(2):
            n = b * ks*ks*2 * h * w
            offset = offset.flatten()[:n].reshape(1, ks*ks*2, h, w)
            n = b * ks*ks * h * w
            m = m.flatten()[:n].reshape(1, ks*ks, h, w)

        x1 = offset[:, :N*2:2, :, :]
        x2 = offset[:, 1:N*2:2, :, :]
        offset = torch.cat([x1, x2], dim=1)

        if self.padding:
            x = self.zero_padding(x)

        # (b, 2N, h, w)
        p = self._get_p(offset, dtype)

        # (b, h, w, 2N)
        p = p.contiguous().permute(0, 2, 3, 1)
        q_lt = p.detach().floor()
        q_rb = q_lt + 1

        q_lt = torch.cat([torch.clamp(q_lt[..., :N], 0, x.size(2)-1), torch.clamp(q_lt[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_rb = torch.cat([torch.clamp(q_rb[..., :N], 0, x.size(2)-1), torch.clamp(q_rb[..., N:], 0, x.size(3)-1)], dim=-1).long()
        q_lb = torch.cat([q_lt[..., :N], q_rb[..., N:]], dim=-1)
        q_rt = torch.cat([q_rb[..., :N], q_lt[..., N:]], dim=-1)

        # clip p
        p = torch.cat([torch.clamp(p[..., :N], 0, x.size(2)-1), torch.clamp(p[..., N:], 0, x.size(3)-1)], dim=-1)

        # bilinear kernel (b, h, w, N)
        g_lt = (1 + (q_lt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_lt[..., N:].type_as(p) - p[..., N:]))
        g_rb = (1 - (q_rb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_rb[..., N:].type_as(p) - p[..., N:]))
        g_lb = (1 + (q_lb[..., :N].type_as(p) - p[..., :N])) * (1 - (q_lb[..., N:].type_as(p) - p[..., N:]))
        g_rt = (1 - (q_rt[..., :N].type_as(p) - p[..., :N])) * (1 + (q_rt[..., N:].type_as(p) - p[..., N:]))

        # (b, c, h, w, N)
        x_q_lt = self._get_x_q(x, q_lt, N)
        x_q_rb = self._get_x_q(x, q_rb, N)
        x_q_lb = self._get_x_q(x, q_lb, N)
        x_q_rt = self._get_x_q(x, q_rt, N)

        # (b, c, h, w, N)
        x_offset = g_lt.unsqueeze(dim=1) * x_q_lt + \
                   g_rb.unsqueeze(dim=1) * x_q_rb + \
                   g_lb.unsqueeze(dim=1) * x_q_lb + \
                   g_rt.unsqueeze(dim=1) * x_q_rt

        # modulation
        if self.modulation:
            m = m.contiguous().permute(0, 2, 3, 1)
            m = m.unsqueeze(dim=1)
            m = torch.cat([m for _ in range(x_offset.size(1))], dim=1)
            x_offset *= m

        x_offset = self._reshape_x_offset(x_offset, ks)

        out = self.conv(x_offset)

        return out

    def _get_p_n(self, N, dtype):
        p_n_x, p_n_y = torch.meshgrid(
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1),
            torch.arange(-(self.kernel_size-1)//2, (self.kernel_size-1)//2+1))
        # (2N, 1)
        p_n = torch.cat([torch.flatten(p_n_x), torch.flatten(p_n_y)], 0)
        p_n = p_n.view(1, 2*N, 1, 1).type(dtype)

        return p_n

    def _get_p_0(self, h, w, N, dtype):
        p_0_x, p_0_y = torch.meshgrid(
            torch.arange(1, h*self.stride+1, self.stride),
            torch.arange(1, w*self.stride+1, self.stride))
        p_0_x = torch.flatten(p_0_x).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0_y = torch.flatten(p_0_y).view(1, 1, h, w).repeat(1, N, 1, 1)
        p_0 = torch.cat([p_0_x, p_0_y], 1).type(dtype)

        return p_0

    def _get_p(self, offset, dtype):
        N, h, w = offset.size(1)//2, offset.size(2), offset.size(3)

        # (1, 2N, 1, 1)
        p_n = self._get_p_n(N, dtype)
        # (1, 2N, h, w)
        p_0 = self._get_p_0(h, w, N, dtype)
        p = p_0 + p_n + offset
        return p

    def _get_x_q(self, x, q, N):
        b, h, w, _ = q.size()
        padded_w = x.size(3)
        c = x.size(1)
        # (b, c, h*w)
        x = x.contiguous().view(b, c, -1)

        # (b, h, w, N)
        index = q[..., :N]*padded_w + q[..., N:]  # offset_x*w + offset_y
        # (b, c, h*w*N)
        index = index.contiguous().unsqueeze(dim=1).expand(-1, c, -1, -1, -1).contiguous().view(b, c, -1)

        x_offset = x.gather(dim=-1, index=index).contiguous().view(b, c, h, w, N)

        return x_offset

    @staticmethod
    def _reshape_x_offset(x_offset, ks):
        b, c, h, w, N = x_offset.size()
        x_offset = torch.cat([x_offset[..., s:s+ks].contiguous().view(b, c, h, w*ks) for s in range(0, N, ks)], dim=-1)
        x_offset = x_offset.contiguous().view(b, c, h*ks, w*ks)

        return x_offset

ooe1123 avatar Jul 12 '22 14:07 ooe1123

○ BertEncoder

[参考] https://huggingface.co/docs/transformers/serialization#selecting-features-for-different-model-topologies

  1. エクスポートモジュールのインストール pip install transformers[onnx]
  2. モデル保存 ○ maskrcnn_benchmark/engine/predictor_glip.py
class GLIPDemo(object):
    def compute_prediction(self, original_image, original_caption, custom_entity = None):
        ...
        if isinstance(original_caption, list):
            ...
            tokenized = self.tokenizer([caption_string], return_tensors="pt")

class GLIPDemo(object):
    def compute_prediction(self, original_image, original_caption, custom_entity = None):
        ...
        if isinstance(original_caption, list):
            ...
            tokenized = self.tokenizer([caption_string], return_tensors="pt")
            tokenizer.save_pretrained("bert-checkpoint")

○ maskrcnn_benchmark/modeling/language_backbone/bert_model.py

class BertEncoder(nn.Module):
    def __init__(self, cfg):
        ...
        if self.bert_name == "bert-base-uncased":
            config = BertConfig.from_pretrained(self.bert_name)

    def forward(self, x):
        ...
        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
            outputs = self.model(...)

class BertEncoder(nn.Module):
    def __init__(self, cfg):
        ...
        if self.bert_name == "bert-base-uncased":
            config = BertConfig.from_pretrained(self.bert_name, output_hidden_states=True)

    def forward(self, x):
        ...
        if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS:
            outputs = self.model(...)
            self.model.save_pretrained("bert-checkpoint")
  1. 保存したモデルからエクスポート python -m transformers.onnx --model=bert-checkpoint onnx/

ooe1123 avatar Jul 12 '22 15:07 ooe1123