FasterTransformer copied to clipboard
How convert weights from pytorch vit to tensorrt plugin?
Torch:1.10.2+cu113 TRT: I train model with vit, I want to convert weights from pytorch vit to tensorrt. I wrote a function to transform the weights to adapt to the loadWeightsPtr of ViTPlugin.cpp, but there is a little problem,my diff is huge between torch and trt. here is my function of transform weight . Is there a special tool for conversion in FasterTransformer?
''' def vit_bin_file_th_2_fw(self, path): # path is the weight of pytorch def th2np(weights, conv=False, tp=False): """Possibly convert HWIO to OIHW.""" # """Possibly convert tf to torch.""" """Possibly convert 2310 to 0123.""" if conv and (not tp): return weights.permute(2, 3, 1, 0) # return weights.permute(3, 2, 1, 0) elif (not conv) and tp: return weights.t() elif (not conv) and (not tp): return weights
assert 'bin' in path
print('[INFO] weight is transform from torch_key, torch_weight to tf_key, torch_weight...')
torch_key_weight = torch.load(path)
tf_key_weight = {}
tf_key_root = 'Transformer/'
for k in torch_key_weight.keys():
if 'cls_token' in k:
tf_key_weight['cls'] = th2np(torch_key_weight[k])
elif 'position_embeddings' in k:
tf_key_weight['Transformer/posembed_input/pos_embedding'] = th2np(torch_key_weight[k])
elif 'patch_embeddings.weight' in k:
tf_key_weight['embedding/kernel'] = th2np(torch_key_weight[k], conv=True)
elif 'patch_embeddings.bias' in k:
tf_key_weight['embedding/bias'] = th2np(torch_key_weight[k])
elif 'head.weight' in k:
tf_key_weight['head/kernel'] = th2np(torch_key_weight[k], tp=True)
elif 'head.bias' in k:
tf_key_weight['head/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'encoder_norm.weight' in k:
tf_key_weight[tf_key_root+'encoder_norm/scale'] = th2np(torch_key_weight[k])
elif 'encoder_norm.bias' in k:
tf_key_weight[tf_key_root+'encoder_norm/bias'] = th2np(torch_key_weight[k])
elif 'layer' in k:
k_list = k.split('.')
layer_num = int(k_list[3])
layer_name = '.'.join(k_list[4:])
sub_tf_key_root = 'Transformer/encoderblock_'
if layer_name == 'attention_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'attention_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn.fc1.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc1.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'attn' in layer_name:
sub_layer_name = layer_name.split('.')[1]
if 'weight' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].t().reshape(12, 64, self.hidden_size))
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].t().reshape(self.hidden_size, 12, 64))
elif 'bias' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
th2np(torch_key_weight[k].t().reshape(64, 12))
# th2np(torch_key_weight[k].reshape(12, 64))
print('[ERROR] attn wrong ? The name is {}'.format(k))
print('[ERROR] layer wrong ? The name is {}'.format(k))
print('[INFO] source dict: {} / new dict: {}'.format(len(torch_key_weight), len(tf_key_weight)))
save_path = path.strip('.bin')+'_tf.bin'
# np.savez(tf_key_weight, save_path), save_path)
My vit has been changed to customize the input size, the input of the image in the code is 384x128
You can refer the example of and
Thx,I have referred to the above two examples, and the shape of the torch weight is consistent with ViT-B_16.npz, but avg diff : 1.1434973 max diff : 4.220703, I wonder if ViT-B_16.npz is from tf? I didn't find a tool to convert torch weights directly to trt_plugin.
In ViT-B_16.npz MultiHeadDotProductAttention_1/out/bias shape is 768, MultiHeadDotProductAttention_1/query/bias shape is (12, 64) I think it's a little strange.
There may be some problem on th2np function. Try to add .contiguous() after reshape like
def th2np(weights, conv=False, tp=False):
"""Possibly convert HWIO to OIHW."""
# """Possibly convert tf to torch."""
"""Possibly convert 2310 to 0123."""
if conv and (not tp):
return weights.permute(2, 3, 1, 0).contiguous()
elif (not conv) and tp:
return weights.t().contiguous()
elif (not conv) and (not tp):
return weights.contiguous()
Yes, .contiguous() is required, and there are some bugs in MultiHeadDotProductAttention, some places don't need to be transposed, now it work. Here is my codes:)
def th2np(weights, conv=False, tp=False):
if conv and (not tp):
return weights.permute(2, 3, 1, 0).contiguous()
elif (not conv) and tp:
return weights.t().contiguous()
elif (not conv) and (not tp):
return weights.contiguous()
def vit_bin_file_th_2_fw(self, path):
assert 'bin' in path
torch_key_weight = torch.load(path)
tf_key_weight = {}
tf_key_root = 'Transformer/'
for k in torch_key_weight.keys():
if 'cls_token' in k:
tf_key_weight['cls'] = th2np(torch_key_weight[k])
elif 'position_embeddings' in k:
tf_key_weight['Transformer/posembed_input/pos_embedding'] = th2np(torch_key_weight[k])
elif 'patch_embeddings.weight' in k:
tf_key_weight['embedding/kernel'] = th2np(torch_key_weight[k], conv=True)
elif 'patch_embeddings.bias' in k:
tf_key_weight['embedding/bias'] = th2np(torch_key_weight[k])
elif 'head.weight' in k:
tf_key_weight['head/kernel'] = th2np(torch_key_weight[k], tp=True)
elif 'head.bias' in k:
tf_key_weight['head/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'encoder_norm.weight' in k:
tf_key_weight[tf_key_root+'encoder_norm/scale'] = th2np(torch_key_weight[k])
elif 'encoder_norm.bias' in k:
tf_key_weight[tf_key_root+'encoder_norm/bias'] = th2np(torch_key_weight[k])
elif 'layer' in k:
k_list = k.split('.')
layer_num = int(k_list[3])
layer_name = '.'.join(k_list[4:])
sub_tf_key_root = 'Transformer/encoderblock_'
if layer_name == 'attention_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'attention_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_0/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/scale'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn_norm.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/LayerNorm_2/bias'] = th2np(torch_key_weight[k])
elif layer_name == 'ffn.fc1.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc1.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_0/bias'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.weight':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/kernel'] = th2np(torch_key_weight[k], tp=True)
elif layer_name == 'ffn.fc2.bias':
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MlpBlock_3/Dense_1/bias'] = th2np(torch_key_weight[k], tp=True)
elif 'attn' in layer_name:
sub_layer_name = layer_name.split('.')[1]
if 'weight' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].view(self.hidden_size, 12, 64).permute(1,2,0))
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/kernel'] = \
th2np(torch_key_weight[k].t().view(self.hidden_size, 12, 64))
elif 'bias' in layer_name:
if 'out' in layer_name:
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
tf_key_weight[sub_tf_key_root+ str(layer_num)+'/MultiHeadDotProductAttention_1/'+sub_layer_name+'/bias'] = \
th2np(torch_key_weight[k].view(12, 64))
print('[ERROR] attn wrong ? The name is {}'.format(k))
print('[ERROR] layer wrong ? The name is {}'.format(k))
print('[INFO] source dict: {} / new dict: {}'.format(len(torch_key_weight), len(tf_key_weight)))
save_path = path.strip('.bin')+'_tf.bin', save_path)
return save_path
If I want to implement some custom layers, such as batchnorm, or ibn, but I haven't seen the relevant content in FasterTransformer, what projects could I refer to to implement it?
Close this bug because it is inactivated. Feel free to re-open this issue if you still have any problem.