FSRNET_pytorch
FSRNET_pytorch copied to clipboard
the test is empty,can you provide it
You can check my test.py modified
` from future import print_function import os os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.utils as vutils from torch.autograd import Variable import torchvision
import argparse import os import random
from dataset import * import time import numpy as np
from networks import * from math import log10
import cv2 import skimage import scipy.io import glob import matplotlib.image as mpimg import matplotlib.pyplot as plt
from data.dataloader import *
import matplotlib.image as mpimg import matplotlib.pyplot as plt from PIL import Image
parser = argparse.ArgumentParser() parser.add_argument('--test', default='True', action='store_true', help='enables test during training') parser.add_argument('--mse_avg', action='store_true', help='enables mse avg') parser.add_argument('--num_layers_res', type=int, help='number of the layers in residual block', default=2) parser.add_argument('--nrow', type=int, help='number of the rows to save images', default=1) parser.add_argument('--batchSize', type=int, default=64, help='input batch size') parser.add_argument('--test_batchSize', type=int, default=64, help='test batch size') parser.add_argument('--save_iter', type=int, default=10, help='the interval iterations for saving models') parser.add_argument('--test_iter', type=int, default=500, help='the interval iterations for testing') parser.add_argument('--cdim', type=int, default=3, help='the channel-size of the input image to network') parser.add_argument("--nEpochs", type=int, default=1000, help="number of epochs to train for") parser.add_argument("--start_epoch", default=0, type=int, help="Manual epoch number (useful on restarts)") parser.add_argument('--lr', type=float, default=0.7 * 2.5 * 10 ** (-4), help='learning rate, default=0.0002') parser.add_argument('--cuda', default='True', action='store_true', help='enables cuda') parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') parser.add_argument('--outf', default='./results/1_4/', help='folder to output images') parser.add_argument('--manualSeed', type=int, help='manual seed') parser.add_argument("--pretrained",default="./model/sr_1_4_0model_epoch_812_iter_0.pth", type=str, help="path to pretrained model (default: none)") parser.add_argument("--batch_size", default="20", type=int, help="The path to store our batch_size") parser.add_argument("--image_dir", default="./data/CelebA-HQ-img/", type=str, help="The path to store our batch_size") parser.add_argument("--image_list", default="./data/test_fileList.txt", help="The path to store our batch_size")
def main(): global opt, model opt = parser.parse_args() print(opt)
try:
os.makedirs(opt.outf)
except OSError:
pass
if opt.manualSeed is None:
opt.manualSeed = random.randint(1, 10000)
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
if opt.cuda:
torch.cuda.manual_seed_all(opt.manualSeed)
cudnn.benchmark = True
#if torch.cuda.is_available() and not opt.cuda:
# print("WARNING: You have a CUDA device, so you should probably run with --cuda")
ngpu = int(opt.ngpu)
with torch.no_grad():
srnet = NetSR(num_layers_res=opt.num_layers_res)
if opt.cuda:
srnet = srnet.cuda()
if opt.pretrained:
if os.path.isfile(opt.pretrained):
print("=> loading model '{}'".format(opt.pretrained))
weights = torch.load(opt.pretrained, map_location='cpu')
pretrained_dict = weights['model'].state_dict()
model_dict = srnet.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
srnet.load_state_dict(model_dict)
else:
print("=> no model found at '{}'".format(opt.pretrained))
size = 128
batch_size = 14
save_freq = 2
result_dir = './results/'
demo_dataset = ImageDatasetFromFile(
opt.image_list,
opt.image_dir,is_parsing_map=False)
test_data_loader = data.DataLoader(dataset=demo_dataset, batch_size=opt.batch_size, num_workers=8, drop_last=True,
pin_memory=True)
for iteration, batch in enumerate(test_data_loader):
input0, target0 = Variable(batch[0]), Variable(batch[1])
if opt.cuda:
input0 = input0.cuda()
target0 = target0.cuda()
try:
with torch.no_grad():
output0, parsing_maps, output = srnet(input0)
except RuntimeError as exception:
if "out of memory" in str(exception):
print("Warning: out of memory")
else:
raise exception
output11 = output.permute(0, 2, 3, 1).cpu().data.numpy()
for n in range(opt.batch_size):
output01 = output11[n, :, :, :]
inputXX = input0[n, :, :, :]
inputXX = inputXX.permute(2, 1, 0).cpu().data.numpy()
targetXX = target0[n, :, :, :]
targetXX = targetXX.permute(2, 1, 0).cpu().data.numpy()
temp = np.concatenate((inputXX*255, targetXX*255,output01*255), axis=1)
Image.fromarray(temp.astype('uint8')).save(result_dir + 'lr_%d_%d.jpg' % (iteration, n))
if name == "main": main() `