stylegan2-pytorch
stylegan2-pytorch copied to clipboard
OOM on GTX1080ti with 128 image size
I fixed my trained model load problem and the weight seems load successfully.
But when inference I got OOM, here is my config and inference code:
"""
Generate random images by trained StyleGANV2
"""
from stylegan2_pytorch.stylegan2_pytorch import Generator, StyleGAN2
import torch
import numpy as np
from alfred.dl.torch.common import device
import cv2
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(np.concatenate(
[init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
return torch.index_select(a, dim, order_index)
def evaluate_in_chunks(max_batch_size, model, *args):
split_args = list(
zip(*list(map(lambda x: x.split(max_batch_size, dim=0), args))))
chunked_outputs = [model(*i) for i in split_args]
if len(chunked_outputs) == 1:
return chunked_outputs[0]
return torch.cat(chunked_outputs, dim=0)
def styles_def_to_tensor(styles_def):
return torch.cat([t[:, None, :].expand(-1, n, -1) for t, n in styles_def], dim=1)
def truncate_style(g, tensor, trunc_psi=0.75):
S = g.S
latent_dim = g.G.latent_dim
z = torch.randn([2000, latent_dim]).to(device)
samples = evaluate_in_chunks(1, S, z).cpu().numpy()
av = np.mean(samples, axis=0)
av = np.expand_dims(av, axis=0)
av_torch = torch.from_numpy(av).to(device)
tensor = trunc_psi * (tensor - av_torch) + av_torch
return tensor
def truncate_style_defs(g, w, trunc_psi=0.75):
w_space = []
for tensor, num_layers in w:
tensor = truncate_style(g, tensor, trunc_psi=trunc_psi)
w_space.append((tensor, num_layers))
return w_space
def generate_truncated(g, S, G, style, noi, trunc_psi=0.75, num_image_tiles=8):
w = map(lambda t: (S(t[0]), t[1]), style)
w_truncated = truncate_style_defs(g, w, trunc_psi=trunc_psi)
w_styles = styles_def_to_tensor(w_truncated)
generated_images = evaluate_in_chunks(1, G, w_styles, noi)
return generated_images.clamp_(0., 1.)
if __name__ == "__main__":
weight_path = 'models/default/model_50.pt'
# should same as your config file
image_size = 128
latent_dim = 512
network_capacity = 16
fq_dict_size = 256
transparent = False
attn_layers = []
no_const = False
model = StyleGAN2(image_size=image_size, latent_dim=latent_dim, network_capacity=network_capacity,
transparent=transparent, fq_dict_size=fq_dict_size, attn_layers=attn_layers, no_const=no_const)
# print(model)
model.eval()
ckpt = torch.load(weight_path)
model.load_state_dict(ckpt['GAN'])
model.to(device)
print('StyleGAN2 loaded.')
num_layers = model.G.num_layers
nn = torch.randn([1, latent_dim]).to(device)
tmp1 = tile(nn, 0, 1)
tmp2 = nn.repeat(1, 1)
tt = int(num_layers / 2)
mixed_latents = [(tmp1, tt), (tmp2, num_layers - tt)]
print(mixed_latents)
generated_images = generate_truncated(model,
model.SE, model.GE, mixed_latents, 1)
OOM, what's the error could be? 128 input size should ok with my 12GB memory?
You need to allocate more virtual memory (paging file).