EG3D-projector icon indicating copy to clipboard operation
EG3D-projector copied to clipboard

slow opt time when I optimizer 2 img at the same time

Open renrenzsbbb opened this issue 2 years ago • 6 comments

Thans for your great work. when I project 1 img, the speed is normal. however, when I project 2 img in a batch, the speed will slow down 10 time, can you give me some advice.

renrenzsbbb avatar Aug 29 '22 14:08 renrenzsbbb

How did you project 2 img in a batch? Please provide more information or upload your code.

oneThousand1000 avatar Aug 29 '22 14:08 oneThousand1000

thans for your reply. I make a toy example. I only add selevel line in line135. the comment line is what I add. the opt time in per step time will change from 0.2s to 2s. can your give me some advise?

        ws = (w_opt + w_noise)
        # ws = ws.repeat(2,1,1)
        # if step == 0:
        #     target_images = target_images.repeat(2, 1, 1, 1)
        #     c = c.repeat(2, 1)
        synth_images = G.synthesis(ws,c, noise_mode='const')['image']

renrenzsbbb avatar Aug 30 '22 02:08 renrenzsbbb

Hi renrenzsbbb, it seems that you want to do two optimizations at the same time...

But you input 2 same latent codes into G, and G will output 2 same synth_images. The loss will simply be twice.

I recommend you project 1 image at one time.

oneThousand1000 avatar Aug 30 '22 03:08 oneThousand1000

Sorry, this is a toy example for test opt time. I want to optimize two different pose img at the same time rather than optimize one img second time. Thanks.

renrenzsbbb avatar Aug 30 '22 03:08 renrenzsbbb

The codes below optimizes a single ws latent code for different view images of a single person.

    for step in tqdm(range(num_steps)):

        # Learning rate schedule.
        t = step / num_steps
        w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
        lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
        lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
        lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
        lr = initial_learning_rate * lr_ramp
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        # Synth images from opt_w.
        w_noise = torch.randn_like(w_opt) * w_noise_scale
        ws = (w_opt + w_noise)

        for pose_idx , target in enumerate(targets):
            c = cameras[pose_idx]
            target_features = target_feature_list[pose_idx]
            synth_images = G.synthesis(ws, c, noise_mode='const')['image']

            # generated_images = self.G.synthesis(w, c, noise_mode='const')['image']

            if step % image_log_step == 0:
                with torch.no_grad():
                    vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

                    PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}_{pose_idx}.png')

            # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
            synth_images = (synth_images + 1) * (255 / 2)
            if synth_images.shape[2] > 256:
                synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')

            # Features for synth images.
            synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
            dist = (target_features - synth_features).square().sum()

            # Noise regularization.
            reg_loss = 0.0
            for v in noise_bufs.values():
                noise = v[None, None, :, :]  # must be [1,1,H,W] for F.avg_pool2d()
                while True:
                    reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
                    reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
                    if noise.shape[2] <= 8:
                        break
                    noise = F.avg_pool2d(noise, kernel_size=2)
            loss = dist + reg_loss * regularize_noise_weight

            if step % 100 == 0:
                with torch.no_grad():
                    print(
                        f'step {step}  dist loss: {dist.detach().cpu()} reg loss {reg_loss.detach().cpu() * regularize_noise_weight}')

            # Step
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
            logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')

            # Normalize noise.
            with torch.no_grad():
                for buf in noise_bufs.values():
                    buf -= buf.mean()
                    buf *= buf.square().mean().rsqrt()

oneThousand1000 avatar Aug 30 '22 04:08 oneThousand1000

Thanks for your reply. your code will opt only one img at the same time. Why can we opt 2 or more img at the same time.

renrenzsbbb avatar Aug 30 '22 06:08 renrenzsbbb