StyTR-2
StyTR-2 copied to clipboard
About the metric score of StyTr2:Image Style Transfer with Transformers
Thanks for your sharing your code. It's a wonderful job I think~~
I have one question about the content loss score. I have applied StyTr2 to a dataset of 800 images, using your pre-trained model. To ensure consistency with your test settings, I resized all the images to 256x256 before calculating the content loss. However, I have noticed significant differences in the content loss values compared to what is reported in your papers.
I understand that variations in scores are expected due to the use of different images. Nonetheless, I found that the style loss scores exhibit a similar trend, while the content loss scores demonstrate noticeable discrepancies. So May I know how can you calculate the content loss? Is it possible to share your metric code or tell me where I am wrong?
#!/usr/bin/env python3
import argparse
import os
import torch
import torch.nn as nn
from tqdm import tqdm
import cv2
parser = argparse.ArgumentParser()
parser.add_argument("--resize", type=int, default=256, help="resize_image_size")
parser.add_argument("--content_dir", default=r'\input\content', help="the directory of content images")
parser.add_argument("--style_dir", default=r'\input\style', help="the directory of style images")
parser.add_argument("--stylized_dir", default=r\StyTR-2-main\output', required=False, help="the directory of stylized images")
parser.add_argument("--log_path", default=r't\metrics', required=False, help="the directory of stylized images")
parser.add_argument('--mode', type=int, default=1, help="0 for style loss, 1 for content loss, 2 for both")
args = parser.parse_args()
device = torch.device("cuda")
vgg = nn.Sequential(
nn.Conv2d(3, 3, (1, 1)),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(3, 64, (3, 3)),
nn.ReLU(), # relu1-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 64, (3, 3)),
nn.ReLU(), # relu1-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(64, 128, (3, 3)),
nn.ReLU(), # relu2-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 128, (3, 3)),
nn.ReLU(), # relu2-2
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(128, 256, (3, 3)),
nn.ReLU(), # relu3-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 256, (3, 3)),
nn.ReLU(), # relu3-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(256, 512, (3, 3)),
nn.ReLU(), # relu4-1, this is the last layer used
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu4-4
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-1
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-2
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU(), # relu5-3
nn.ReflectionPad2d((1, 1, 1, 1)),
nn.Conv2d(512, 512, (3, 3)),
nn.ReLU() # relu5-4
)
vgg.eval()
vgg.load_state_dict(torch.load("../models/vgg_normalised.pth"))
enc_1 = nn.Sequential(*list(vgg.children())[:4]) # input -> relu1_1
enc_2 = nn.Sequential(*list(vgg.children())[4:11]) # relu1_1 -> relu2_1
enc_3 = nn.Sequential(*list(vgg.children())[11:18]) # relu2_1 -> relu3_1
enc_4 = nn.Sequential(*list(vgg.children())[18:31]) # relu3_1 -> relu4_1
enc_5 = nn.Sequential(*list(vgg.children())[31:44]) # relu4_1 -> relu5_1
enc_1.to(device)
enc_2.to(device)
enc_3.to(device)
enc_4.to(device)
enc_5.to(device)
def calc_content_loss(input, target):
assert (input.size() == target.size())
return torch.nn.MSELoss()(input, target)
content_dir = args.content_dir
style_dir = args.style_dir
stylized_dir = args.stylized_dir
log_dir = args.log_path
stylized_files = os.listdir(stylized_dir)
folder_components = stylized_dir.split(os.path.sep)
name = folder_components[-2]
sub_name = folder_components[-1]
log_path = os.path.join(args.log_path, name + '_log.txt')
with torch.no_grad():
if args.mode == 1 or args.mode == 2:
loss_c_sum = 0.
count = 0
for i, stylized in enumerate(tqdm(stylized_files)):
stylized_img = cv2.imread(stylized_dir + os.sep + stylized) # stylized image
if stylized_img is None or stylized_img.size == 0:
print('Failed to load the image:', stylized_dir + os.sep + stylized)
stylized_img = cv2.resize(stylized_img, (args.resize, args.resize))
name = stylized.split("_stylized_") # parse the content image's name
content_img = cv2.imread(content_dir + os.sep + name[0] + '.jpg') # content image
if content_img is None or content_img.size == 0:
print('Failed to load the image:', content_dir + os.sep + name[0] + '.jpg')
content_img = cv2.resize(content_img, (args.resize, args.resize))
stylized_img = torch.tensor(stylized_img, dtype=torch.float)
stylized_img = stylized_img/255
stylized_img = torch.unsqueeze(stylized_img, dim=0)
stylized_img = stylized_img.permute([0, 3, 1, 2])
stylized_img = stylized_img.cuda().to(device)
content_img = torch.tensor(content_img, dtype=torch.float)
content_img = content_img/255
content_img = torch.unsqueeze(content_img, dim=0)
content_img = content_img.permute([0, 3, 1, 2])
content_img = content_img.cuda().to(device)
loss_c = 0.
o1 = enc_4(enc_3(enc_2(enc_1(stylized_img))))
c1 = enc_4(enc_3(enc_2(enc_1(content_img))))
loss_c += calc_content_loss(o1, c1)
o2 = enc_5(o1)
c2 = enc_5(c1)
loss_c += calc_content_loss(o2, c2)
print("Content Loss: {}".format(loss_c / 2))
loss_c_sum += float(loss_c / 2)
count += 1
print("Total num: {}".format(count))
print("Average Content Loss: {}".format(loss_c_sum / count))