FasterTransformer
FasterTransformer copied to clipboard
How convert weights from pytorch vit to tensorrt plugin?
Torch:1.10.2+cu113 TRT: 8.2.1.8 I train model with vit, I want to convert weights from pytorch vit to tensorrt. I wrote a function to transform the weights to adapt to the loadWeightsPtr of ViTPlugin.cpp, but there is a little problem,my diff is huge between torch and trt. here is my function of transform weight . Is there a special tool for conversion in FasterTransformer?
''' def vit_bin_file_th_2_fw(self, path): # path is the weight of pytorch def th2np(weights, conv=False, tp=False): """Possibly convert HWIO to OIHW.""" # """Possibly convert tf to torch.""" """Possibly convert 2310 to 0123.""" if conv and (not tp): return weights.permute(2, 3, 1, 0) # return weights.permute(3, 2, 1, 0) elif (not conv) and tp: return weights.t() elif (not conv) and (not tp): return weights
assert 'bin' in path
print('[INFO] weight is transform from torch_key, torch_weight to tf_key, torch_weight...')
torch_key_weight = torch.load(path)
tf_key_weight = {}
tf_key_root = 'Transformer/'
for k in torch_key_weight.keys():
if 'cls_token' in k:
tf_key_weight['cls'] = th2np(torch_key_weight[k])
elif 'position_embeddings' in k:
tf_key_weight['Transformer/posembed_input/pos_embedding'] = th2np(torch_key_weight[k])
elif 'patch_embeddings.weight' in k:
tf_key_weight['embedding/kernel'] = th2np(torch_key_weight[k], conv=True)
elif 'patch_embeddings.bias' in k:
tf_key_weight['embedding/bias'] = th2np(torch_key_weight[k])
elif 'head.weight' in k:
tf_key_weight['head/kernel'] = th2np(torch_key_weight[k], tp=True)
elif 'head.bias' in k:
tf_key_weight['head/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'encoder_norm.weight' in k:
tf_key_weight[tf_key_root+'encoder_norm/scale'] = th2np(torch_key_weight[k])
elif 'encoder_norm.bias' in k:
tf_key_weight[tf_key_root+'encoder_norm/bias'] = th2np(torch_key_weight[k])
elif 'layer' in k:
k_list = k.split('.')
layer_num = int(k_list[3])
layer_name = '.'.join(k_list[4:])
sub_tf_key_root = 'Transformer/encoderblock_'
if layer_name == 'attention_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'attention_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn.fc1.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc1.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'attn' in layer_name:
sub_layer_name = layer_name.split('.')[1]
if 'weight' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].t().reshape(12, 64, self.hidden_size))
else:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].t().reshape(self.hidden_size, 12, 64))
elif 'bias' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
th2np(torch_key_weight[k])
else:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
th2np(torch_key_weight[k].t().reshape(64, 12))
# th2np(torch_key_weight[k].reshape(12, 64))
else:
print('[ERROR] attn wrong ? The name is {}'.format(k))
exit()
else:
print('[ERROR] layer wrong ? The name is {}'.format(k))
exit()
print('[INFO] source dict: {} / new dict: {}'.format(len(torch_key_weight), len(tf_key_weight)))
save_path = path.strip('.bin')+'_tf.bin'
# np.savez(tf_key_weight, save_path)
torch.save(tf_key_weight, save_path)
'''
My vit has been changed to customize the input size, the input of the image in the code is 384x128
You can refer the example of https://github.com/NVIDIA/FasterTransformer/blob/main/docs/vit_guide.md#run-with-tensorrt-plugin and https://github.com/NVIDIA/FasterTransformer/blob/main/examples/tensorrt/vit/plugin_loader.py.
Thx,I have referred to the above two examples, and the shape of the torch weight is consistent with ViT-B_16.npz, but avg diff : 1.1434973 max diff : 4.220703, I wonder if ViT-B_16.npz is from tf? I didn't find a tool to convert torch weights directly to trt_plugin.
In ViT-B_16.npz MultiHeadDotProductAttention_1/out/bias shape is 768, MultiHeadDotProductAttention_1/query/bias shape is (12, 64) I think it's a little strange.
There may be some problem on th2np function. Try to add .contiguous() after reshape like
def th2np(weights, conv=False, tp=False):
"""Possibly convert HWIO to OIHW."""
# """Possibly convert tf to torch."""
"""Possibly convert 2310 to 0123."""
if conv and (not tp):
return weights.permute(2, 3, 1, 0).contiguous()
elif (not conv) and tp:
return weights.t().contiguous()
elif (not conv) and (not tp):
return weights.contiguous()
Yes, .contiguous() is required, and there are some bugs in MultiHeadDotProductAttention, some places don't need to be transposed, now it work. Here is my codes:)
def th2np(weights, conv=False, tp=False):
if conv and (not tp):
return weights.permute(2, 3, 1, 0).contiguous()
elif (not conv) and tp:
return weights.t().contiguous()
elif (not conv) and (not tp):
return weights.contiguous()
def vit_bin_file_th_2_fw(self, path):
assert 'bin' in path
torch_key_weight = torch.load(path)
tf_key_weight = {}
tf_key_root = 'Transformer/'
for k in torch_key_weight.keys():
if 'cls_token' in k:
tf_key_weight['cls'] = th2np(torch_key_weight[k])
elif 'position_embeddings' in k:
tf_key_weight['Transformer/posembed_input/pos_embedding'] = th2np(torch_key_weight[k])
elif 'patch_embeddings.weight' in k:
tf_key_weight['embedding/kernel'] = th2np(torch_key_weight[k], conv=True)
elif 'patch_embeddings.bias' in k:
tf_key_weight['embedding/bias'] = th2np(torch_key_weight[k])
elif 'head.weight' in k:
tf_key_weight['head/kernel'] = th2np(torch_key_weight[k], tp=True)
elif 'head.bias' in k:
tf_key_weight['head/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'encoder_norm.weight' in k:
tf_key_weight[tf_key_root+'encoder_norm/scale'] = th2np(torch_key_weight[k])
elif 'encoder_norm.bias' in k:
tf_key_weight[tf_key_root+'encoder_norm/bias'] = th2np(torch_key_weight[k])
elif 'layer' in k:
k_list = k.split('.')
layer_num = int(k_list[3])
layer_name = '.'.join(k_list[4:])
sub_tf_key_root = 'Transformer/encoderblock_'
if layer_name == 'attention_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'attention_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn.fc1.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc1.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'attn' in layer_name:
sub_layer_name = layer_name.split('.')[1]
if 'weight' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].view(self.hidden_size, 12, 64).permute(1,2,0))
else:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].t().view(self.hidden_size, 12, 64))
elif 'bias' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
th2np(torch_key_weight[k])
else:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
th2np(torch_key_weight[k].view(12, 64))
else:
print('[ERROR] attn wrong ? The name is {}'.format(k))
exit()
else:
print('[ERROR] layer wrong ? The name is {}'.format(k))
exit()
print('[INFO] source dict: {} / new dict: {}'.format(len(torch_key_weight), len(tf_key_weight)))
save_path = path.strip('.bin')+'_tf.bin'
torch.save(tf_key_weight, save_path)
return save_path
If I want to implement some custom layers, such as batchnorm, or ibn, but I haven't seen the relevant content in FasterTransformer, what projects could I refer to to implement it?
https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html
Close this bug because it is inactivated. Feel free to re-open this issue if you still have any problem.