improved_CcGAN
improved_CcGAN copied to clipboard
How to use
Hi there,
Thanks for your work in generative AI. I'm currently trying to implement CcGAN and just wondering, is there anyway to train the model easily?
I can see there are some python files which receive arguments through shell script files, but they lack requirements.txt and need some configuration to run, as far as I can guess.
It'd be much appreciated if you could provide a way to use your model and see how it works with other custom datasets.
Thanks in advance.
By the way, my local uses Windows 11 OS
Hey, I am stuck on the same issue. Did you figure it out by any chance?
I managed to run the shell script and used my custom datasets.
What is the issue you're stuck at? @ashamy97
I am just stuck on what to do once I cloned the repository. I am using Jupyter notebook. Can you tell me what are the steps that you followed? SO I have my own custom dataset that I want to train the CcGAN on and instead of one conditional input, I have two. So I am not sure what do I need to change and how to run it. Thank you so much for your help.
- Understand the structure of the repository
Cell-200,RC-49,SteeringAngle... these upper level folders have the similar structure underneath it, by which the authors wanted to structure the code by image size and models used(if you read the paper, you will understand why they have structured the code like this. I guess, they just based the structuring strategy on their experiments- So,
64x64or128x128mean the image size to be used in the training, andCcGAN,CcGAN-improvedandcGAN-concatare the models to be used in the training also.
- Find the
run_train.shin thescriptfolder of the image dataset and the model you want to use in the training- For example, I used
RC-49_128x128.h5dataset(you can download easily through the link inREADME.md), andCcGAN-improved, so I went for therun_train.shscript under.../RC-49/RC-49_128x128/scriptsfolder - The reason why you need to choose the dataset and model before the training is because the preprocessing code are customised for specific dataset and the model for the training
- For example, I used
- Prepare the dataset, if you want to use your own
- Modify the
run_train.sh- Especially, you might need to modify
ROOT_PATHandDATA_PATHto make the shell script run - Also,
python main.py...this bit actually doesn't work as it is now, sincemain.pyis not on the same level, but on the level one upper than as it is now(you will see what I mean, if you run that script by facingfile not found error) - Comment unnecessary parts in
run_train.shto enhance the training speed- For example, I did like below
- Especially, you might need to modify
## Path
ROOT_PATH="../../"
DATA_PATH="/c/Users/msi/Desktop/workspace/001_practice/improved_CcGAN/RC-49/RC-49_128x128/CcGAN-improved/scripts/datasets"
EVAL_PATH="/c/Users/msi/Desktop/workspace/001_practice/improved_CcGAN/RC-49/RC-49_128x128/CcGAN-improved/output/eval_models"
SEED=2020
NUM_WORKERS=0
MIN_LABEL=0
MAX_LABEL=360
# 추가
MIN_LABEL_SCALE=0.0
MAX_LABEL_SCALE=1.0
IMG_SIZE=128
MAX_N_IMG_PER_LABEL=25
MAX_N_IMG_PER_LABEL_AFTER_REPLICA=0
NITERS=15000
BATCH_SIZE_G=36
BATCH_SIZE_D=36
NUM_D_STEPS=2
SIGMA=-1.0
KAPPA=-2.0
LR_G=1e-4
LR_D=1e-4
GAN_ARCH="SAGAN"
LOSS_TYPE="hinge"
NUM_EVAL_LABELS=-1
NFAKE_PER_LABEL=200
SAMP_BATCH_SIZE=1000
FID_RADIUS=0
FID_NUM_CENTERS=-1
# python pretrain_AE.py \
# --root_path $ROOT_PATH --data_path $DATA_PATH --seed $SEED --num_workers $NUM_WORKERS \
# --dim_bottleneck 512 --epochs 200 --resume_epoch 0 \
# --batch_size_train 256 --batch_size_valid 10 \
# --base_lr 1e-3 --lr_decay_epochs 50 --lr_decay_factor 0.1 \
# --lambda_sparsity 0 --weight_dacay 1e-4 \
# --img_size $IMG_SIZE --min_label $MIN_LABEL --max_label $MAX_LABEL \
# 2>&1 | tee output_AE.txt
# python pretrain_CNN_class.py \
# --root_path $ROOT_PATH --data_path $DATA_PATH --seed $SEED --num_workers $NUM_WORKERS \
# --CNN ResNet34_class \
# --epochs 200 --batch_size_train 256 --batch_size_valid 10 \
# --base_lr 0.01 --weight_dacay 1e-4 \
# --img_size $IMG_SIZE --min_label $MIN_LABEL --max_label $MAX_LABEL \
# 2>&1 | tee output_CNN_class.txt
# python pretrain_CNN_regre.py \
# --root_path $ROOT_PATH --data_path $DATA_PATH --seed $SEED --num_workers $NUM_WORKERS \
# --CNN ResNet34_regre \
# --epochs 200 --batch_size_train 256 --batch_size_valid 10 \
# --base_lr 0.01 --weight_dacay 1e-4 \
# --img_size $IMG_SIZE --min_label $MIN_LABEL --max_label $MAX_LABEL \
# 2>&1 | tee output_CNN_regre.txt
GAN="CcGAN"
DIM_GAN=256
DIM_EMBED=128
resume_niters_gan=2800
python ../main.py \
--root_path $ROOT_PATH --data_path $DATA_PATH --eval_ckpt_path $EVAL_PATH --seed $SEED --num_workers $NUM_WORKERS \
--min_label $MIN_LABEL --max_label $MAX_LABEL --img_size $IMG_SIZE \
--min_label_scale $MIN_LABEL_SCALE --max_label $MAX_LABEL_SCALE \
--max_num_img_per_label $MAX_N_IMG_PER_LABEL --max_num_img_per_label_after_replica $MAX_N_IMG_PER_LABEL_AFTER_REPLICA \
--GAN $GAN --GAN_arch $GAN_ARCH --niters_gan $NITERS --resume_niters_gan $resume_niters_gan --loss_type_gan $LOSS_TYPE \
--save_niters_freq 200 --visualize_freq 200 \
--batch_size_disc $BATCH_SIZE_D --batch_size_gene $BATCH_SIZE_G --num_D_steps $NUM_D_STEPS \
--lr_g $LR_G --lr_d $LR_D --dim_gan $DIM_GAN --dim_embed $DIM_EMBED \
--kernel_sigma $SIGMA --threshold_type soft --kappa $KAPPA \
--gan_DiffAugment --gan_DiffAugment_policy color,translation,cutout \
--visualize_fake_images \
--comp_FID --samp_batch_size $SAMP_BATCH_SIZE --FID_radius $FID_RADIUS --FID_num_centers $FID_NUM_CENTERS \
--num_eval_labels $NUM_EVAL_LABELS --nfake_per_label $NFAKE_PER_LABEL \
--dump_fake_for_NIQE \
2>&1 | tee output_CcGAN_30K.txt
# GAN="cGAN"
# DIM_GAN=128
# resume_niters_gan=0
# python main.py \
# --root_path $ROOT_PATH --data_path $DATA_PATH --eval_ckpt_path $EVAL_PATH --seed $SEED --num_workers $NUM_WORKERS \
# --min_label $MIN_LABEL --max_label $MAX_LABEL --img_size $IMG_SIZE \
# --max_num_img_per_label $MAX_N_IMG_PER_LABEL --max_num_img_per_label_after_replica $MAX_N_IMG_PER_LABEL_AFTER_REPLICA \
# --GAN $GAN --GAN_arch $GAN_ARCH --cGAN_num_classes 150 --niters_gan $NITERS --resume_niters_gan $resume_niters_gan --loss_type_gan $LOSS_TYPE \
# --save_niters_freq 2000 --visualize_freq 1000 \
# --batch_size_disc $BATCH_SIZE_D --batch_size_gene $BATCH_SIZE_G --num_D_steps $NUM_D_STEPS \
# --lr_g $LR_G --lr_d $LR_D --dim_gan $DIM_GAN \
# --gan_DiffAugment --gan_DiffAugment_policy color,translation,cutout \
# --visualize_fake_images \
# --comp_FID --samp_batch_size $SAMP_BATCH_SIZE --FID_radius $FID_RADIUS --FID_num_centers $FID_NUM_CENTERS \
# --num_eval_labels $NUM_EVAL_LABELS --nfake_per_label $NFAKE_PER_LABEL \
# --dump_fake_for_NIQE \
# 2>&1 | tee output_cGAN_150classes_30K.txt
# GAN="cGAN-concat"
# DIM_GAN=128
# resume_niters_gan=0
# python main.py \
# --root_path $ROOT_PATH --data_path $DATA_PATH --eval_ckpt_path $EVAL_PATH --seed $SEED --num_workers $NUM_WORKERS \
# --min_label $MIN_LABEL --max_label $MAX_LABEL --img_size $IMG_SIZE \
# --max_num_img_per_label $MAX_N_IMG_PER_LABEL --max_num_img_per_label_after_replica $MAX_N_IMG_PER_LABEL_AFTER_REPLICA \
# --GAN $GAN --GAN_arch $GAN_ARCH --niters_gan $NITERS --resume_niters_gan $resume_niters_gan --loss_type_gan $LOSS_TYPE \
# --save_niters_freq 2000 --visualize_freq 1000 \
# --batch_size_disc $BATCH_SIZE_D --batch_size_gene $BATCH_SIZE_G --num_D_steps $NUM_D_STEPS \
# --lr_g $LR_G --lr_d $LR_D --dim_gan $DIM_GAN \
# --gan_DiffAugment --gan_DiffAugment_policy color,translation,cutout \
# --visualize_fake_images \
# --comp_FID --samp_batch_size $SAMP_BATCH_SIZE --FID_radius $FID_RADIUS --FID_num_centers $FID_NUM_CENTERS \
# --num_eval_labels $NUM_EVAL_LABELS --nfake_per_label $NFAKE_PER_LABEL \
# --dump_fake_for_NIQE \
# 2>&1 | tee output_cGAN-concat_30K.txt
- Modify the code, if needed, especially if you use your own custom dataset
- I slightly changed the input condition, so I had to modify
main.pyandtrain_ccgan.py
- I slightly changed the input condition, so I had to modify
- See the generated result, which will be generated in
outputfolder - If you're using Jupyter notebook, then, you might need to set all the things done as I mentioned above, and invoke the shell script in the notebook cell like
!python run_train.sh
Hope this would help you
thank you so much! This is very helpful. So the dataset is loaded from the main.py correct? Specifically lines 70-78? And just curious do you know how I would pass in two conditional inputs instead of one? Do I need to pass the second input into the CNN to create an embedding and all of that (just like the first conditional input)?
-
Yes you're right on dataset loading part. The dataset is loaded in
main.pyand it is passed totrain_ccganfunction defined intrain_ccgan.py -
Actually, I was in the same case as yours, and I had to modify the input layer to receive two conditions
- In this case, you maybe right in your guess. But to be honest with you, I can't remember how I modified the code... 😢 I think I definitely modified model part as well.
-
I tried to upload my code on my repo, but couldn't do it somehow due to unnecessary files I staged on the commit 😢. Below are some of the files I changed.
-
The point is how
NUM_CONDITIONSvariable is used -
It might seem daunting, since it's very long. but if you try to find which parts are different from original one, you'll see it is not that difficult to understand the modified parts
-
Also, if you using a dataset different from
RC-49orCell200which are used in the paper, you would need to modify the data loading part more than I did, since the code is assuming to use those datasets
ResNet_embed.py
'''
ResNet-based model to map an image from pixel space to a features space.
Need to be pretrained on the dataset.
if isometric_map = True, there is an extra step (elf.classifier_1 = nn.Linear(512, 32*32*3)) to increase the dimension of the feature map from 512 to 32*32*3. This selection is for desity-ratio estimation in feature space.
codes are based on
@article{
zhang2018mixup,
title={mixup: Beyond Empirical Risk Minimization},
author={Hongyi Zhang, Moustapha Cisse, Yann N. Dauphin, David Lopez-Paz},
journal={International Conference on Learning Representations},
year={2018},
url={https://openreview.net/forum?id=r1Ddp1-Rb},
}
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
NC = 3
IMG_SIZE = 128
DIM_EMBED = 128
NUM_CONDITIONS = 2
#------------------------------------------------------------------------------
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet_embed(nn.Module):
def __init__(self, block, num_blocks, nc=NC, dim_embed=DIM_EMBED):
super(ResNet_embed, self).__init__()
self.in_planes = 64
self.main = nn.Sequential(
nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False), # h=h
# nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1, bias=False), # h=h/2
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2,2), #h=h/2 64
# self._make_layer(block, 64, num_blocks[0], stride=1), # h=h
self._make_layer(block, 64, num_blocks[0], stride=2), # h=h/2 32
self._make_layer(block, 128, num_blocks[1], stride=2), # h=h/2 16
self._make_layer(block, 256, num_blocks[2], stride=2), # h=h/2 8
self._make_layer(block, 512, num_blocks[3], stride=2), # h=h/2 4
# nn.AvgPool2d(kernel_size=4)
nn.AdaptiveAvgPool2d((1, 1))
)
self.x2h_res = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, dim_embed),
nn.BatchNorm1d(dim_embed),
nn.ReLU(),
)
self.h2y = nn.Sequential(
nn.Linear(dim_embed, NUM_CONDITIONS),
nn.ReLU()
)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
features = self.main(x)
features = features.view(features.size(0), -1)
features = self.x2h_res(features)
out = self.h2y(features)
return out, features
def ResNet18_embed(dim_embed=DIM_EMBED):
return ResNet_embed(BasicBlock, [2,2,2,2], dim_embed=dim_embed)
def ResNet34_embed(dim_embed=DIM_EMBED):
return ResNet_embed(BasicBlock, [3,4,6,3], dim_embed=dim_embed)
def ResNet50_embed(dim_embed=DIM_EMBED):
return ResNet_embed(Bottleneck, [3,4,6,3], dim_embed=dim_embed)
#------------------------------------------------------------------------------
# map labels to the embedding space
class model_y2h(nn.Module):
def __init__(self, dim_embed=DIM_EMBED):
super(model_y2h, self).__init__()
self.main = nn.Sequential(
nn.Linear(NUM_CONDITIONS, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
nn.ReLU()
)
def forward(self, y):
# y = y.view(-1, 1) +1e-8
y = y + 1e-8
# y = torch.exp(y.view(-1, 1))
return self.main(y)
if __name__ == "__main__":
net = ResNet34_embed(dim_embed=128).cuda()
x = torch.randn(16,NC,IMG_SIZE,IMG_SIZE).cuda()
out, features = net(x)
print(out.size())
print(features.size())
net_y2h = model_y2h().cuda()
y_hat = net_y2h(out)
print(f"{y_hat.size() = }")
train_net_for_label_embed.py
import torch
import torch.nn as nn
from torchvision.utils import save_image
import numpy as np
import os
import timeit
from PIL import Image
NUM_CONDITIONS = 2
#-------------------------------------------------------------
def train_net_embed(net, net_name, trainloader, testloader, epochs=200, resume_epoch = 0, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = None):
''' learning rate decay '''
def adjust_learning_rate_1(optimizer, epoch):
"""decrease the learning rate """
lr = lr_base
num_decays = len(lr_decay_epochs)
for decay_i in range(num_decays):
if epoch >= lr_decay_epochs[decay_i]:
lr = lr * lr_decay_factor
#end if epoch
#end for decay_i
for param_group in optimizer.param_groups:
param_group['lr'] = lr
net = net.cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
# resume training; load checkpoint
if path_to_ckpt is not None and resume_epoch>0:
save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(resume_epoch)
checkpoint = torch.load(save_file)
net.load_state_dict(checkpoint['net_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
torch.set_rng_state(checkpoint['rng_state'])
#end if
start_tmp = timeit.default_timer()
for epoch in range(resume_epoch, epochs):
net.train()
train_loss = 0
adjust_learning_rate_1(optimizer, epoch)
for _, (batch_train_images, batch_train_labels) in enumerate(trainloader):
# batch_train_images = nn.functional.interpolate(batch_train_images, size = (299,299), scale_factor=None, mode='bilinear', align_corners=False)
batch_train_images = batch_train_images.type(torch.float).cuda()
batch_train_labels = batch_train_labels.type(torch.float).view(-1,NUM_CONDITIONS).cuda()
#Forward pass
outputs, _ = net(batch_train_images)
loss = criterion(outputs, batch_train_labels)
#backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_loss += loss.cpu().item()
#end for batch_idx
train_loss = train_loss / len(trainloader)
if testloader is None:
print('Train net_x2y for embedding: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
else:
net.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance)
with torch.no_grad():
test_loss = 0
for batch_test_images, batch_test_labels in testloader:
batch_test_images = batch_test_images.type(torch.float).cuda()
batch_test_labels = batch_test_labels.type(torch.float).view(-1,NUM_CONDITIONS).cuda()
outputs,_ = net(batch_test_images)
loss = criterion(outputs, batch_test_labels)
test_loss += loss.cpu().item()
test_loss = test_loss/len(testloader)
print('Train net_x2y for label embedding: [epoch %d/%d] train_loss:%f test_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, test_loss, timeit.default_timer()-start_tmp))
#save checkpoint
if path_to_ckpt is not None and (((epoch+1) % 50 == 0) or (epoch+1==epochs)):
save_file = path_to_ckpt + "/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_{}.pth".format(epoch+1)
os.makedirs(os.path.dirname(save_file), exist_ok=True)
torch.save({
'epoch': epoch,
'net_state_dict': net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'rng_state': torch.get_rng_state()
}, save_file)
#end for epoch
return net
###################################################################################
class label_dataset(torch.utils.data.Dataset):
def __init__(self, labels):
super(label_dataset, self).__init__()
self.labels = labels
self.n_samples = len(self.labels)
def __getitem__(self, index):
y = self.labels[index]
return y
def __len__(self):
return self.n_samples
def train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=500, lr_base=0.01, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128):
'''
unique_labels_norm: an array of normalized unique labels
'''
''' learning rate decay '''
def adjust_learning_rate_2(optimizer, epoch):
"""decrease the learning rate """
lr = lr_base
num_decays = len(lr_decay_epochs)
for decay_i in range(num_decays):
if epoch >= lr_decay_epochs[decay_i]:
lr = lr * lr_decay_factor
#end if epoch
#end for decay_i
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# unique_labels_norm.shape == (B, NUM_CONDITIONS)로 가정
assert np.max(unique_labels_norm)<=1 and np.min(unique_labels_norm)>=0
trainset = label_dataset(unique_labels_norm)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
net_embed.eval()
net_h2y=net_embed.module.h2y #convert embedding labels to original labels
optimizer_y2h = torch.optim.SGD(net_y2h.parameters(), lr = lr_base, momentum= 0.9, weight_decay=weight_decay)
start_tmp = timeit.default_timer()
for epoch in range(epochs):
net_y2h.train()
train_loss = 0
adjust_learning_rate_2(optimizer_y2h, epoch)
for _, batch_labels in enumerate(trainloader):
batch_labels = batch_labels.type(torch.float).view(-1,NUM_CONDITIONS).cuda()
# generate noises which will be added to labels
batch_size_curr = len(batch_labels)
batch_gamma = np.random.normal(0, 0.2, (batch_size_curr, NUM_CONDITIONS))
batch_gamma = torch.from_numpy(batch_gamma).view(-1,NUM_CONDITIONS).type(torch.float).cuda()
# add noise to labels
batch_labels_noise = torch.clamp(batch_labels+batch_gamma, 0.0, 1.0)
#Forward pass
batch_hiddens_noise = net_y2h(batch_labels_noise)
batch_rec_labels_noise = net_h2y(batch_hiddens_noise)
loss = nn.MSELoss()(batch_rec_labels_noise, batch_labels_noise)
#backward pass
optimizer_y2h.zero_grad()
loss.backward()
optimizer_y2h.step()
train_loss += loss.cpu().item()
#end for batch_idx
train_loss = train_loss / len(trainloader)
print('\n Train net_y2h: [epoch %d/%d] train_loss:%f Time:%.4f' % (epoch+1, epochs, train_loss, timeit.default_timer()-start_tmp))
#end for epoch
return net_y2h
train_ccgan.py
import torch
import numpy as np
import os
import timeit
from PIL import Image
from torchvision.utils import save_image
import torch.cuda as cutorch
from utils import SimpleProgressBar, IMGs_dataset
from opts import parse_opts
from DiffAugment_pytorch import DiffAugment
NUM_CONDITIONS = 2
''' Settings '''
args = parse_opts()
# some parameters in opts
gan_arch = args.GAN_arch
loss_type = args.loss_type_gan
niters = args.niters_gan
resume_niters = args.resume_niters_gan
dim_gan = args.dim_gan
lr_g = args.lr_g_gan
lr_d = args.lr_d_gan
save_niters_freq = args.save_niters_freq
batch_size_disc = args.batch_size_disc
batch_size_gene = args.batch_size_gene
# batch_size_max = max(batch_size_disc, batch_size_gene)
num_D_steps = args.num_D_steps
visualize_freq = args.visualize_freq
num_workers = args.num_workers
threshold_type = args.threshold_type
nonzero_soft_weight_threshold = args.nonzero_soft_weight_threshold
num_channels = args.num_channels
img_size = args.img_size
max_label = args.max_label
use_DiffAugment = args.gan_DiffAugment
policy = args.gan_DiffAugment_policy
## normalize images
def normalize_images(batch_images):
batch_images = batch_images/255.0
batch_images = (batch_images - 0.5)/0.5
return batch_images
def train_ccgan(kernel_sigma, kappa, train_images, train_labels, netG, netD, net_y2h, save_images_folder, save_models_folder = None, clip_label=False):
'''
Note that train_images are not normalized to [-1,1]
'''
netG = netG.cuda()
netD = netD.cuda()
net_y2h = net_y2h.cuda()
net_y2h.eval()
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr_g, betas=(0.5, 0.999))
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr_d, betas=(0.5, 0.999))
if save_models_folder is not None and resume_niters>0:
save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, resume_niters)
checkpoint = torch.load(save_file)
netG.load_state_dict(checkpoint['netG_state_dict'])
netD.load_state_dict(checkpoint['netD_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizerG_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizerD_state_dict'])
torch.set_rng_state(checkpoint['rng_state'])
print(f"Got model from {save_file} successfully")
#end if
#################
# --- unique_train_labels = np.sort(np.array(list(set(train_labels))))
unique_train_labels_1 = np.sort(np.array(list(set(train_labels[:, 0]))))
unique_train_labels_2 = np.sort(np.array(list(set(train_labels[:, 1]))))
# print(f"------------{unique_train_labels_1 = }")
# print(f"------------{unique_train_labels_2 = }")
# printed images with labels between the 5-th quantile and 95-th quantile of training labels
n_row=10; n_col = n_row
z_fixed = torch.randn(n_row*n_col, dim_gan, dtype=torch.float).cuda()
# --- start_label = np.quantile(train_labels, 0.05)
# --- end_label = np.quantile(train_labels, 0.95)
# --- selected_labels = np.linspace(start_label, end_label, num=n_row)
start_label_1 = np.quantile(train_labels[:, 0], 0.05)
end_label_1 = np.quantile(train_labels[:, 0], 0.95)
selected_labels_1 = np.linspace(start_label_1, end_label_1, num=n_row)
start_label_2 = np.quantile(train_labels[:, 1], 0.05)
end_label_2 = np.quantile(train_labels[:, 1], 0.95)
selected_labels_2 = np.linspace(start_label_2, end_label_2, num=n_col)
# --- y_fixed = np.zeros(n_row*n_col)
# --- for i in range(n_row):
# --- curr_label = selected_labels[i]
# --- for j in range(n_col):
# --- y_fixed[i*n_col+j] = curr_label
y_fixed = np.zeros((n_row * n_col, NUM_CONDITIONS))
for i in range(n_row):
curr_label_1 = selected_labels_1[i]
for j in range(n_col):
curr_label_2 = selected_labels_2[j]
y_fixed[i*n_col+j, 0] = curr_label_1
y_fixed[i*n_col+j, 1] = curr_label_2
y_fixed = torch.from_numpy(y_fixed).type(torch.float).view(-1,NUM_CONDITIONS).cuda()
# print(f"{y_fixed = }")
start_time = timeit.default_timer()
for niter in range(resume_niters, niters):
''' Train Discriminator '''
for _ in range(num_D_steps):
# ## randomly draw batch_size_disc y's from unique_train_labels
# batch_target_labels_in_dataset = np.random.choice(unique_train_labels, size=batch_size_disc, replace=True)
# ## add Gaussian noise; we estimate image distribution conditional on these labels
# batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
# batch_target_labels = batch_target_labels_in_dataset + batch_epsilons
# ## find index of real images with labels in the vicinity of batch_target_labels
# ## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
# batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
# batch_fake_labels = np.zeros(batch_size_disc)
# for j in range(batch_size_disc):
# ## index for real images
# if threshold_type == "hard":
# indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
# else:
# # reverse the weight function for SVDL
# indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
# ## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
# while len(indx_real_in_vicinity)<1:
# batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
# batch_target_labels[j] = batch_target_labels_in_dataset[j] + batch_epsilons_j
# if clip_label:
# batch_target_labels = np.clip(batch_target_labels, 0.0, 1.0)
# ## index for real images
# if threshold_type == "hard":
# indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
# else:
# # reverse the weight function for SVDL
# indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
# #end while len(indx_real_in_vicinity)<1
# assert len(indx_real_in_vicinity)>=1
# batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]
# ## labels for fake images generation
# if threshold_type == "hard":
# lb = batch_target_labels[j] - kappa
# ub = batch_target_labels[j] + kappa
# else:
# lb = batch_target_labels[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
# ub = batch_target_labels[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
# lb = max(0.0, lb); ub = min(ub, 1.0)
# assert lb<=ub
# assert lb>=0 and ub>=0
# assert lb<=1 and ub<=1
# batch_fake_labels[j] = np.random.uniform(lb, ub, size=1)[0]
# #end for j
# ----------------------------------------------------------------------------------------------------------------
## randomly draw batch_size_disc y's from unique_train_labels
batch_target_labels_in_dataset_1 = np.random.choice(unique_train_labels_1, size=batch_size_disc, replace=True)
batch_target_labels_in_dataset_2 = np.random.choice(unique_train_labels_2, size=batch_size_disc, replace=True)
## add Gaussian noise; we estimate image distribution conditional on these labels
batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_disc)
batch_target_labels_1 = batch_target_labels_in_dataset_1 + batch_epsilons
batch_target_labels_2 = batch_target_labels_in_dataset_2 + batch_epsilons
## find index of real images with labels in the vicinity of batch_target_labels
## generate labels for fake image generation; these labels are also in the vicinity of batch_target_labels
batch_real_indx = np.zeros(batch_size_disc, dtype=int) #index of images in the datata; the labels of these images are in the vicinity
batch_fake_labels = np.zeros((batch_size_disc, NUM_CONDITIONS))
for j in range(batch_size_disc):
## index for real images
if threshold_type == "hard":
indx_real_in_vicinity_1 = np.where(np.abs(train_labels[:, 0]-batch_target_labels_1[j])<= kappa)[0]
indx_real_in_vicinity_2 = np.where(np.abs(train_labels[:, 1]-batch_target_labels_2[j])<= kappa)[0]
else:
# reverse the weight function for SVDL
indx_real_in_vicinity_1 = np.where((train_labels[:, 0]-batch_target_labels_1[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
indx_real_in_vicinity_2 = np.where((train_labels[:, 1]-batch_target_labels_2[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
indx_real_in_vicinity = np.array(list(set.intersection(set(indx_real_in_vicinity_1), set(indx_real_in_vicinity_2))))
## if the max gap between two consecutive ordered unique labels is large, it is possible that len(indx_real_in_vicinity)<1
while len(indx_real_in_vicinity)<1:
batch_epsilons_j = np.random.normal(0, kernel_sigma, 1)
batch_target_labels_1[j] = batch_target_labels_in_dataset_1[j] + batch_epsilons_j
batch_target_labels_2[j] = batch_target_labels_in_dataset_2[j] + batch_epsilons_j
if clip_label:
batch_target_labels_1 = np.clip(batch_target_labels_1, 0.0, 1.0)
batch_target_labels_2 = np.clip(batch_target_labels_2, 0.0, 1.0)
## index for real images
# if threshold_type == "hard":
# indx_real_in_vicinity = np.where(np.abs(train_labels-batch_target_labels[j])<= kappa)[0]
# else:
# # reverse the weight function for SVDL
# indx_real_in_vicinity = np.where((train_labels-batch_target_labels[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
if threshold_type == "hard":
indx_real_in_vicinity_1 = np.where(np.abs(train_labels[:, 0]-batch_target_labels_1[j])<= kappa)[0]
indx_real_in_vicinity_2 = np.where(np.abs(train_labels[:, 1]-batch_target_labels_2[j])<= kappa)[0]
else:
# reverse the weight function for SVDL
indx_real_in_vicinity_1 = np.where((train_labels[:, 0]-batch_target_labels_1[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
indx_real_in_vicinity_2 = np.where((train_labels[:, 1]-batch_target_labels_2[j])**2 <= -np.log(nonzero_soft_weight_threshold)/kappa)[0]
indx_real_in_vicinity = np.array(list(set.intersection(set(indx_real_in_vicinity_1), set(indx_real_in_vicinity_2))))
#end while len(indx_real_in_vicinity)<1
assert len(indx_real_in_vicinity)>=1
batch_real_indx[j] = np.random.choice(indx_real_in_vicinity, size=1)[0]
## labels for fake images generation
if threshold_type == "hard":
lb_1 = batch_target_labels_1[j] - kappa
ub_1 = batch_target_labels_1[j] + kappa
lb_2 = batch_target_labels_2[j] - kappa
ub_2 = batch_target_labels_2[j] + kappa
else:
lb_1 = batch_target_labels_1[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
ub_1 = batch_target_labels_1[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
lb_2 = batch_target_labels_2[j] - np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
ub_2 = batch_target_labels_2[j] + np.sqrt(-np.log(nonzero_soft_weight_threshold)/kappa)
lb_1 = max(0.0, lb_1); ub_1 = min(ub_1, 1.0)
lb_2 = max(0.0, lb_2); ub_2 = min(ub_2, 1.0)
assert lb_1<=ub_1
assert lb_2<=ub_2
assert lb_1>=0 and ub_1>=0
assert lb_1<=1 and ub_1<=1
assert lb_2>=0 and ub_2>=0
assert lb_2<=1 and ub_2<=1
batch_fake_labels[j] = np.array(np.random.uniform(lb_1, ub_1, size=1)[0], np.random.uniform(lb_2, ub_2, size=1)[0])
batch_target_labels = np.stack([batch_target_labels_1, batch_target_labels_2], axis=1)
#end for j
# ----------------------------------------------------------------------------------------------------------------
# print(f"===1 {batch_target_labels.shape = }")
# print(f"===1 {np.min(batch_target_labels) = }, {np.max(batch_target_labels) = }")
# print(f"===2 {batch_fake_labels.shape = }")
# print(f"===2 {np.min(batch_fake_labels) = }, {np.max(batch_fake_labels) = }")
## draw real image/label batch from the training set
batch_real_images = torch.from_numpy(normalize_images(train_images[batch_real_indx]))
batch_real_images = batch_real_images.type(torch.float).cuda()
batch_real_labels = train_labels[batch_real_indx]
batch_real_labels = torch.from_numpy(batch_real_labels).type(torch.float).cuda()
# print(f"===3 {batch_real_labels.shape = }")
# print(f"===3 {torch.min(batch_real_labels) = }, {torch.max(batch_real_labels) = }")
# print(f"===4 {batch_real_images.shape = }")
# print(f"===4 {torch.min(batch_real_images) = }, {torch.max(batch_real_images) = }")
## generate the fake image batch
batch_fake_labels = torch.from_numpy(batch_fake_labels).type(torch.float).cuda()
z = torch.randn(batch_size_disc, dim_gan, dtype=torch.float).cuda()
batch_fake_images = netG(z, net_y2h(batch_fake_labels))
# print(f"===5 {batch_fake_images.shape = }")
# print(f"===5 {torch.min(batch_fake_images) = }, {torch.max(batch_fake_images) = }")
## target labels on gpu
batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
## weight vector
if threshold_type == "soft":
real_weights = torch.exp(-kappa*(batch_real_labels-batch_target_labels)**2).cuda()
fake_weights = torch.exp(-kappa*(batch_fake_labels-batch_target_labels)**2).cuda()
else:
real_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
fake_weights = torch.ones(batch_size_disc, dtype=torch.float).cuda()
#end if threshold type
# forward pass
if use_DiffAugment:
real_dis_out = netD(DiffAugment(batch_real_images, policy=policy), net_y2h(batch_target_labels))
fake_dis_out = netD(DiffAugment(batch_fake_images.detach(), policy=policy), net_y2h(batch_target_labels))
else:
real_dis_out = netD(batch_real_images, net_y2h(batch_target_labels))
fake_dis_out = netD(batch_fake_images.detach(), net_y2h(batch_target_labels))
# print(f"===6 {real_dis_out.shape = }")
# print(f"===6 {torch.min(real_dis_out) = }, {torch.max(real_dis_out) = }")
# print(f"===7 {real_weights.shape = }")
# print(f"===7 {torch.min(real_weights) = }, {torch.max(real_weights) = }")
if loss_type == "vanilla":
real_dis_out = torch.nn.Sigmoid()(real_dis_out)
fake_dis_out = torch.nn.Sigmoid()(fake_dis_out)
d_loss_real = - torch.log(real_dis_out+1e-20)
d_loss_fake = - torch.log(1-fake_dis_out+1e-20)
elif loss_type == "hinge":
d_loss_real = torch.nn.ReLU()(1.0 - real_dis_out)
d_loss_fake = torch.nn.ReLU()(1.0 + fake_dis_out)
else:
raise ValueError('Not supported loss type!!!')
# TODO: 추가했으나, real_weights의 dimension을 줄이는 게 나을지 netD의 결과물 dimension을 늘리는 게 좋을지 모르겠음
real_weights = torch.mean(real_weights, axis=1)
fake_weights = torch.mean(fake_weights, axis=1)
d_loss = torch.mean(real_weights.view(-1) * d_loss_real.view(-1)) + torch.mean(fake_weights.view(-1) * d_loss_fake.view(-1))
optimizerD.zero_grad()
d_loss.backward()
optimizerD.step()
#end for step_D_index
''' Train Generator '''
netG.train()
# generate fake images
## randomly draw batch_size_gene y's from unique_train_labels
batch_target_labels_in_dataset_1 = np.random.choice(unique_train_labels_1, size=batch_size_gene, replace=True)
batch_target_labels_in_dataset_2 = np.random.choice(unique_train_labels_2, size=batch_size_gene, replace=True)
## add Gaussian noise; we estimate image distribution conditional on these labels
batch_epsilons = np.random.normal(0, kernel_sigma, batch_size_gene)
batch_target_labels_1 = batch_target_labels_in_dataset_1 + batch_epsilons
batch_target_labels_2 = batch_target_labels_in_dataset_2 + batch_epsilons
batch_target_labels = np.stack([batch_target_labels_1, batch_target_labels_2], axis=1)
batch_target_labels = torch.from_numpy(batch_target_labels).type(torch.float).cuda()
z = torch.randn(batch_size_gene, dim_gan, dtype=torch.float).cuda()
batch_fake_images = netG(z, net_y2h(batch_target_labels))
# loss
if use_DiffAugment:
dis_out = netD(DiffAugment(batch_fake_images, policy=policy), net_y2h(batch_target_labels))
else:
dis_out = netD(batch_fake_images, net_y2h(batch_target_labels))
if loss_type == "vanilla":
dis_out = torch.nn.Sigmoid()(dis_out)
g_loss = - torch.mean(torch.log(dis_out+1e-20))
elif loss_type == "hinge":
g_loss = - dis_out.mean()
# backward
optimizerG.zero_grad()
g_loss.backward()
optimizerG.step()
# print loss
if (niter+1) % 20 == 0:
print ("CcGAN,%s: [Iter %d/%d] [D loss: %.4e] [G loss: %.4e] [real prob: %.3f] [fake prob: %.3f] [Time: %.4f]" % (gan_arch, niter+1, niters, d_loss.item(), g_loss.item(), real_dis_out.mean().item(), fake_dis_out.mean().item(), timeit.default_timer()-start_time))
if (niter+1) % visualize_freq == 0:
netG.eval()
with torch.no_grad():
gen_imgs = netG(z_fixed, net_y2h(y_fixed))
gen_imgs = gen_imgs.detach().cpu()
save_image(gen_imgs.data, save_images_folder + '/{}.png'.format(niter+1), nrow=n_row, normalize=True)
if save_models_folder is not None and ((niter+1) % save_niters_freq == 0 or (niter+1) == niters):
save_file = save_models_folder + "/CcGAN_{}_{}_nDsteps_{}_checkpoint_intrain/CcGAN_checkpoint_niters_{}.pth".format(gan_arch, threshold_type, num_D_steps, niter+1)
os.makedirs(os.path.dirname(save_file), exist_ok=True)
torch.save({
'netG_state_dict': netG.state_dict(),
'netD_state_dict': netD.state_dict(),
'optimizerG_state_dict': optimizerG.state_dict(),
'optimizerD_state_dict': optimizerD.state_dict(),
'rng_state': torch.get_rng_state()
}, save_file)
#end for niter
return netG, netD
def sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = 500, to_numpy=True, denorm=True, verbose=True):
'''
netG: pretrained generator network
labels: float. normalized labels.
'''
nfake = len(labels)
if batch_size>nfake:
batch_size=nfake
fake_images = []
fake_labels = np.concatenate((labels, labels[0:batch_size]))
netG=netG.cuda()
netG.eval()
net_y2h = net_y2h.cuda()
net_y2h.eval()
with torch.no_grad():
if verbose:
pb = SimpleProgressBar()
n_img_got = 0
while n_img_got < nfake:
z = torch.randn(batch_size, dim_gan, dtype=torch.float).cuda()
y = torch.from_numpy(fake_labels[n_img_got:(n_img_got+batch_size)]).type(torch.float).view(-1,1).cuda()
batch_fake_images = netG(z, net_y2h(y))
if denorm: #denorm imgs to save memory
assert batch_fake_images.max().item()<=1.0 and batch_fake_images.min().item()>=-1.0
batch_fake_images = batch_fake_images*0.5+0.5
batch_fake_images = batch_fake_images*255.0
batch_fake_images = batch_fake_images.type(torch.uint8)
# assert batch_fake_images.max().item()>1
fake_images.append(batch_fake_images.cpu())
n_img_got += batch_size
if verbose:
pb.update(min(float(n_img_got)/nfake, 1)*100)
##end while
fake_images = torch.cat(fake_images, dim=0)
#remove extra entries
fake_images = fake_images[0:nfake]
fake_labels = fake_labels[0:nfake]
if to_numpy:
fake_images = fake_images.numpy()
return fake_images, fake_labels
if __name__ == "__main__":
from models import CcGAN_SAGAN_Generator, CcGAN_SAGAN_Discriminator
from models.ResNet_embed import model_y2h
B = 500
images_train = np.random.normal(size=(B, 3, 128, 128))
labels_train = np.random.normal(size=(B, 2))
kernel_sigma = -1.0
if kernel_sigma < 0:
kernel_sigma = 1.06 * np.std(labels_train) * (len(labels_train))**(-1/5)
kappa = -1
if kappa < 0:
unique_labels_norm = np.unique(labels_train[:, 0])
n_unique = len(unique_labels_norm)
diff_list = []
for i in range(1, n_unique):
diff_list.append(unique_labels_norm[i] - unique_labels_norm[i-1])
kappa_base = np.abs(kappa) * np.max(np.array(diff_list))
# threshold_type 관련 분기가 있지만 여기서는 soft로 진행
kappa = 1 / kappa_base ** 2
dim_embed = 128
netG = CcGAN_SAGAN_Generator(dim_z=dim_gan, dim_embed=128)
netD = CcGAN_SAGAN_Discriminator(dim_embed=dim_embed)
net_y2h = model_y2h(dim_embed=dim_embed)
save_image_in_train_folder = "."
save_models_folder = "."
# images_train = torch.from_numpy(images_train).type(torch.float)
# labels_train = torch.from_numpy(labels_train).type(torch.float)
netG, netD = train_ccgan(kernel_sigma,
kappa,
images_train,
labels_train,
netG,
netD,
net_y2h,
save_image_in_train_folder,
save_models_folder)
main.py
print("\n===================================================================================================")
import argparse
import copy
import gc
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import matplotlib as mpl
import h5py
import os
import random
from tqdm import tqdm
import torch
import torchvision
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torchvision.utils import save_image
import timeit
from PIL import Image
import sys
### import my stuffs ###
from opts import parse_opts
args = parse_opts()
wd = args.root_path
os.chdir(wd)
from utils import *
from models import *
from train_cgan import train_cgan, sample_cgan_given_labels
from train_cgan_concat import train_cgan_concat, sample_cgan_concat_given_labels
from train_ccgan import train_ccgan, sample_ccgan_given_labels
from train_net_for_label_embed import train_net_embed, train_net_y2h
from eval_metrics import cal_FID, cal_labelscore
#######################################################################################
''' Settings '''
#######################################################################################
#-------------------------------
# seeds
random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = True
cudnn.benchmark = False
np.random.seed(args.seed)
NUM_CONDITIONS = 2
#-------------------------------
# output folders
path_to_output = os.path.join(wd, "output/output_{}_arch_{}".format(args.GAN, args.GAN_arch))
print(f"{path_to_output = }")
os.makedirs(path_to_output, exist_ok=True)
save_models_folder = os.path.join(path_to_output, 'saved_models')
os.makedirs(save_models_folder, exist_ok=True)
save_images_folder = os.path.join(path_to_output, 'saved_images')
os.makedirs(save_images_folder, exist_ok=True)
path_to_embed_models = os.path.join(wd, 'output/embed_models')
os.makedirs(path_to_embed_models, exist_ok=True)
#-------------------------------
# Embedding
base_lr_x2y = 0.01
base_lr_y2h = 0.01
#######################################################################################
''' Data loader '''
#######################################################################################
# data loader
# data_filename = args.data_path + '/RC-49_{}x{}.h5'.format(args.img_size, args.img_size)
# 신규 데이터셋
data_filename = args.data_path + '/RC-49_{}x{}_downscale.h5'.format(args.img_size, args.img_size)
hf = h5py.File(data_filename, 'r')
labels_all = hf['labels'][:]
labels_all = labels_all.astype(float)
images_all = hf['images'][:]
indx_train = hf['indx_train'][:]
hf.close()
print("\n RC-49 dataset shape: {}x{}x{}x{}".format(images_all.shape[0], images_all.shape[1], images_all.shape[2], images_all.shape[3]))
# data split
if args.data_split == "train":
images_train = images_all[indx_train]
labels_train_raw = labels_all[indx_train]
else:
images_train = copy.deepcopy(images_all)
labels_train_raw = copy.deepcopy(labels_all)
# only take images with label in (q1, q2)
q1 = args.min_label
q2 = args.max_label
# # 여러 개의 조건일 경우
if labels_train_raw.shape[-1] == NUM_CONDITIONS:
scale_q1 = args.min_label_scale
scale_q2 = args.max_label_scale
indx_1 = np.where((labels_train_raw[:, 0]>q1)*(labels_train_raw[:, 0]<q2)==True)[0]
indx_2 = np.where((labels_train_raw[:, 1]>scale_q1)*(labels_train_raw[:, 1]<scale_q2)==True)[0]
indx = np.array(list(set.intersection(set(indx_1), set(indx_2))))
else:
indx = np.where((labels_train_raw>q1)*(labels_train_raw<q2)==True)[0]
labels_train_raw = labels_train_raw[indx]
images_train = images_train[indx]
assert len(labels_train_raw)==len(images_train)
if args.visualize_fake_images or args.comp_FID:
indx = np.where((labels_all>q1)*(labels_all<q2)==True)[0]
labels_all = labels_all[indx]
images_all = images_all[indx]
assert len(labels_all)==len(images_all)
### show some real images
if args.show_real_imgs:
unique_labels_show = np.array(sorted(list(set(labels_all))))
indx_show = np.arange(0, len(unique_labels_show), len(unique_labels_show)//9)
unique_labels_show = unique_labels_show[indx_show]
nrow = len(unique_labels_show); ncol = 1
sel_labels_indx = []
for i in range(nrow):
curr_label = unique_labels_show[i]
indx_curr_label = np.where(labels_all==curr_label)[0]
np.random.shuffle(indx_curr_label)
indx_curr_label = indx_curr_label[0:ncol]
sel_labels_indx.extend(list(indx_curr_label))
sel_labels_indx = np.array(sel_labels_indx)
images_show = images_all[sel_labels_indx]
print(images_show.mean())
images_show = (images_show/255.0-0.5)/0.5
images_show = torch.from_numpy(images_show)
save_image(images_show.data, save_images_folder +'/real_images_grid_{}x{}.png'.format(nrow, ncol), nrow=ncol, normalize=True)
# for each angle, take no more than args.max_num_img_per_label images
# image_num_threshold = args.max_num_img_per_label
# print("\n Original set has {} images; For each angle, take no more than {} images>>>".format(len(images_train), image_num_threshold))
# unique_labels_tmp = np.sort(np.array(list(set(labels_train_raw))))
# for i in tqdm(range(len(unique_labels_tmp))):
# indx_i = np.where(labels_train_raw == unique_labels_tmp[i])[0]
# if len(indx_i)>image_num_threshold:
# np.random.shuffle(indx_i)
# indx_i = indx_i[0:image_num_threshold]
# if i == 0:
# sel_indx = indx_i
# else:
# sel_indx = np.concatenate((sel_indx, indx_i))
# images_train = images_train[sel_indx]
# labels_train_raw = labels_train_raw[sel_indx]
print("{} images left and there are {}, {} unique labels".format(len(images_train), len(set(labels_train_raw[:, 0])), len(set(labels_train_raw[:, 1]))))
# normalize labels_train_raw
print("\n Range of unnormalized labels for first axis: ({},{})".format(np.min(labels_train_raw[:, 0]), np.max(labels_train_raw[:, 0])))
print("\n Range of unnormalized labels for second axis: ({},{})".format(np.min(labels_train_raw[:, 1]), np.max(labels_train_raw[:, 1])))
if args.GAN == "cGAN": #treated as classification; convert angles to class labels
unique_labels = np.sort(np.array(list(set(labels_train_raw))))
num_unique_labels = len(unique_labels)
print("{} unique labels are split into {} classes".format(num_unique_labels, args.cGAN_num_classes))
## convert steering angles to class labels and vice versa
### step 1: prepare two dictionaries
label2class = dict()
class2label = dict()
num_labels_per_class = num_unique_labels//args.cGAN_num_classes
class_cutoff_points = [unique_labels[0]] #the cutoff points on [min_label, max_label] to determine classes
curr_class = 0
for i in range(num_unique_labels):
label2class[unique_labels[i]]=curr_class
if (i+1)%num_labels_per_class==0 and (curr_class+1)!=args.cGAN_num_classes:
curr_class += 1
class_cutoff_points.append(unique_labels[i+1])
class_cutoff_points.append(unique_labels[-1])
assert len(class_cutoff_points)-1 == args.cGAN_num_classes
for i in range(args.cGAN_num_classes):
class2label[i] = (class_cutoff_points[i]+class_cutoff_points[i+1])/2
### step 2: convert angles to class labels
labels_new = -1*np.ones(len(labels_train_raw))
for i in range(len(labels_train_raw)):
labels_new[i] = label2class[labels_train_raw[i]]
assert np.sum(labels_new<0)==0
labels_train = labels_new
del labels_new; gc.collect()
unique_labels = np.sort(np.array(list(set(labels_train)))).astype(int)
assert len(unique_labels) == args.cGAN_num_classes
elif args.GAN == "CcGAN":
if labels_train_raw.shape[-1] == NUM_CONDITIONS:
labels_train = labels_train_raw / [args.max_label, args.max_label_scale]
else:
labels_train = labels_train_raw / args.max_label
print("\n Range of normalized labels: ({},{})".format(np.min(labels_train), np.max(labels_train)))
# normalised 된 조건 2개에 대해서 진행
# unique_labels_norm = np.sort(np.array(list(set(labels_train[:, 0]))))
# if args.kernel_sigma<0:
# std_label = np.std(labels_train)
# args.kernel_sigma = 1.06*std_label*(len(labels_train))**(-1/5)
# print("\n Use rule-of-thumb formula to compute kernel_sigma >>>")
# print("\n The std of {} labels is {} so the kernel sigma is {}".format(len(labels_train), std_label, args.kernel_sigma))
# if args.kappa<0:
# n_unique = len(unique_labels_norm)
# diff_list = []
# for i in range(1,n_unique):
# diff_list.append(unique_labels_norm[i] - unique_labels_norm[i-1])
# kappa_base = np.abs(args.kappa)*np.max(np.array(diff_list))
# if args.threshold_type=="hard":
# args.kappa = kappa_base
# else:
# args.kappa = 1/kappa_base**2
unique_labels_norm_1 = np.sort(np.array(list(set(labels_train[:, 0]))))
unique_labels_norm_2 = np.sort(np.array(list(set(labels_train[:, 1]))))
unique_labels_norm = np.zeros((len(unique_labels_norm_1) * len(unique_labels_norm_2), 2))
for idx_1, c_1 in enumerate(unique_labels_norm_1):
for idx_2, c_2 in enumerate(unique_labels_norm_2):
unique_labels_norm[idx_1*len(unique_labels_norm_2) + idx_2, 0] = c_1
unique_labels_norm[idx_1*len(unique_labels_norm_2) + idx_2, 1] = c_2
if args.kernel_sigma<0:
std_label = np.std(labels_train)
args.kernel_sigma = 1.06*std_label*(len(labels_train))**(-1/5)
print("\n Use rule-of-thumb formula to compute kernel_sigma >>>")
print("\n The std of {} labels is {} so the kernel sigma is {}".format(len(labels_train), std_label, args.kernel_sigma))
# TODO: 해당 부분은 조건 1개에 대해서만 일단 진행
if args.kappa<0:
n_unique = len(unique_labels_norm)
diff_list = []
for i in range(1,n_unique):
diff_list.append(unique_labels_norm[i] - unique_labels_norm[i-1])
kappa_base = np.abs(args.kappa)*np.max(np.array(diff_list))
if args.threshold_type=="hard":
args.kappa = kappa_base
else:
args.kappa = 1/kappa_base**2
elif args.GAN == "cGAN-concat":
labels_train = labels_train_raw / args.max_label
print("\n Range of normalized labels: ({},{})".format(np.min(labels_train), np.max(labels_train)))
else:
raise ValueError('Not supported')
## end if args.GAN
#######################################################################################
''' Pre-trained CNN and GAN for label embedding '''
#######################################################################################
if args.GAN == "CcGAN":
net_embed_filename_ckpt = os.path.join(path_to_embed_models, 'ckpt_{}_epoch_{}_seed_{}.pth'.format(args.net_embed, args.epoch_cnn_embed, args.seed))
net_y2h_filename_ckpt = os.path.join(path_to_embed_models, 'ckpt_net_y2h_epoch_{}_seed_{}.pth'.format(args.epoch_net_y2h, args.seed))
print("\n "+net_embed_filename_ckpt)
print("\n "+net_y2h_filename_ckpt)
trainset = IMGs_dataset(images_train, labels_train, normalize=True)
trainloader_embed_net = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size_embed, shuffle=True, num_workers=args.num_workers)
if args.net_embed == "ResNet18_embed":
net_embed = ResNet18_embed(dim_embed=args.dim_embed)
elif args.net_embed == "ResNet34_embed":
net_embed = ResNet34_embed(dim_embed=args.dim_embed)
elif args.net_embed == "ResNet50_embed":
net_embed = ResNet50_embed(dim_embed=args.dim_embed)
net_embed = net_embed.cuda()
net_embed = nn.DataParallel(net_embed)
net_y2h = model_y2h(dim_embed=args.dim_embed)
net_y2h = net_y2h.cuda()
net_y2h = nn.DataParallel(net_y2h)
## (1). Train net_embed first: x2h+h2y
if not os.path.isfile(net_embed_filename_ckpt):
print("\n Start training CNN for label embedding >>>")
net_embed = train_net_embed(net=net_embed, net_name=args.net_embed, trainloader=trainloader_embed_net, testloader=None, epochs=args.epoch_cnn_embed, resume_epoch = args.resumeepoch_cnn_embed, lr_base=base_lr_x2y, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = path_to_embed_models)
# save model
torch.save({
'net_state_dict': net_embed.state_dict(),
}, net_embed_filename_ckpt)
else:
print("\n net_embed ckpt already exists")
print("\n Loading...")
checkpoint = torch.load(net_embed_filename_ckpt)
net_embed.load_state_dict(checkpoint['net_state_dict'])
#end not os.path.isfile
## (2). Train y2h
#train a net which maps a label back to the embedding space
if not os.path.isfile(net_y2h_filename_ckpt):
print("\n Start training net_y2h >>>")
net_y2h = train_net_y2h(unique_labels_norm, net_y2h, net_embed, epochs=args.epoch_net_y2h, lr_base=base_lr_y2h, lr_decay_factor=0.1, lr_decay_epochs=[150, 250, 350], weight_decay=1e-4, batch_size=128)
# save model
torch.save({
'net_state_dict': net_y2h.state_dict(),
}, net_y2h_filename_ckpt)
else:
print("\n net_y2h ckpt already exists")
print("\n Loading...")
checkpoint = torch.load(net_y2h_filename_ckpt)
net_y2h.load_state_dict(checkpoint['net_state_dict'])
#end not os.path.isfile
##some simple test
indx_tmp = np.arange(len(unique_labels_norm))
np.random.shuffle(indx_tmp)
indx_tmp = indx_tmp[:10]
labels_tmp = unique_labels_norm[indx_tmp].reshape(-1,NUM_CONDITIONS)
labels_tmp = torch.from_numpy(labels_tmp).type(torch.float).cuda()
epsilons_tmp = np.random.normal(0, 0.2, (len(labels_tmp), NUM_CONDITIONS))
epsilons_tmp = torch.from_numpy(epsilons_tmp).view(-1,NUM_CONDITIONS).type(torch.float).cuda()
labels_tmp = torch.clamp(labels_tmp+epsilons_tmp, 0.0, 1.0)
net_embed.eval()
net_h2y = net_embed.module.h2y
net_y2h.eval()
with torch.no_grad():
labels_rec_tmp = net_h2y(net_y2h(labels_tmp)).cpu().numpy().reshape(-1,NUM_CONDITIONS)
# results = np.concatenate((labels_tmp.cpu().numpy(), labels_rec_tmp), axis=1)
print("\n labels vs reconstructed labels")
# print(results)
print(labels_tmp)
print()
print(labels_rec_tmp)
#put models on cpu
net_embed = net_embed.cpu()
net_h2y = net_h2y.cpu()
del net_embed, net_h2y; gc.collect()
net_y2h = net_y2h.cpu()
#######################################################################################
''' GAN training '''
#######################################################################################
if args.GAN == 'CcGAN':
print("CcGAN: {}, {}, Sigma is {}, Kappa is {}.".format(args.GAN_arch, args.threshold_type, args.kernel_sigma, args.kappa))
save_images_in_train_folder = save_images_folder + '/{}_{}_{}_{}_in_train'.format(args.GAN_arch, args.threshold_type, args.kernel_sigma, args.kappa)
elif args.GAN == "cGAN":
print("cGAN: {}, {} classes.".format(args.GAN_arch, args.cGAN_num_classes))
save_images_in_train_folder = save_images_folder + '/{}_{}_in_train'.format(args.GAN_arch, args.cGAN_num_classes)
elif args.GAN == "cGAN-concat":
print("cGAN-concat: {}.".format(args.GAN_arch))
save_images_in_train_folder = save_images_folder + '/{}_in_train'.format(args.GAN_arch)
os.makedirs(save_images_in_train_folder, exist_ok=True)
start = timeit.default_timer()
print("\n Begin Training %s:" % args.GAN)
#----------------------------------------------
# cGAN: treated as a classification dataset
if args.GAN == "cGAN":
Filename_GAN = save_models_folder + '/ckpt_{}_niters_{}_nDsteps_{}_nclass_{}_seed_{}.pth'.format(args.GAN_arch, args.niters_gan, args.num_D_steps, args.cGAN_num_classes, args.seed)
print(Filename_GAN)
if not os.path.isfile(Filename_GAN):
print("There are {} unique labels".format(len(unique_labels)))
if args.GAN_arch=="SAGAN":
netG = cGAN_SAGAN_Generator(z_dim=args.dim_gan, num_classes=args.cGAN_num_classes)
netD = cGAN_SAGAN_Discriminator(num_classes=args.cGAN_num_classes)
else:
raise ValueError('Do not support!!!')
netG = nn.DataParallel(netG)
netD = nn.DataParallel(netD)
# Start training
netG, netD = train_cgan(images_train, labels_train, netG, netD, save_images_folder=save_images_in_train_folder, save_models_folder = save_models_folder)
# store model
torch.save({
'netG_state_dict': netG.state_dict(),
}, Filename_GAN)
else:
print("Loading pre-trained generator >>>")
checkpoint = torch.load(Filename_GAN)
netG = cGAN_SAGAN_Generator(z_dim=args.dim_gan, num_classes=args.cGAN_num_classes).cuda()
netG = nn.DataParallel(netG)
netG.load_state_dict(checkpoint['netG_state_dict'])
# function for sampling from a trained GAN
def fn_sampleGAN_given_labels(labels, batch_size):
labels = labels*args.max_label
fake_images, fake_labels = sample_cgan_given_labels(netG, labels, class_cutoff_points=class_cutoff_points, batch_size = batch_size)
fake_labels = fake_labels / args.max_label
return fake_images, fake_labels
#----------------------------------------------
# cGAN: simple concatenation
elif args.GAN == "cGAN-concat":
Filename_GAN = save_models_folder + '/ckpt_{}_niters_{}_nDsteps_{}_seed_{}.pth'.format(args.GAN_arch, args.niters_gan, args.num_D_steps, args.seed)
print(Filename_GAN)
if not os.path.isfile(Filename_GAN):
if args.GAN_arch=="SAGAN":
netG = cGAN_concat_SAGAN_Generator(z_dim=args.dim_gan)
netD = cGAN_concat_SAGAN_Discriminator()
else:
raise ValueError('Do not support!!!')
netG = nn.DataParallel(netG)
netD = nn.DataParallel(netD)
# Start training
netG, netD = train_cgan_concat(images_train, labels_train, netG, netD, save_images_folder=save_images_in_train_folder, save_models_folder = save_models_folder)
# store model
torch.save({
'netG_state_dict': netG.state_dict(),
}, Filename_GAN)
else:
print("Loading pre-trained generator >>>")
checkpoint = torch.load(Filename_GAN)
netG = cGAN_concat_SAGAN_Generator(z_dim=args.dim_gan).cuda()
netG = nn.DataParallel(netG)
netG.load_state_dict(checkpoint['netG_state_dict'])
# function for sampling from a trained GAN
def fn_sampleGAN_given_labels(labels, batch_size):
labels = labels*args.max_label
fake_images, fake_labels = sample_cgan_concat_given_labels(netG, labels, batch_size = batch_size, denorm=True, to_numpy=True, verbose=True)
fake_labels = fake_labels / args.max_label
return fake_images, fake_labels
#----------------------------------------------
# Concitnuous cGAN
elif args.GAN == "CcGAN":
Filename_GAN = save_models_folder + '/ckpt_{}_niters_{}_nDsteps_{}_seed_{}_{}_{}_{}.pth'.format(args.GAN_arch, args.niters_gan, args.num_D_steps, args.seed, args.threshold_type, args.kernel_sigma, args.kappa)
print(Filename_GAN)
if not os.path.isfile(Filename_GAN):
netG = CcGAN_SAGAN_Generator(dim_z=args.dim_gan, dim_embed=args.dim_embed)
netD = CcGAN_SAGAN_Discriminator(dim_embed=args.dim_embed)
netG = nn.DataParallel(netG)
netD = nn.DataParallel(netD)
# Start training
netG, netD = train_ccgan(args.kernel_sigma, args.kappa, images_train, labels_train, netG, netD, net_y2h, save_images_folder=save_images_in_train_folder, save_models_folder = save_models_folder)
# store model
torch.save({
'netG_state_dict': netG.state_dict(),
}, Filename_GAN)
else:
print("Loading pre-trained generator >>>")
checkpoint = torch.load(Filename_GAN)
netG = CcGAN_SAGAN_Generator(dim_z=args.dim_gan, dim_embed=args.dim_embed).cuda()
netG = nn.DataParallel(netG)
netG.load_state_dict(checkpoint['netG_state_dict'])
def fn_sampleGAN_given_labels(labels, batch_size):
fake_images, fake_labels = sample_ccgan_given_labels(netG, net_y2h, labels, batch_size = batch_size, to_numpy=True, denorm=True, verbose=True)
return fake_images, fake_labels
stop = timeit.default_timer()
print("GAN training finished; Time elapses: {}s".format(stop - start))
#######################################################################################
''' Evaluation '''
#######################################################################################
if args.comp_FID:
print("\n Evaluation in Mode {}...".format(args.eval_mode))
PreNetFID = encoder(dim_bottleneck=512).cuda()
PreNetFID = nn.DataParallel(PreNetFID)
Filename_PreCNNForEvalGANs = args.eval_ckpt_path + '/ckpt_AE_epoch_200_seed_2020_CVMode_False.pth'
checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs)
PreNetFID.load_state_dict(checkpoint_PreNet['net_encoder_state_dict'])
# Diversity: entropy of predicted races within each eval center
PreNetDiversity = ResNet34_class_eval(num_classes=49, ngpu = torch.cuda.device_count()).cuda() #49 chair types
Filename_PreCNNForEvalGANs_Diversity = args.eval_ckpt_path + '/ckpt_PreCNNForEvalGANs_ResNet34_class_epoch_200_seed_2020_classify_49_chair_types_CVMode_False.pth'
checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs_Diversity)
PreNetDiversity.load_state_dict(checkpoint_PreNet['net_state_dict'])
# for LS
PreNetLS = ResNet34_regre_eval(ngpu = torch.cuda.device_count()).cuda()
Filename_PreCNNForEvalGANs_LS = args.eval_ckpt_path + '/ckpt_PreCNNForEvalGANs_ResNet34_regre_epoch_200_seed_2020_CVMode_False.pth'
checkpoint_PreNet = torch.load(Filename_PreCNNForEvalGANs_LS)
PreNetLS.load_state_dict(checkpoint_PreNet['net_state_dict'])
#####################
# generate nfake images
print("\n Start sampling {} fake images per label from GAN >>>".format(args.nfake_per_label))
if args.eval_mode == 1: #Mode 1: eval on unique labels used for GAN training
eval_labels = np.sort(np.array(list(set(labels_train_raw)))) #not normalized
elif args.eval_mode in [2, 3]: #Mode 2 and 3: eval on all unique labels in the dataset
eval_labels = np.sort(np.array(list(set(labels_all)))) #not normalized
else: #Mode 4: eval on a interval [min_label, max_label] with num_eval_labels labels
eval_labels = np.linspace(np.min(labels_all), np.max(labels_all), args.num_eval_labels) #not normalized
unique_eval_labels = list(set(eval_labels))
print("\n There are {} unique eval labels.".format(len(unique_eval_labels)))
eval_labels_norm = eval_labels/args.max_label #normalized
for i in range(len(eval_labels)):
curr_label = eval_labels_norm[i]
if i == 0:
fake_labels_assigned = np.ones(args.nfake_per_label)*curr_label
else:
fake_labels_assigned = np.concatenate((fake_labels_assigned, np.ones(args.nfake_per_label)*curr_label))
fake_images, _ = fn_sampleGAN_given_labels(fake_labels_assigned, args.samp_batch_size)
assert len(fake_images) == args.nfake_per_label*len(eval_labels)
assert len(fake_labels_assigned) == args.nfake_per_label*len(eval_labels)
assert fake_images.min()>=0 and fake_images.max()<=255.0
## dump fake images for computing NIQE
if args.dump_fake_for_NIQE:
print("\n Dumping fake images for NIQE...")
dump_fake_images_folder = save_images_folder + '/fake_images_for_NIQE_nfake_{}'.format(len(fake_images))
os.makedirs(dump_fake_images_folder, exist_ok=True)
for i in tqdm(range(len(fake_images))):
label_i = fake_labels_assigned[i]*args.max_label
filename_i = dump_fake_images_folder + "/{}_{}.png".format(i, label_i)
os.makedirs(os.path.dirname(filename_i), exist_ok=True)
image_i = fake_images[i].astype(np.uint8)
# image_i = ((image_i*0.5+0.5)*255.0).astype(np.uint8)
image_i_pil = Image.fromarray(image_i.transpose(1,2,0))
image_i_pil.save(filename_i)
#end for i
# sys.exit()
print("End sampling! We got {} fake images.".format(len(fake_images)))
#####################
# prepare real/fake images and labels
if args.eval_mode in [1, 3]:
# real_images = (images_train/255.0-0.5)/0.5
real_images = images_train
real_labels = labels_train_raw #not normalized
else: #for both mode 2 and 4
# real_images = (images_all/255.0-0.5)/0.5
real_images = images_all
real_labels = labels_all #not normalized
# fake_images = (fake_images/255.0-0.5)/0.5
#######################
# For each label take nreal_per_label images
unique_labels_real = np.sort(np.array(list(set(real_labels))))
indx_subset = []
for i in range(len(unique_labels_real)):
label_i = unique_labels_real[i]
indx_i = np.where(real_labels==label_i)[0]
np.random.shuffle(indx_i)
if args.nreal_per_label>1:
indx_i = indx_i[0:args.nreal_per_label]
indx_subset.append(indx_i)
indx_subset = np.concatenate(indx_subset)
real_images = real_images[indx_subset]
real_labels = real_labels[indx_subset]
nfake_all = len(fake_images)
nreal_all = len(real_images)
#####################
# Evaluate FID within a sliding window with a radius R on the label's range (not normalized range, i.e., [min_label,max_label]). The center of the sliding window locate on [min_label+R,...,max_label-R].
if args.eval_mode == 1:
center_start = np.min(labels_train_raw)+args.FID_radius ##bug???
center_stop = np.max(labels_train_raw)-args.FID_radius
else:
center_start = np.min(labels_all)+args.FID_radius
center_stop = np.max(labels_all)-args.FID_radius
if args.FID_num_centers<=0 and args.FID_radius==0: #completely overlap
centers_loc = eval_labels #not normalized
elif args.FID_num_centers>0:
centers_loc = np.linspace(center_start, center_stop, args.FID_num_centers) #not normalized
else:
print("\n Error.")
FID_over_centers = np.zeros(len(centers_loc))
entropies_over_centers = np.zeros(len(centers_loc)) # entropy at each center
labelscores_over_centers = np.zeros(len(centers_loc)) #label score at each center
num_realimgs_over_centers = np.zeros(len(centers_loc))
for i in range(len(centers_loc)):
center = centers_loc[i]
interval_start = (center - args.FID_radius)#/args.max_label
interval_stop = (center + args.FID_radius)#/args.max_label
indx_real = np.where((real_labels>=interval_start)*(real_labels<=interval_stop)==True)[0]
np.random.shuffle(indx_real)
real_images_curr = real_images[indx_real]
real_images_curr = (real_images_curr/255.0-0.5)/0.5
num_realimgs_over_centers[i] = len(real_images_curr)
indx_fake = np.where((fake_labels_assigned>=(interval_start/args.max_label))*(fake_labels_assigned<=(interval_stop/args.max_label))==True)[0]
np.random.shuffle(indx_fake)
fake_images_curr = fake_images[indx_fake]
fake_images_curr = (fake_images_curr/255.0-0.5)/0.5
fake_labels_assigned_curr = fake_labels_assigned[indx_fake]
# FID
FID_over_centers[i] = cal_FID(PreNetFID, real_images_curr, fake_images_curr, batch_size = 200, resize = None)
# Entropy of predicted class labels
predicted_class_labels = predict_class_labels(PreNetDiversity, fake_images_curr, batch_size=200, num_workers=args.num_workers)
entropies_over_centers[i] = compute_entropy(predicted_class_labels)
# Label score
labelscores_over_centers[i], _ = cal_labelscore(PreNetLS, fake_images_curr, fake_labels_assigned_curr, min_label_before_shift=0, max_label_after_shift=args.max_label, batch_size = 500, resize = None, num_workers=args.num_workers)
print("\n [{}/{}] Center:{}; Real:{}; Fake:{}; FID:{}; LS:{}; ET:{}.".format(i+1, len(centers_loc), center, len(real_images_curr), len(fake_images_curr), FID_over_centers[i], labelscores_over_centers[i], entropies_over_centers[i]))
# end for i
# average over all centers
print("\n {} SFID: {}({}); min/max: {}/{}.".format(args.GAN_arch, np.mean(FID_over_centers), np.std(FID_over_centers), np.min(FID_over_centers), np.max(FID_over_centers)))
print("\n {} LS over centers: {}({}); min/max: {}/{}.".format(args.GAN_arch, np.mean(labelscores_over_centers), np.std(labelscores_over_centers), np.min(labelscores_over_centers), np.max(labelscores_over_centers)))
print("\n {} entropy over centers: {}({}); min/max: {}/{}.".format(args.GAN_arch, np.mean(entropies_over_centers), np.std(entropies_over_centers), np.min(entropies_over_centers), np.max(entropies_over_centers)))
# dump FID versus number of samples (for each center) to npy
dump_fid_ls_entropy_over_centers_filename = os.path.join(path_to_output, 'fid_ls_entropy_over_centers')
np.savez(dump_fid_ls_entropy_over_centers_filename, fids=FID_over_centers, labelscores=labelscores_over_centers, entropies=entropies_over_centers, nrealimgs=num_realimgs_over_centers, centers=centers_loc)
#####################
# FID: Evaluate FID on all fake images
indx_shuffle_real = np.arange(nreal_all); np.random.shuffle(indx_shuffle_real)
indx_shuffle_fake = np.arange(nfake_all); np.random.shuffle(indx_shuffle_fake)
FID = cal_FID(PreNetFID, real_images[indx_shuffle_real], fake_images[indx_shuffle_fake], batch_size = 200, resize = None, norm_img = True)
print("\n {}: FID of {} fake images: {}.".format(args.GAN_arch, nfake_all, FID))
#####################
# Overall LS: abs(y_assigned - y_predicted)
ls_mean_overall, ls_std_overall = cal_labelscore(PreNetLS, fake_images, fake_labels_assigned, min_label_before_shift=0, max_label_after_shift=args.max_label, batch_size = 200, resize = None, norm_img = True, num_workers=args.num_workers)
print("\n {}: overall LS of {} fake images: {}({}).".format(args.GAN_arch, nfake_all, ls_mean_overall, ls_std_overall))
#####################
# Dump evaluation results
eval_results_logging_fullpath = os.path.join(path_to_output, 'eval_results_{}.txt'.format(args.GAN_arch))
if not os.path.isfile(eval_results_logging_fullpath):
eval_results_logging_file = open(eval_results_logging_fullpath, "w")
eval_results_logging_file.close()
with open(eval_results_logging_fullpath, 'a') as eval_results_logging_file:
eval_results_logging_file.write("\n===================================================================================================")
eval_results_logging_file.write("\n Eval Mode: {}; Radius: {}; # Centers: {}. \n".format(args.eval_mode, args.FID_radius, args.FID_num_centers))
print(args, file=eval_results_logging_file)
eval_results_logging_file.write("\n SFID: {}({}).".format(np.mean(FID_over_centers), np.std(FID_over_centers)))
eval_results_logging_file.write("\n LS: {}({}).".format(np.mean(labelscores_over_centers), np.std(labelscores_over_centers)))
eval_results_logging_file.write("\n Diversity: {}({}).".format(np.mean(entropies_over_centers), np.std(entropies_over_centers)))
#######################################################################################
''' Visualize fake images of the trained GAN '''
#######################################################################################
if args.visualize_fake_images:
# First, visualize conditional generation # vertical grid
## 10 rows; 3 columns (3 samples for each age)
n_row = 10
n_col = 10
displayed_unique_labels = np.sort(np.array(list(set(labels_all))))
displayed_labels_indx = (np.linspace(0.05, 0.95, n_row)*len(displayed_unique_labels)).astype(int)
displayed_labels = displayed_unique_labels[displayed_labels_indx] #not normalized
displayed_normalized_labels = displayed_labels/args.max_label
### output fake images from a trained GAN
filename_fake_images = os.path.join(save_images_folder, 'fake_images_grid_{}x{}.png').format(n_row, n_col)
fake_labels_assigned = []
for tmp_i in range(len(displayed_normalized_labels)):
curr_label = displayed_normalized_labels[tmp_i]
fake_labels_assigned.append(np.ones(shape=[n_col, 1])*curr_label)
fake_labels_assigned = np.concatenate(fake_labels_assigned, axis=0)
images_show, _ = fn_sampleGAN_given_labels(fake_labels_assigned, args.samp_batch_size)
images_show = (images_show/255.0-0.5)/0.5
images_show = torch.from_numpy(images_show)
save_image(images_show.data, filename_fake_images, nrow=n_col, normalize=True)
if args.GAN == "CcGAN":
# Second, fix z but increase y; check whether there is a continuous change, only for CcGAN
n_continuous_labels = 10
normalized_continuous_labels = np.linspace(0.05, 0.95, n_continuous_labels)
z = torch.randn(1, args.dim_gan, dtype=torch.float).cuda()
continuous_images_show = torch.zeros(n_continuous_labels, args.num_channels, args.img_size, args.img_size, dtype=torch.float)
netG.eval()
with torch.no_grad():
for i in range(n_continuous_labels):
y = np.ones(1) * normalized_continuous_labels[i]
y = torch.from_numpy(y).type(torch.float).view(-1,1).cuda()
fake_image_i = netG(z, net_y2h(y))
continuous_images_show[i,:,:,:] = fake_image_i.cpu()
filename_continous_fake_images = os.path.join(save_images_folder, 'continuous_fake_images_grid.png')
save_image(continuous_images_show.data, filename_continous_fake_images, nrow=n_continuous_labels, normalize=True)
print("Continuous ys: ", (normalized_continuous_labels*args.max_label))
### output some real images as baseline
filename_real_images = save_images_folder + '/real_images_grid_{}x{}.png'.format(n_row, n_col)
if not os.path.isfile(filename_real_images):
images_show = np.zeros((n_row*n_col, args.num_channels, args.img_size, args.img_size))
for i_row in range(n_row):
# generate 3 real images from each interval
curr_label = displayed_labels[i_row]
for j_col in range(n_col):
indx_curr_label = np.where(labels_all==curr_label)[0]
np.random.shuffle(indx_curr_label)
indx_curr_label = indx_curr_label[0]
images_show[i_row*n_col+j_col] = images_all[indx_curr_label]
#end for i_row
images_show = (images_show/255.0-0.5)/0.5
images_show = torch.from_numpy(images_show)
save_image(images_show.data, filename_real_images, nrow=n_col, normalize=True)
print("\n===================================================================================================")
Oh that's cool! I am glad we are both trying to solve similar problems haha. I will implement what you kindly suggested and let you know if it works. Thank you so much! Also did you happen to modify CcGAN_SAGAN.py at all?
I thought I did modify CcGAN_SAGAN.py, but can't see any legacy for it in the commit history or git diff. Perhaps, the file wasn't necessary to be modified for our situation.
I see makes sense. And once you trained it, how did you use the generator model?
I used Jupyter notebook for inference(generating on conditions)
- Bring the model architecture code in the notebook
- It seems like
netGandnet_y2hare needed. Or perhaps other models as well... - Use the checkpoint to
load_state_dictfor those models- For me, the names in the stored state_dict and the model I am making in the jupyter notebook didn't match, so I had to make them match each other. For example, I had to remove
'module.'from the names of all the layers inside the state
- For me, the names in the stored state_dict and the model I am making in the jupyter notebook didn't match, so I had to make them match each other. For example, I had to remove
- It seems like
- Make conditions and input those into the model
- Visualise the output
Below is some of the core parts(in my opinion) in the inference notebook I wrote
Checkpoint
# x2y
x2y_ckpt_path = "../../../../output/embed_models/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_200.pth"
x2y_state_dict = torch.load(x2y_ckpt_path)
x2y_state_dict["new_net_state_dict"] = dict()
for k, v in x2y_state_dict["net_state_dict"].items():
x2y_state_dict["new_net_state_dict"].update({k.replace("module.", ""):v})
net_x2y.load_state_dict(x2y_state_dict["new_net_state_dict"])
Inference
labels = torch.tensor([[30, 1.0],
[60, 1.3],
[90, 1.6],
[120, 1.9],
[150, 2.1],
[180, 2.4],
[210, 2.7]]).cuda()
with torch.no_grad():
h = net_y2h(labels)
print(f"{h.shape = }")
z = torch.randn((h.shape[0], DIM_GAN)).cuda()
outputs = netG(z, h)
print(f"{outputs.shape = }")
labels_hat, embedded_features = net_x2y(outputs)
Hi there,
Thanks for your work in generative AI. I'm currently trying to implement CcGAN and just wondering, is there anyway to train the model easily?
I can see there are some python files which receive arguments through shell script files, but they lack requirements.txt and need some configuration to run, as far as I can guess.
It'd be much appreciated if you could provide a way to use your model and see how it works with other custom datasets.
Thanks in advance.
Hi,
Thanks a lot for your interest in our work. In this repository, we only provide the .sh shell script for training the model on Linux. If you want to train the model on Windows, you may need to convert the .sh files into .bat files suitable for Windows. Please refer to https://github.com/UBCDingXin/Dual-NDA or https://github.com/UBCDingXin/CCDM, where the Windows batch scripts are provided.
Thank you @Foundsheep and @UBCDingXin for your help! So I tried to run the main.py and I got the following error
return F.mse_loss(input, target, reduction=self.reduction)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File ~\Desktop\Research\Adding ML TO AR\GAN and cGAN and fastGAN and WGAN\Continuous Conditional GAN\improved_CcGAN\RC-49\RC-49_128x128\CcGAN-improved\main.py:307
305 if not os.path.isfile(net_embed_filename_ckpt):
306 print("\n Start training CNN for label embedding >>>")
--> 307 net_embed = train_net_embed(net=net_embed, net_name=args.net_embed, trainloader=trainloader_embed_net, testloader=None, epochs=args.epoch_cnn_embed, resume_epoch = args.resumeepoch_cnn_embed, lr_base=base_lr_x2y, lr_decay_factor=0.1, lr_decay_epochs=[80, 140], weight_decay=1e-4, path_to_ckpt = path_to_embed_models)
308 # save model
309 torch.save({
310 'net_state_dict': net_embed.state_dict(),
311 }, net_embed_filename_ckpt)
File ~\Desktop\Research\Adding ML TO AR\GAN and cGAN and fastGAN and WGAN\Continuous Conditional GAN\improved_CcGAN\RC-49\RC-49_128x128\CcGAN-improved\train_net_for_label_embed.py:59, in train_net_embed(net, net_name, trainloader, testloader, epochs, resume_epoch, lr_base, lr_decay_factor, lr_decay_epochs, weight_decay, path_to_ckpt)
57 #Forward pass
58 outputs, _ = net(batch_train_images)
---> 59 loss = criterion(outputs, batch_train_labels)
61 #backward pass
62 optimizer.zero_grad()
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\loss.py:520, in MSELoss.forward(self, input, target)
519 def forward(self, input: Tensor, target: Tensor) -> Tensor:
--> 520 return F.mse_loss(input, target, reduction=self.reduction)
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\functional.py:3111, in mse_loss(input, target, size_average, reduce, reduction)
3108 if size_average is not None or reduce is not None:
3109 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3111 expanded_input, expanded_target = torch.broadcast_tensors(input, target)
3112 return torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\functional.py:72, in broadcast_tensors(*tensors)
70 if has_torch_function(tensors):
71 return handle_torch_function(broadcast_tensors, tensors, *tensors)
---> 72 return _VF.broadcast_tensors(tensors)
RuntimeError: The size of tensor a (256) must match the size of tensor b (1809408) at non-singleton dimension 0
I know that is an error because the dimensions of the conditional inputs are not the same but I do not understand where the 1809408 came from if I am taking a batch of 256 images then it makes sense that the conditional inputs would be 256 as well. Do you think you know where the problem is?
@UBCDingXin Thanks for the guide to .bat files. I didn't know that way, so I used git bash to invoke that shell script haha.
@ashamy97 It seems like the loss function is flattening the input and tensor. I guess, perhaps a wrong guess, you are using different image size to the one the model is expecting, for example, the model might expect 128x128, but you are using 256x256 or something like that.
yup you were absolutely right. My labels were the wrong dimensions but the images were the right dimensions. Silly me! Thank you so much :)
Also I believe for this section in the main.py code
if labels_train_raw.shape[-1] == NUM_CONDITIONS:
scale_q1 = args.min_label_scale
scale_q2 = args.max_label_scale
indx_1 = np.where((labels_train_raw[:, 0]>q1)*(labels_train_raw[:, 0]<q2)==True)[0]
indx_2 = np.where((labels_train_raw[:, 1]>scale_q1)*(labels_train_raw[:, 1]<scale_q2)==True)[0]
indx = np.array(list(set.intersection(set(indx_1), set(indx_2))))
else:
indx = np.where((labels_train_raw>q1)*(labels_train_raw<q2)==True)[0]
I think this is meant to be the min and max for the second label? It wasn't added to the opts.py file so it gave an error. I added it there and it is running fine.
@ashamy97 Yes, you're correct on it :) Because the original code was assuming only one condition, there was that setting min-max mechanism only for the entire label set, and I thought that would need to be expanded to several conditions as well.
Perhaps if the code is running well, you must have implemented it right. As I mentioned earlier, to expand the number of conditions, we need to figure out and modify the bits where only one condition is assumed. Hope your experiment goes well :)
Will let you know. The CNN embed part seems to be taking a while to train but I am happy it's training haha. Thanks for your help @Foundsheep .
Also @UBCDingXin if you think our approach is correct when it comes to including more than one conditional input, let us know.
@Foundsheep I am not sure if you remember this or not, but after how many epochs when training the cCGAN using the following parameters did you get results?
Because so far here are my generated images after 800 epochs
This is what my images are supposed to look like
@ashamy97 As you can see, just after 800 epochs is almost at the very beginning of the training. To get something noticeable, I recall that I waited to see until about 2000 epochs and most of the pictures in the grid became noticeable after about 5000 epochs. Also, the grid finally became somewhat like a real picture after about 12000 epochs, and even if I waited until about 30000 epochs, it didn't get better from the one around 12000 epochs.
It might vary depending on the dataset we're training on, but all the training(as I mentioned, I skipped some of the training parts in the shell script, such as cGAN tranining or training for evaluation network etc.) took about 3~4 days on my RTX 4070 Laptop GPU.
To be honest with you, our result didn't match our expectation. Even if we gave two condtiions the model completely ignored one of them, and only tried to give us something looking realistic, rather than according to the conditions, which I guess is mode collapse. We might have needed to set the dataset more neatly that the model could learn easily, but that is where we stopped for this model.
@Foundsheep @ashamy97 Hi guys,
The current CcGAN is designed specifically for univariate conditions, so it might face challenges when directly transferring to multi-dimensional conditions. For multi-dimensional conditions, you may need to re-design the condition input mechanism and re-define the hard/soft vicinity.
We are working on such a problem but lacking in suitable datasets. Is there any public image dataset that is labeled by multi-dimensional conditions?
@UBCDingXin I dont know if this will help but I found this link (https://height-weight-chart.com/) that contains images of individuals with their corresponding height and weight. Both continuous variables and it has limited images.
@UBCDingXin Thanks for the instrudction. That is exactly what I did when trying to convert it to multi-dimensional conditioning model.
I'm not sure either if this would answer your question, but, actually, I used RC-49 dataset, and the way we used it is scaling the size of the chair and used the scaling ratio as another condition. For example, the condition would be (angle, scaling ratio) vector and scaling ratio of 0.5 would mean the chair size is resized to its half(which would then be cropped and added to the white background of the same size as the original image like 128x128)
@UBCDingXin @Foundsheep I am curious if you guys tried training this cCGAN on a dataset that is limited? Like for example on 100 images?
@ashamy97 I haven't tried it. But curious about its result.
@Foundsheep yeah i think I will try that some other time.
Question: so I trained the cCGAN and I am trying to do the inference but I don't quite know how many models we need. Because isn't x2y maps the input image to a latent space and then that latent space to the regression label y? But we want to use the generator so in that case we need net_y2h and the generator and that's it no? Or am I understanding this wrong?
@ashamy97 yes, you're right in the direction of the guess. I mentioned how I used the generator for inference up there, and here is the link https://github.com/UBCDingXin/improved_CcGAN/issues/12#issuecomment-2097120415
@Foundsheep So I am running into an error when trying to find the reconstructed labels. I copied the same code you used with some modifications.
# x2y
DIM_EMBED = 128
x2y_ckpt_path = ".../output/embed_models/embed_x2y_ckpt_in_train/embed_x2y_checkpoint_epoch_200.pth"
x2y_state_dict = torch.load(x2y_ckpt_path)
# for k, v in x2y_state_dict["net_state_dict"].items():
# x2y_state_dict["new_net_state_dict"].update({k.replace("module.", ""):v})
net_embed = ResNet34_embed(dim_embed=DIM_EMBED)
net_embed = net_embed.cuda()
net_embed = nn.DataParallel(net_embed)
net_embed.load_state_dict(x2y_state_dict['net_state_dict'])
#y2h
y2h_ckpt_path = ".../output/embed_models/ckpt_net_y2h_epoch_500_seed_2020.pth"
y2h_state_dict = torch.load(y2h_ckpt_path)
net_y2h = model_y2h(dim_embed=DIM_EMBED)
net_y2h = net_y2h.cuda()
net_y2h = nn.DataParallel(net_y2h)
net_y2h.load_state_dict(y2h_state_dict['net_state_dict'])
#netG
DIM_GAN = 256
Filename_GAN = ".../output/output_CcGAN_arch_SAGAN/saved_models/SAGAN_soft_2_checkpoint_intrain/checkpoint_20400.pth"
checkpoint = torch.load(Filename_GAN)
netG = CcGAN_SAGAN_Generator(dim_z=DIM_GAN, dim_embed=DIM_EMBED).cuda()
netG = nn.DataParallel(netG)
netG.load_state_dict(checkpoint['netG_state_dict'])
This is the inference:
labels = torch.tensor([[9, 10]]).cuda()
with torch.no_grad():
h = net_y2h(labels)
print(f" shape of h: {h.shape}")
z = torch.randn((h.shape[0], DIM_GAN)).cuda()
outputs = netG(z, h)
print(f"Output shape: {outputs.shape}")
labels_hat, embedded_features = net_embed(outputs)
print(labels_hat.shape)
print(embedded_features.shape)
This is the error message log:
shape of h: torch.Size([1, 128])
Output shape: torch.Size([1, 3, 128, 128])
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 16
13 outputs = netG(z, h)
14 print(f"Output shape: {outputs.shape}")
---> 16 labels_hat, embedded_features = net_embed(outputs)
18 print(labels_hat.shape)
19 print(embedded_features.shape)
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\parallel\data_parallel.py:166, in DataParallel.forward(self, *inputs, **kwargs)
163 kwargs = ({},)
165 if len(self.device_ids) == 1:
--> 166 return self.module(*inputs[0], **kwargs[0])
167 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
168 outputs = self.parallel_apply(replicas, inputs, kwargs)
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~\Desktop\Research\Adding ML TO AR\GAN and cGAN and fastGAN and WGAN\Continuous Conditional GAN\improved_CcGAN\RC-49\RC-49_128x128\CcGAN-improved\models\ResNet_embed.py:131, in ResNet_embed.forward(self, x)
129 features = self.main(x)
130 features = features.view(features.size(0), -1)
--> 131 features = self.x2h_res(features)
132 out = self.h2y(features)
134 return out, features
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\container.py:141, in Sequential.forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\module.py:1102, in Module._call_impl(self, *input, **kwargs)
1098 # If we don't have any hooks, we want to skip the rest of the logic in
1099 # this function, and just call forward.
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\modules\batchnorm.py:168, in _BatchNorm.forward(self, input)
161 bn_training = (self.running_mean is None) and (self.running_var is None)
163 r"""
164 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
165 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
166 used for normalization (i.e. in eval mode when buffers are not None).
167 """
--> 168 return F.batch_norm(
169 input,
170 # If buffers are not to be tracked, ensure that they won't be updated
171 self.running_mean
172 if not self.training or self.track_running_stats
173 else None,
174 self.running_var if not self.training or self.track_running_stats else None,
175 self.weight,
176 self.bias,
177 bn_training,
178 exponential_average_factor,
179 self.eps,
180 )
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\functional.py:2280, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
2267 return handle_torch_function(
2268 batch_norm,
2269 (input, running_mean, running_var, weight, bias),
(...)
2277 eps=eps,
2278 )
2279 if training:
-> 2280 _verify_batch_size(input.size())
2282 return torch.batch_norm(
2283 input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
2284 )
File ~\anaconda3\envs\ContinuousCGAN\lib\site-packages\torch\nn\functional.py:2248, in _verify_batch_size(size)
2246 size_prods *= size[i + 2]
2247 if size_prods == 1:
-> 2248 raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 512])
It's weird because I didn't get error messages when training x2y.
Do you think you might know where the problem is? I think this will then be my last question. Thanks so much for your help :)
@ashamy97 It seems related to something to do with the shape of tensor, which is modified due to our modification on conditioning inputs.
Could you see if there's any difference between the below code and the code you're using in ResNet_embed.py? Perhaps, those commented lines might have something to do with this error.
Basically, I think I faced a similar kind of error, and I resolved it through checking what shape the model is expecting inside line by line. If you're to face this error continuously, that might be helpful :)
My version of ResNet_embed.py
import torch
import torch.nn as nn
import torch.nn.functional as F
NC = 3
IMG_SIZE = 128
DIM_EMBED = 128
NUM_CONDITIONS = 2
#------------------------------------------------------------------------------
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(self.expansion*planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet_embed(nn.Module):
def __init__(self, block, num_blocks, nc=NC, dim_embed=DIM_EMBED):
super(ResNet_embed, self).__init__()
self.in_planes = 64
self.main = nn.Sequential(
nn.Conv2d(nc, 64, kernel_size=3, stride=1, padding=1, bias=False), # h=h
# nn.Conv2d(nc, 64, kernel_size=4, stride=2, padding=1, bias=False), # h=h/2
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2,2), #h=h/2 64
# self._make_layer(block, 64, num_blocks[0], stride=1), # h=h
self._make_layer(block, 64, num_blocks[0], stride=2), # h=h/2 32
self._make_layer(block, 128, num_blocks[1], stride=2), # h=h/2 16
self._make_layer(block, 256, num_blocks[2], stride=2), # h=h/2 8
self._make_layer(block, 512, num_blocks[3], stride=2), # h=h/2 4
# nn.AvgPool2d(kernel_size=4)
nn.AdaptiveAvgPool2d((1, 1))
)
self.x2h_res = nn.Sequential(
nn.Linear(512, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
nn.Linear(512, dim_embed),
nn.BatchNorm1d(dim_embed),
nn.ReLU(),
)
self.h2y = nn.Sequential(
nn.Linear(dim_embed, NUM_CONDITIONS),
nn.ReLU()
)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
features = self.main(x)
features = features.view(features.size(0), -1)
features = self.x2h_res(features)
out = self.h2y(features)
return out, features
def ResNet18_embed(dim_embed=DIM_EMBED):
return ResNet_embed(BasicBlock, [2,2,2,2], dim_embed=dim_embed)
def ResNet34_embed(dim_embed=DIM_EMBED):
return ResNet_embed(BasicBlock, [3,4,6,3], dim_embed=dim_embed)
def ResNet50_embed(dim_embed=DIM_EMBED):
return ResNet_embed(Bottleneck, [3,4,6,3], dim_embed=dim_embed)
#------------------------------------------------------------------------------
# map labels to the embedding space
class model_y2h(nn.Module):
def __init__(self, dim_embed=DIM_EMBED):
super(model_y2h, self).__init__()
self.main = nn.Sequential(
nn.Linear(NUM_CONDITIONS, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
# nn.BatchNorm1d(dim_embed),
nn.GroupNorm(8, dim_embed),
nn.ReLU(),
nn.Linear(dim_embed, dim_embed),
nn.ReLU()
)
def forward(self, y):
# y = y.view(-1, 1) +1e-8
y = y + 1e-8
# y = torch.exp(y.view(-1, 1))
return self.main(y)
if __name__ == "__main__":
net = ResNet34_embed(dim_embed=128).cuda()
x = torch.randn(16,NC,IMG_SIZE,IMG_SIZE).cuda()
out, features = net(x)
print(out.size())
print(features.size())
net_y2h = model_y2h().cuda()
y_hat = net_y2h(out)
print(f"{y_hat.size() = }")