FasterTransformer icon indicating copy to clipboard operation
FasterTransformer copied to clipboard

How convert weights from pytorch vit to tensorrt plugin?

Open HollrayChan opened this issue 2 years ago • 7 comments

Torch:1.10.2+cu113 TRT: 8.2.1.8 I train model with vit, I want to convert weights from pytorch vit to tensorrt. I wrote a function to transform the weights to adapt to the loadWeightsPtr of ViTPlugin.cpp, but there is a little problem,my diff is huge between torch and trt. here is my function of transform weight . Is there a special tool for conversion in FasterTransformer?

''' def vit_bin_file_th_2_fw(self, path): # path is the weight of pytorch def th2np(weights, conv=False, tp=False): """Possibly convert HWIO to OIHW.""" # """Possibly convert tf to torch.""" """Possibly convert 2310 to 0123.""" if conv and (not tp): return weights.permute(2, 3, 1, 0) # return weights.permute(3, 2, 1, 0) elif (not conv) and tp: return weights.t() elif (not conv) and (not tp): return weights

    assert 'bin' in path
    print('[INFO] weight is transform from torch_key, torch_weight to tf_key, torch_weight...')
    torch_key_weight = torch.load(path)
    tf_key_weight = {}
    tf_key_root = 'Transformer/'

    for k in torch_key_weight.keys():
        if 'cls_token' in k:
            tf_key_weight['cls'] = th2np(torch_key_weight[k])
        elif 'position_embeddings' in k:
            tf_key_weight['Transformer/posembed_input/pos_embedding'] = th2np(torch_key_weight[k])

        elif 'patch_embeddings.weight' in k:
            tf_key_weight['embedding/kernel'] = th2np(torch_key_weight[k], conv=True)
        elif 'patch_embeddings.bias' in k:
            tf_key_weight['embedding/bias'] = th2np(torch_key_weight[k])

        elif 'head.weight' in k:
            tf_key_weight['head/kernel'] = th2np(torch_key_weight[k], tp=True)
        elif 'head.bias' in k:
            tf_key_weight['head/bias'] = th2np(torch_key_weight[k], tp=True)
            
        elif 'encoder_norm.weight' in k:
            tf_key_weight[tf_key_root+'encoder_norm/scale'] = th2np(torch_key_weight[k])
        elif 'encoder_norm.bias' in k:
            tf_key_weight[tf_key_root+'encoder_norm/bias'] = th2np(torch_key_weight[k])
            
        elif 'layer' in k:
            k_list = k.split('.')
            layer_num = int(k_list[3])
            layer_name = '.'.join(k_list[4:])
            sub_tf_key_root = 'Transformer/encoderblock_'

            if layer_name == 'attention_norm.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = th2np(torch_key_weight[k])
            elif layer_name == 'attention_norm.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = th2np(torch_key_weight[k])

            elif layer_name == 'ffn_norm.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = th2np(torch_key_weight[k])
            elif layer_name == 'ffn_norm.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = th2np(torch_key_weight[k])

            elif layer_name == 'ffn.fc1.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = th2np(torch_key_weight[k], tp=True)
            elif layer_name == 'ffn.fc1.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = th2np(torch_key_weight[k], tp=True)
            elif layer_name == 'ffn.fc2.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = th2np(torch_key_weight[k], tp=True)
            elif layer_name == 'ffn.fc2.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = th2np(torch_key_weight[k], tp=True)
            
            elif 'attn' in layer_name:
                sub_layer_name = layer_name.split('.')[1]
                if 'weight' in layer_name:
                    if 'out' in layer_name:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
                        th2np(torch_key_weight[k].t().reshape(12, 64, self.hidden_size))
                    else:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
                        th2np(torch_key_weight[k].t().reshape(self.hidden_size, 12, 64))
                elif 'bias' in layer_name:
                    if 'out' in layer_name:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
                        th2np(torch_key_weight[k])
                    else:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
                        th2np(torch_key_weight[k].t().reshape(64, 12))
                        # th2np(torch_key_weight[k].reshape(12, 64))
                else:
                    print('[ERROR] attn wrong ? The name is {}'.format(k))
                    exit()
        else:
            print('[ERROR] layer wrong ? The name is {}'.format(k))
            exit()

    print('[INFO] source dict: {} / new dict: {}'.format(len(torch_key_weight), len(tf_key_weight)))
    save_path = path.strip('.bin')+'_tf.bin'
    # np.savez(tf_key_weight, save_path)
    torch.save(tf_key_weight, save_path)

'''

HollrayChan avatar Apr 27 '22 08:04 HollrayChan

My vit has been changed to customize the input size, the input of the image in the code is 384x128

HollrayChan avatar Apr 27 '22 08:04 HollrayChan

You can refer the example of https://github.com/NVIDIA/FasterTransformer/blob/main/docs/vit_guide.md#run-with-tensorrt-plugin and https://github.com/NVIDIA/FasterTransformer/blob/main/examples/tensorrt/vit/plugin_loader.py.

byshiue avatar Apr 27 '22 09:04 byshiue

Thx,I have referred to the above two examples, and the shape of the torch weight is consistent with ViT-B_16.npz, but avg diff : 1.1434973 max diff : 4.220703, I wonder if ViT-B_16.npz is from tf? I didn't find a tool to convert torch weights directly to trt_plugin.

