VideoTransformer-pytorch icon indicating copy to clipboard operation
VideoTransformer-pytorch copied to clipboard

How to test my trained model?

Open FJNU-LWP opened this issue 2 years ago • 2 comments

Hello, thank you very much for sharing this wonderful project!

I have now trained my own model and generated a .pth file using this code. How can I use this .pth file to test other data?

Looking forward to your response, and I would greatly appreciate it!

FJNU-LWP avatar Oct 13 '23 02:10 FJNU-LWP

Hello, have you solved this problem? @FJNU-LWP

call-me-akeiang avatar Nov 05 '24 09:11 call-me-akeiang

import argparse

import torch from torch import nn from video_transformer import ViViT from transformer import ClassificationHead from data_trainer import KineticsDataModule from tqdm import tqdm

def compute_top_k_accuracy(data_loader, model, cls_head, device, top_k=5): top_1_correct = 0 top_5_correct = 0 total_samples = 0

 model.eval()
 cls_head.eval()

with torch.no_grad():
    for video, label in tqdm(data_loader, desc="Evaluating"):
        video = video.to(device)
        label = label.to(device)

        # Forward pass through the model
        logits = model(video)
        output = cls_head(logits)
        output = output.view(output.size(0), -1)  # Flatten logits

        # Calculate Top-k accuracy
        _, topk_preds = output.topk(top_k, dim=1, largest=True, sorted=True)

        # Calculate Top-1 and Top-5 accuracy
        top_1_correct += (topk_preds[:, 0] == label).sum().item()
        top_5_correct += (topk_preds == label.unsqueeze(1)).sum().item()

        total_samples += label.size(0)

top_1_accuracy = top_1_correct / total_samples
top_5_accuracy = top_5_correct / total_samples

return top_1_accuracy, top_5_accuracy

def parse_args(): parser = argparse.ArgumentParser(description='lr receiver') # Common

parser.add_argument(
    '-objective', type=str, default='supervised',
    help='the learning objective from [mim, supervised]')
parser.add_argument(
    '-eval_metrics', type=str, default='finetune',
    help='the eval metrics choosen from [linear_prob, finetune]')
parser.add_argument(
    '-batch_size', type=int, default= 16,
    help='the batch size of data inputs')
parser.add_argument(
    '-num_workers', type=int, default=4,
    help='the num workers of loading data')

# Environment
parser.add_argument(
    '-gpus', nargs='+', type=int, default=-1,
    help='the avaiable gpus in this experiment')

# Data
parser.add_argument(
    '-num_class', type=int, default=400,
    help='the num class of dataset used')
parser.add_argument(
    '-num_samples_per_cls', type=int, default=10000,
    help='the num samples of per class')
parser.add_argument(
    '-img_size', type=int, default=224,
    help='the size of processed image')
parser.add_argument(
    '-num_frames', type=int, default=16,
    help='the mumber of frame sampling') 
parser.add_argument(
    '-frame_interval', type=int, default=16,
    help='the intervals of frame sampling')
parser.add_argument(
    '-data_statics', type=str, default='kinetics',
    help='choose data statics from [imagenet, kinetics, clip]')
parser.add_argument(
    '-val_data_path', type=str, default="/home/hmf/vivit_clip/VideoTransformer-pytorch-main/OpenMMLab___Kinetics"
                                        "-400/raw/Kinetics-400/kinetics400_val_list_videos.txt",
    help='the path to val set')
parser.add_argument(
    '-test_data_path', type=str, default=None,
    help='the path to test set')
parser.add_argument(
    '-auto_augment', type=str, default=None,
    help='the used Autoaugment policy')
parser.add_argument(
    '-mixup', type=bool, default=False,
    help="""Whether or not to use multi crop.""")  
parser.add_argument(
    '-multi_crop', type=bool, default=False,
    help="""Whether or not to use multi crop.""")

# Model
parser.add_argument(
    '-arch', type=str, default='vivit',
    help='the choosen model arch from [timesformer, vivit]')
parser.add_argument(
    '-attention_type', type=str, default='fact_encoder',
    help='the choosen attention type using in model')
parser.add_argument(
    '-pretrain_pth', type=str, default="./vivit_model.pth",
    help='the path to the pretrain weights')
