stylegan2-pytorch copied to clipboard
how could I be able to move sliders in latent space? looking to mess more with my generation
@erinbeesley Hey Erin, I'll look into exposing a way to use your model in code, so you may generate on whatever set of latent codes you wish
Hi, I wanted to do this too and figured out a way- pretty rudimentary but it works! Add this function to then run it and it will pop up a matplotlib window with a slider that on click will evaluate the model at a new point. Each time you load the window it will choose a new linspace. @lucidrains I made this a while back, before you had the model-exposed-to-code set up. Sorry for butchering your code! I'll make a better one soon.
def eval_model_points(
data = './data',
results_dir = './results',
models_dir = './models',
name = 'default',
new = False,
load_from = -1,
image_size = 512,
network_capacity = 16,
transparent = False,
batch_size = 5,
gradient_accumulate_every = 6,
num_train_steps = 150000,
learning_rate = 2e-4,
lr_mlp = 0.1,
ttur_mult = 1.5,
rel_disc_loss = False,
num_workers = None,
save_every = 1000,
evaluate_every = 1000,
generate = True,
generate_interpolation = True,
interpolation_num_steps = 100,
save_frames = True,
num_image_tiles = 4,
trunc_psi = 0.75,
mixed_prob = 0.9,
fp16 = False,
cl_reg = False,
fq_layers = [],
fq_dict_size = 256,
attn_layers = [],
no_const = False,
aug_prob = 0.,
aug_types = ['translation', 'cutout'],
top_k_training = False,
generator_top_k_gamma = 0.99,
generator_top_k_frac = 0.5,
dataset_aug_prob = 0.,
multi_gpus = True,
calculate_fid_every = None,
seed = 42
model_args = dict(
name = name,
results_dir = results_dir,
models_dir = models_dir,
batch_size = batch_size,
gradient_accumulate_every = gradient_accumulate_every,
image_size = image_size,
network_capacity = network_capacity,
transparent = transparent,
lr = learning_rate,
lr_mlp = lr_mlp,
ttur_mult = ttur_mult,
rel_disc_loss = rel_disc_loss,
num_workers = num_workers,
save_every = save_every,
evaluate_every = evaluate_every,
trunc_psi = trunc_psi,
fp16 = fp16,
cl_reg = cl_reg,
fq_layers = fq_layers,
fq_dict_size = fq_dict_size,
attn_layers = attn_layers,
no_const = no_const,
aug_prob = aug_prob,
aug_types = cast_list(aug_types),
top_k_training = top_k_training,
generator_top_k_gamma = generator_top_k_gamma,
generator_top_k_frac = generator_top_k_frac,
dataset_aug_prob = dataset_aug_prob,
calculate_fid_every = calculate_fid_every,
mixed_prob = mixed_prob
if generate:
model = Trainer(**model_args)
samples_name = timestamped_filename()
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
fig.subplots_adjust(bottom=0.2, top=0.75)
ax_Ef = fig.add_axes([0.3, 0.85, 0.4, 0.05])
def tile(a, dim, n_tile):
init_dim = a.size(dim)
repeat_idx = [1] * a.dim()
repeat_idx[dim] = n_tile
a = a.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(0)
return torch.index_select(a, dim, order_index)
num_rows = 1
latent_dim = 512
latents_low = noise(num_rows ** 2, latent_dim, device=0)
latents_high = noise(num_rows ** 2, latent_dim, device=0)
n =image_noise(num_rows ** 2, image_size, device=0)
#tmp1 = tile(n, 0, num_rows)
#tmp2 = n.repeat(num_rows, 1)
num_layers = int(log2(image_size) - 1)
point = 0
ratios = torch.linspace(0., 8., 1000)
point = ratios[int(point)]
interp_latents = slerp(point, latents_low, latents_high)
latents = [(interp_latents, num_layers)]
s_Ef = Slider(ax=ax_Ef, label='Axis1', valmin=0, valmax=1000,
valfmt=' %1.1f ', facecolor='#cc7000')
def update(val):
point = round(val)
#f_d.set_data(point, model.grad_map(n, point))
interp_latents = slerp(point, latents_low, latents_high)
latents = [(interp_latents, num_layers)]
im1=model.grad_map(n, latents)
f_d = ax.imshow(im1)
im1 = model.grad_map(n, latents)
f_d = ax.imshow(im1)
print(f'sample images generated at {results_dir}/{name}/{samples_name}')
def noise(n, latent_dim, device):
return torch.randn(n, latent_dim).cuda(device)
def slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm * high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
return res
def main():
if __name__ == '__main__':