stylegan2-pytorch icon indicating copy to clipboard operation
stylegan2-pytorch copied to clipboard

how could I be able to move sliders in latent space? looking to mess more with my generation

Open erinbeesley opened this issue 3 years ago • 2 comments

erinbeesley avatar Sep 16 '20 17:09 erinbeesley

@erinbeesley Hey Erin, I'll look into exposing a way to use your model in code, so you may generate on whatever set of latent codes you wish

lucidrains avatar Sep 24 '20 22:09 lucidrains

Hi, I wanted to do this too and figured out a way- pretty rudimentary but it works! Add this function to cli.py then run it and it will pop up a matplotlib window with a slider that on click will evaluate the model at a new point. Each time you load the window it will choose a new linspace. @lucidrains I made this a while back, before you had the model-exposed-to-code set up. Sorry for butchering your code! I'll make a better one soon.

    def eval_model_points(
    data = './data',
    results_dir = './results',
    models_dir = './models',
    name = 'default',
    new = False,
    load_from = -1,
    image_size = 512,
    network_capacity = 16,
    transparent = False,
    batch_size = 5,
    gradient_accumulate_every = 6,
    num_train_steps = 150000,
    learning_rate = 2e-4,
    lr_mlp = 0.1,
    ttur_mult = 1.5,
    rel_disc_loss = False,
    num_workers =  None,
    save_every = 1000,
    evaluate_every = 1000,
    generate = True,
    generate_interpolation = True,
    interpolation_num_steps = 100,
    save_frames = True,
    num_image_tiles = 4,
    trunc_psi = 0.75,
    mixed_prob = 0.9,
    fp16 = False,
    cl_reg = False,
    fq_layers = [],
    fq_dict_size = 256,
    attn_layers = [],
    no_const = False,
    aug_prob = 0.,
    aug_types = ['translation', 'cutout'],
    top_k_training = False,
    generator_top_k_gamma = 0.99,
    generator_top_k_frac = 0.5,
    dataset_aug_prob = 0.,
    multi_gpus = True,
    calculate_fid_every = None,
    seed = 42
):
model_args = dict(
    name = name,
    results_dir = results_dir,
    models_dir = models_dir,
    batch_size = batch_size,
    gradient_accumulate_every = gradient_accumulate_every,
    image_size = image_size,
    network_capacity = network_capacity,
    transparent = transparent,
    lr = learning_rate,
    lr_mlp = lr_mlp,
    ttur_mult = ttur_mult,
    rel_disc_loss = rel_disc_loss,
    num_workers = num_workers,
    save_every = save_every,
    evaluate_every = evaluate_every,
    trunc_psi = trunc_psi,
    fp16 = fp16,
    cl_reg = cl_reg,
    fq_layers = fq_layers,
    fq_dict_size = fq_dict_size,
    attn_layers = attn_layers,
    no_const = no_const,
    aug_prob = aug_prob,
    aug_types = cast_list(aug_types),
    top_k_training = top_k_training,
    generator_top_k_gamma = generator_top_k_gamma,
    generator_top_k_frac = generator_top_k_frac,
    dataset_aug_prob = dataset_aug_prob,
    calculate_fid_every = calculate_fid_every,
    mixed_prob = mixed_prob
)

if generate:
    model = Trainer(**model_args)
    model.load(load_from)
    samples_name = timestamped_filename()
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111)
    fig.subplots_adjust(bottom=0.2, top=0.75)
    ax_Ef = fig.add_axes([0.3, 0.85, 0.4, 0.05])
    ax_Ef.spines['top'].set_visible(True)
    ax_Ef.spines['right'].set_visible(True)

    def tile(a, dim, n_tile):
        init_dim = a.size(dim)
        repeat_idx = [1] * a.dim()
        repeat_idx[dim] = n_tile
        a = a.repeat(*(repeat_idx))
        order_index = torch.LongTensor(
            np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda(0)
        return torch.index_select(a, dim, order_index)

    num_rows = 1
    latent_dim = 512
    latents_low = noise(num_rows ** 2, latent_dim, device=0)
    latents_high = noise(num_rows ** 2, latent_dim, device=0)
    n =image_noise(num_rows ** 2, image_size, device=0)
    #tmp1 = tile(n, 0, num_rows)
    #tmp2 = n.repeat(num_rows, 1)
    num_layers = int(log2(image_size) - 1)
    point = 0
    ratios = torch.linspace(0., 8., 1000)
    point = ratios[int(point)]

    interp_latents = slerp(point, latents_low, latents_high)
    latents = [(interp_latents, num_layers)]

    s_Ef = Slider(ax=ax_Ef, label='Axis1', valmin=0, valmax=1000,
                  valfmt=' %1.1f ', facecolor='#cc7000')


    def update(val):

        point = round(val)
        #f_d.set_data(point, model.grad_map(n, point))

        interp_latents = slerp(point, latents_low, latents_high)
        latents = [(interp_latents, num_layers)]


        im1=model.grad_map(n, latents)
        f_d = ax.imshow(im1)
        fig.canvas.draw_idle()
    s_Ef.on_changed(update)



      im1 = model.grad_map(n, latents)

      f_d = ax.imshow(im1)

      plt.show()
      print(f'sample images generated at {results_dir}/{name}/{samples_name}')
      return

def noise(n, latent_dim, device):
    return torch.randn(n, latent_dim).cuda(device)

def slerp(val, low, high):
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm * high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res

def main():
    eval_model_points()
    #fire.Fire(eval_model_points)
    #fire.Fire(train_from_folder)

if __name__ == '__main__':
    main()

MLTQ avatar May 26 '21 14:05 MLTQ