parser.add_argument(
    '-weights_from', type=str, default='kinetics',
    help='the pretrain params from [imagenet, kinetics, clip]')

args = parser.parse_args()

return args

def replace_state_dict(state_dict): for old_key in list(state_dict.keys()): if old_key.startswith('model'): new_key = old_key[6:] # skip 'model.'
if 'in_proj' in new_key: new_key = new_key.replace('in_proj_', 'qkv.') # in_proj_weight -> qkv.weight elif 'out_proj' in new_key: new_key = new_key.replace('out_proj', 'proj') # out_proj -> proj state_dict[new_key] = state_dict.pop(old_key) else: # cls_head new_key = old_key[9:]
state_dict[new_key] = state_dict.pop(old_key)

def init_from_kinetics_pretrain_(module, pretrain_pth): if torch.cuda.is_available(): state_dict = torch.load(pretrain_pth) else: state_dict = torch.load(pretrain_pth, map_location=torch.device('cpu')) if 'state_dict' in state_dict: state_dict = state_dict['state_dict']

replace_state_dict(state_dict)
msg = module.load_state_dict(state_dict, strict=False)
return msg

def test_vivit_model(test_checkpoint_path): args = parse_args() # Step 1: Load the pretrained model model = ViViT(num_frames=args.num_frames, img_size=args.img_size, patch_size=16, embed_dims=768, in_channels=3, attention_type=args.attention_type, return_cls_token=True, weights_from=args.weights_from, pretrain_pth=args.pretrain_pth )

# Load the pre-trained weights
cls_head = ClassificationHead(num_classes=args.num_class, in_channels=768)
msg_trans = init_from_kinetics_pretrain_(model, test_checkpoint_path)
msg_cls = init_from_kinetics_pretrain_(cls_head, test_checkpoint_path)

# model.eval()
# cls_head.eval()

print(f'Model loaded successfully. Missing keys (transformer): {msg_trans[0]}, (cls_head): {msg_cls[0]}')

# Step 2: Prepare the test dataset
val_data_module = KineticsDataModule(configs=args,
                                     train_ann_path=None,  # No need for training data
                                     val_ann_path=args.val_data_path,
                                     test_ann_path=None)
val_data_module.setup(stage='test') 
val_dataloader = val_data_module.val_dataloader()

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model) 
model = model.to(device)
cls_head = cls_head.to(device)
top_1_accuracy, top_5_accuracy = compute_top_k_accuracy(val_dataloader, model, cls_head, device, top_k=5)
print(f'Top-1 Accuracy: {top_1_accuracy * 100:.2f}%')
print(f'Top-5 Accuracy: {top_5_accuracy * 100:.2f}%')

if name == 'main': # Define paths and configurations test_checkpoint_path = "/home/hmf/vivit_clip/results/lr_0.000625_optim_sgd_lr_schedule_cosine_weight_decay_0.0001_weights_from_imagenet_num_frames_16_frame_interval_16_mixup_False/ckpt/2024-12-17 19:12:49_ep_5_top1_acc_0.623.pth" # Path to the checkpoint # test_data_path = "/home/hmf/vivit_clip/VideoTransformer-pytorch-main/OpenMMLab___Kinetics-400/raw/Kinetics-400/kinetics400_val_list_videos.txt" # Path to the test dataset """ # Configuration settings for the model and training configs = { 'num_frames': 16, # Adjust number of frames as needed 'img_size': 224, 'patch_size': 16, 'embed_dims': 768, 'in_channels': 3, 'attention_type': 'fact_encoder', 'return_cls_token': True, 'weights_from': 'kinetics', 'pretrain_pth': test_checkpoint_path, 'frame_interval': 16, 'num_class': 400, 'objective': 'supervised', 'data_statics': 'kinetics' } """ # Call the function to test the model test_vivit_model(test_checkpoint_path)
@FJNU-LWP @call-me-akeiang This is my test file. I hope it can help you. If there are errors, please point out. : ) By the way, what is the accuracy of the vivit-k400 model you trained in the validation set? My top1 accuracy is only 62%. The lr is 5e-2, and the weight_decay is 0.0001. Other parameter settings are the same as those in the readme.md. Can I have a look at your training log file? Or can you give me some advie? Looking forward to your response.

hemengfan2002 avatar Dec 23 '24 12:12 hemengfan2002