TubeViT icon indicating copy to clipboard operation
TubeViT copied to clipboard

onnx infer

Open yxl502 opened this issue 1 year ago • 0 comments

import torch import torchvision.models as models import sys import os current_dir = os.path.dirname(os.path.abspath(file))
parent_dir = os.path.dirname(current_dir) sys.path.insert(0, parent_dir) from tubevit.model import TubeViTLightningModule

加载一个预训练的PyTorch模型

model = TubeViTLightningModule( num_classes=3, video_shape=[3, 1, 448, 224], num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt", test_each_epoch = False ) model.eval()

定义模型的输入示例

dummy_input = torch.randn(1, 3, 1, 448, 224)

指定要保存的ONNX文件的路径

onnx_file_path = "./weights/test.onnx"

导出模型到ONNX格式

torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持

yxl502 avatar Jan 23 '24 06:01 yxl502