TubeViT
TubeViT copied to clipboard
onnx infer
import torch
import torchvision.models as models
import sys
import os
current_dir = os.path.dirname(os.path.abspath(file))
parent_dir = os.path.dirname(current_dir)
sys.path.insert(0, parent_dir)
from tubevit.model import TubeViTLightningModule
加载一个预训练的PyTorch模型
model = TubeViTLightningModule( num_classes=3, video_shape=[3, 1, 448, 224], num_layers=12, num_heads=12, hidden_dim=768, mlp_dim=3072, weight_path="../weights/tubevit_vitbase_nc3_fpc3_448_224.pt", test_each_epoch = False ) model.eval()
定义模型的输入示例
dummy_input = torch.randn(1, 3, 1, 448, 224)
指定要保存的ONNX文件的路径
onnx_file_path = "./weights/test.onnx"
导出模型到ONNX格式
torch.onnx.export(model, dummy_input, onnx_file_path, verbose=True,opset_version=12) # 版本只有7-16,但是都不支持