CFA_for_anomaly_localization
CFA_for_anomaly_localization copied to clipboard
model save
I'm confused why this model is saved here. From the code it seems that it has not been trained, it is just a pre-trained model.
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
os.path.join(checkpoint_dir, "best.pt" ),
)
Train code
model = model.to(device)
model.eval()
loss_fn = DSVDD(model, train_loader, args.cnn, args.gamma_c, args.gamma_d, device)
loss_fn = loss_fn.to(device)
epochs = 30
params = [{'params' : loss_fn.parameters()},]
optimizer = optim.AdamW(params = params,
lr = 1e-3,
weight_decay = 5e-4,
amsgrad = True )
best_pxl_pro = -1
for epoch in tqdm(range(epochs), '%s -->'%(class_name)):
r'TEST PHASE'
test_imgs = list()
gt_mask_list = list()
gt_list = list()
heatmaps = None
loss_fn.train()
for (x, _, _) in train_loader:
optimizer.zero_grad()
p = model(x.to(device))
loss, _ = loss_fn(p)
loss.backward()
optimizer.step()
loss_fn.eval()