VideoTransformer-pytorch
VideoTransformer-pytorch copied to clipboard
How to test my trained model?
Hello, thank you very much for sharing this wonderful project!
I have now trained my own model and generated a .pth file using this code. How can I use this .pth file to test other data?
Looking forward to your response, and I would greatly appreciate it!
Hello, have you solved this problem? @FJNU-LWP
import argparse
import torch from torch import nn from video_transformer import ViViT from transformer import ClassificationHead from data_trainer import KineticsDataModule from tqdm import tqdm
def compute_top_k_accuracy(data_loader, model, cls_head, device, top_k=5): top_1_correct = 0 top_5_correct = 0 total_samples = 0
model.eval()
cls_head.eval()
with torch.no_grad():
for video, label in tqdm(data_loader, desc="Evaluating"):
video = video.to(device)
label = label.to(device)
# Forward pass through the model
logits = model(video)
output = cls_head(logits)
output = output.view(output.size(0), -1) # Flatten logits
# Calculate Top-k accuracy
_, topk_preds = output.topk(top_k, dim=1, largest=True, sorted=True)
# Calculate Top-1 and Top-5 accuracy
top_1_correct += (topk_preds[:, 0] == label).sum().item()
top_5_correct += (topk_preds == label.unsqueeze(1)).sum().item()
total_samples += label.size(0)
top_1_accuracy = top_1_correct / total_samples
top_5_accuracy = top_5_correct / total_samples
return top_1_accuracy, top_5_accuracy
def parse_args(): parser = argparse.ArgumentParser(description='lr receiver') # Common
parser.add_argument(
'-objective', type=str, default='supervised',
help='the learning objective from [mim, supervised]')
parser.add_argument(
'-eval_metrics', type=str, default='finetune',
help='the eval metrics choosen from [linear_prob, finetune]')
parser.add_argument(
'-batch_size', type=int, default= 16,
help='the batch size of data inputs')
parser.add_argument(
'-num_workers', type=int, default=4,
help='the num workers of loading data')
# Environment
parser.add_argument(
'-gpus', nargs='+', type=int, default=-1,
help='the avaiable gpus in this experiment')
# Data
parser.add_argument(
'-num_class', type=int, default=400,
help='the num class of dataset used')
parser.add_argument(
'-num_samples_per_cls', type=int, default=10000,
help='the num samples of per class')
parser.add_argument(
'-img_size', type=int, default=224,
help='the size of processed image')
parser.add_argument(
'-num_frames', type=int, default=16,
help='the mumber of frame sampling')
parser.add_argument(
'-frame_interval', type=int, default=16,
help='the intervals of frame sampling')
parser.add_argument(
'-data_statics', type=str, default='kinetics',
help='choose data statics from [imagenet, kinetics, clip]')
parser.add_argument(
'-val_data_path', type=str, default="/home/hmf/vivit_clip/VideoTransformer-pytorch-main/OpenMMLab___Kinetics"
"-400/raw/Kinetics-400/kinetics400_val_list_videos.txt",
help='the path to val set')
parser.add_argument(
'-test_data_path', type=str, default=None,
help='the path to test set')
parser.add_argument(
'-auto_augment', type=str, default=None,
help='the used Autoaugment policy')
parser.add_argument(
'-mixup', type=bool, default=False,
help="""Whether or not to use multi crop.""")
parser.add_argument(
'-multi_crop', type=bool, default=False,
help="""Whether or not to use multi crop.""")
# Model
parser.add_argument(
'-arch', type=str, default='vivit',
help='the choosen model arch from [timesformer, vivit]')
parser.add_argument(
'-attention_type', type=str, default='fact_encoder',
help='the choosen attention type using in model')
parser.add_argument(
'-pretrain_pth', type=str, default="./vivit_model.pth",
help='the path to the pretrain weights')
parser.add_argument(
'-weights_from', type=str, default='kinetics',
help='the pretrain params from [imagenet, kinetics, clip]')
args = parser.parse_args()
return args
def replace_state_dict(state_dict):
for old_key in list(state_dict.keys()):
if old_key.startswith('model'):
new_key = old_key[6:] # skip 'model.'
if 'in_proj' in new_key:
new_key = new_key.replace('in_proj_', 'qkv.') # in_proj_weight -> qkv.weight
elif 'out_proj' in new_key:
new_key = new_key.replace('out_proj', 'proj') # out_proj -> proj
state_dict[new_key] = state_dict.pop(old_key)
else: # cls_head
new_key = old_key[9:]
state_dict[new_key] = state_dict.pop(old_key)
def init_from_kinetics_pretrain_(module, pretrain_pth): if torch.cuda.is_available(): state_dict = torch.load(pretrain_pth) else: state_dict = torch.load(pretrain_pth, map_location=torch.device('cpu')) if 'state_dict' in state_dict: state_dict = state_dict['state_dict']
replace_state_dict(state_dict)
msg = module.load_state_dict(state_dict, strict=False)
return msg
def test_vivit_model(test_checkpoint_path): args = parse_args() # Step 1: Load the pretrained model model = ViViT(num_frames=args.num_frames, img_size=args.img_size, patch_size=16, embed_dims=768, in_channels=3, attention_type=args.attention_type, return_cls_token=True, weights_from=args.weights_from, pretrain_pth=args.pretrain_pth )
# Load the pre-trained weights
cls_head = ClassificationHead(num_classes=args.num_class, in_channels=768)
msg_trans = init_from_kinetics_pretrain_(model, test_checkpoint_path)
msg_cls = init_from_kinetics_pretrain_(cls_head, test_checkpoint_path)
# model.eval()
# cls_head.eval()
print(f'Model loaded successfully. Missing keys (transformer): {msg_trans[0]}, (cls_head): {msg_cls[0]}')
# Step 2: Prepare the test dataset
val_data_module = KineticsDataModule(configs=args,
train_ann_path=None, # No need for training data
val_ann_path=args.val_data_path,
test_ann_path=None)
val_data_module.setup(stage='test')
val_dataloader = val_data_module.val_dataloader()
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model)
model = model.to(device)
cls_head = cls_head.to(device)
top_1_accuracy, top_5_accuracy = compute_top_k_accuracy(val_dataloader, model, cls_head, device, top_k=5)
print(f'Top-1 Accuracy: {top_1_accuracy * 100:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy * 100:.2f}%')
if name == 'main':
# Define paths and configurations
test_checkpoint_path = "/home/hmf/vivit_clip/results/lr_0.000625_optim_sgd_lr_schedule_cosine_weight_decay_0.0001_weights_from_imagenet_num_frames_16_frame_interval_16_mixup_False/ckpt/2024-12-17 19:12:49_ep_5_top1_acc_0.623.pth" # Path to the checkpoint
# test_data_path = "/home/hmf/vivit_clip/VideoTransformer-pytorch-main/OpenMMLab___Kinetics-400/raw/Kinetics-400/kinetics400_val_list_videos.txt" # Path to the test dataset
"""
# Configuration settings for the model and training
configs = {
'num_frames': 16, # Adjust number of frames as needed
'img_size': 224,
'patch_size': 16,
'embed_dims': 768,
'in_channels': 3,
'attention_type': 'fact_encoder',
'return_cls_token': True,
'weights_from': 'kinetics',
'pretrain_pth': test_checkpoint_path,
'frame_interval': 16,
'num_class': 400,
'objective': 'supervised',
'data_statics': 'kinetics'
}
"""
# Call the function to test the model
test_vivit_model(test_checkpoint_path)
@FJNU-LWP @call-me-akeiang This is my test file. I hope it can help you. If there are errors, please point out. : )
By the way, what is the accuracy of the vivit-k400 model you trained in the validation set? My top1 accuracy is only 62%.
The lr is 5e-2, and the weight_decay is 0.0001. Other parameter settings are the same as those in the readme.md.
Can I have a look at your training log file? Or can you give me some advie? Looking forward to your response.