CKA.pytorch icon indicating copy to clipboard operation
CKA.pytorch copied to clipboard

NaNs

Open repers opened this issue 2 years ago • 5 comments

Hi there, thanks for providing the code for the CKA analysis. I have tried implementing this on my model, however, I keep on getting NaNs in the final output matrix. Any idea why this happens? Thanks

repers avatar Jul 18 '23 10:07 repers

There could be many reasons for NaN, but more details are needed

numpee avatar Jul 18 '23 20:07 numpee

Hi, so the network I'm trying to apply this on is https://github.com/HyeongminLEE/AdaCoF-pytorch/blob/master/models/adacofnet.py

I use the dataloader to get the testeset from https://github.com/tding1/CDFI/blob/main/datasets.py

I have a feeling it might be due to the custom cuda implementation at the end, but is there a way to only apply hooks for the UNet part of the architecture?

This is the main code I run:

from datareader import DBreader_Vimeo90k
from torch.utils.data import DataLoader
import argparse
from torchvision import transforms
import torch
from TestModule import Middlebury_other
import models
from trainer import Trainer
import losses
import datetime
from adacofnet1 import make_model
from datasets import Vimeo90K_interp
from cka import CKACalculator
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(description='AdaCoF-Pytorch')

# parameters
# Model Selection
parser.add_argument('--model', type=str, default='adacofnet')

# Hardware Setting
parser.add_argument('--gpu_id', type=int, default=0)

# Directory Setting
parser.add_argument('--train', type=str, default='./db/vimeo_triplet')
parser.add_argument('--out_dir', type=str, default='./output_adacof_train')
parser.add_argument('--load', type=str, default=None)
parser.add_argument('--load2', type=str, default=None)
parser.add_argument('--test_input', type=str, default='./test_input/middlebury_others/input')
parser.add_argument('--gt', type=str, default='./test_input/middlebury_others/gt')

# Learning Options
parser.add_argument('--epochs', type=int, default=50, help='Max Epochs')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size')
parser.add_argument('--loss', type=str, default='1*Charb+0.01*g_Spatial+0.005*g_Occlusion', help='loss function configuration')
parser.add_argument('--patch_size', type=int, default=128, help='Patch size')

parser.add_argument('--kernel_size', type=int, default=5)
parser.add_argument('--dilation', type=int, default=1)

transform = transforms.Compose([transforms.ToTensor()])

def main():
    args = parser.parse_args()
    torch.cuda.set_device(args.gpu_id)
    train_dataset, val_dataset = Vimeo90K_interp(args.train)
    test_loader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    model1 = make_model(args)
    checkpoint = torch.load(args.load)
    model1.load_state_dict(checkpoint['state_dict'])
    model2 = make_model(args)
    checkpoint2 = torch.load(args.load2)
    model2.load_state_dict(checkpoint2['state_dict'])
    calculator = CKACalculator(model1=model1, model2=model2, dataloader=test_loader)
    cka_output = calculator.calculate_cka_matrix()
    print(cka_output)
    import matplotlib.pyplot as plt
    plt.rcParams['figure.figsize'] = (7, 7)
    plt.savefig('new.png')
    for i, name in enumerate(calculator.module_names_X):
        print(name)
if __name__ == "__main__":
    main()

repers avatar Jul 19 '23 15:07 repers

One thing to note, the size of the testset does not matter in terms of getting NaNs, I tried a small subset of about 40 images and the full one, same issue happened!

repers avatar Jul 19 '23 15:07 repers

Hi, you can apply hooks to custom layers by passing the modules into the CKA calculator. Check under "Advanced Usage" of the example jupyter provided.

As for the NaNs, I'm not exactly sure what's going on. There may be some underflow or overflow. Maybe you could try modifying the epsilon parameter in the CKACalculator? Also, do you normalize the input data before passing it to the model?

numpee avatar Jul 21 '23 23:07 numpee

batchsize must be gratter than 3

WenLinLliu avatar Aug 06 '24 08:08 WenLinLliu