In ViT-B_16.npz MultiHeadDotProductAttention_1/out/bias shape is 768, MultiHeadDotProductAttention_1/query/bias shape is (12, 64) I think it's a little strange.

HollrayChan avatar Apr 27 '22 09:04 HollrayChan

There may be some problem on th2np function. Try to add .contiguous() after reshape like

def th2np(weights, conv=False, tp=False):
    """Possibly convert HWIO to OIHW."""
    # """Possibly convert tf to torch."""
    """Possibly convert 2310 to 0123."""
    if conv and (not tp):
           return weights.permute(2, 3, 1, 0).contiguous()
     elif (not conv) and tp:
           return weights.t().contiguous()
     elif (not conv) and (not tp):
           return weights.contiguous()

byshiue avatar Apr 27 '22 12:04 byshiue

Yes, .contiguous() is required, and there are some bugs in MultiHeadDotProductAttention, some places don't need to be transposed, now it work. Here is my codes:)

    def th2np(weights, conv=False, tp=False):
        if conv and (not tp):
            return weights.permute(2, 3, 1, 0).contiguous()
        elif (not conv) and tp:
            return weights.t().contiguous()
        elif (not conv) and (not tp):
            return weights.contiguous()

def vit_bin_file_th_2_fw(self, path):
    assert 'bin' in path
    torch_key_weight = torch.load(path)
    tf_key_weight = {}
    tf_key_root = 'Transformer/'

    for k in torch_key_weight.keys():
        if 'cls_token' in k:
            tf_key_weight['cls'] = th2np(torch_key_weight[k])
        elif 'position_embeddings' in k:
            tf_key_weight['Transformer/posembed_input/pos_embedding'] = th2np(torch_key_weight[k])
        elif 'patch_embeddings.weight' in k:
            tf_key_weight['embedding/kernel'] = th2np(torch_key_weight[k], conv=True)
        elif 'patch_embeddings.bias' in k:
            tf_key_weight['embedding/bias'] = th2np(torch_key_weight[k])

        elif 'head.weight' in k:
            tf_key_weight['head/kernel'] = th2np(torch_key_weight[k], tp=True)
        elif 'head.bias' in k:
            tf_key_weight['head/bias'] = th2np(torch_key_weight[k], tp=True)

        elif 'encoder_norm.weight' in k:
            tf_key_weight[tf_key_root+'encoder_norm/scale'] = th2np(torch_key_weight[k])
        elif 'encoder_norm.bias' in k:
            tf_key_weight[tf_key_root+'encoder_norm/bias'] = th2np(torch_key_weight[k])

        elif 'layer' in k:
            k_list = k.split('.')
            layer_num = int(k_list[3])
            layer_name = '.'.join(k_list[4:])
            sub_tf_key_root = 'Transformer/encoderblock_'

            if layer_name == 'attention_norm.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = th2np(torch_key_weight[k])
            elif layer_name == 'attention_norm.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = th2np(torch_key_weight[k])


            elif layer_name == 'ffn_norm.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = th2np(torch_key_weight[k])
            elif layer_name == 'ffn_norm.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = th2np(torch_key_weight[k])

            elif layer_name == 'ffn.fc1.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = th2np(torch_key_weight[k], tp=True)
            elif layer_name == 'ffn.fc1.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = th2np(torch_key_weight[k], tp=True)
            elif layer_name == 'ffn.fc2.weight':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = th2np(torch_key_weight[k], tp=True)
            elif layer_name == 'ffn.fc2.bias':
                tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = th2np(torch_key_weight[k], tp=True)
            
            elif 'attn' in layer_name:
                sub_layer_name = layer_name.split('.')[1]
                if 'weight' in layer_name:
                    if 'out' in layer_name:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
                        th2np(torch_key_weight[k].view(self.hidden_size, 12, 64).permute(1,2,0))
                    else:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
                        th2np(torch_key_weight[k].t().view(self.hidden_size, 12, 64))
                elif 'bias' in layer_name:
                    if 'out' in layer_name:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
                        th2np(torch_key_weight[k])
                    else:
                        tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
                        th2np(torch_key_weight[k].view(12, 64))
                else:
                    print('[ERROR] attn wrong ? The name is {}'.format(k))
                    exit()
        else:
            print('[ERROR] layer wrong ? The name is {}'.format(k))
            exit()

    print('[INFO] source dict: {} / new dict: {}'.format(len(torch_key_weight), len(tf_key_weight)))
    save_path = path.strip('.bin')+'_tf.bin'
    torch.save(tf_key_weight, save_path)
    return save_path

HollrayChan avatar Apr 27 '22 14:04 HollrayChan

If I want to implement some custom layers, such as batchnorm, or ibn, but I haven't seen the relevant content in FasterTransformer, what projects could I refer to to implement it?

HollrayChan avatar Apr 28 '22 15:04 HollrayChan

https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html

byshiue avatar May 02 '22 03:05 byshiue

Close this bug because it is inactivated. Feel free to re-open this issue if you still have any problem.

byshiue avatar Sep 06 '22 01:09 byshiue