DeepDeformable3DCaricatures icon indicating copy to clipboard operation
DeepDeformable3DCaricatures copied to clipboard

Training Parameter Recommendation for Dataset of Both Caricature and Regular Face Models

Open mynametia opened this issue 2 years ago • 7 comments

Hello! Thank you very much for sharing your research.

I am currently using your model to train a dataset of both 3DCariShop and regular 3D heads generated from CelebA. Both types of head models have the same connectivity and have similar number of samples (~1000).

Unfortunately the lowest training error I could achieve is around 40, using the same training parameters detailed in the DD3C paper with shuffling turned on for the training dataset.

Training Error over 3000 Epochs

Would you have any advice as to why the training error converges at such a high number and how to lower it? I have tried to follow the paper's methodology as closely as possible, with exception of the training dataset used. Thank you for your time and any help will be greatly appreciated.

My training code is as follows:

def train(opt, dirs, trainset_num, checkpt_dir='checkpoints', set_shuffle=True):
    checkpoints_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), opt.model_dir, checkpt_dir)

    utils.cond_mkdir(checkpoints_dir)
    meta_params = vars(opt)

    # set up model
    model = SurfaceDeformationField(sum(trainset_num), **meta_params).cuda()
    params = model.parameters()

    train_dataset = CombinedDataset(dirs=dirs, num_samples=11551)
    train_dataloader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=set_shuffle)

    # set up optimiser
    optimizer = torch.optim.Adam(params, lr=1e-4)
    total_steps = 0
    steps_til_summary = opt.epochs_til_checkpoint

    with tqdm(total=len(train_dataloader) * 3000) as pbar:
        train_losses = []
        for epoch in range(3000):
            if not epoch % opt.epochs_til_checkpoint:
                torch.save(model.state_dict(),
                           os.path.join(checkpoints_dir, 'model_epoch_%04d.pth' % epoch))
                np.savetxt(os.path.join(checkpoints_dir, 'train_losses_epoch_%04d.txt' % epoch),
                           np.array(train_losses))
                # valid(model,valid_dataloader)

            for i, (model_input, gt) in enumerate(train_dataloader):
                model_input = {key: value.cuda() for key, value in model_input[0].items()}
                gt = {key: value.cuda() for key, value in gt[0].items()}

                losses = model.forward(model_input, gt)
                train_loss = 0.
                for loss_name, loss in losses.items():
                    single_loss = loss.mean()
                    train_loss += single_loss
                train_losses.append(train_loss.item())

                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
                pbar.update(1)
                if not total_steps % steps_til_summary:
                    tqdm.write("Epoch %d, Total loss %0.6f, iteration time %0.6f" % (
                        epoch, train_loss, time.time() - start_time))
                total_steps += 1

        torch.save(model.state_dict(),
                   os.path.join(checkpoints_dir, 'model_final.pth'))
        np.savetxt(os.path.join(checkpoints_dir, 'train_losses_final.txt'),
                   np.array(train_losses))

mynametia avatar Nov 19 '22 12:11 mynametia

Hi, thanks for the interest in our research.

I think you can check these first:

  1. When you try only 3DCaricShop or only the regular face dataset, do you get a good result?
  2. Did you align the 3D models in your mixed dataset? (E.g., when you open one model from 3DCaricShop and another model from the regular face dataset, are they roughly aligned in 3D space with similar scale?) Otherwise the network would have hard time learning the latent manifold.
  3. Did you check the visual result? What does the 3D shape generated from an optimized latent code look like?

ycjungSubhuman avatar Nov 21 '22 02:11 ycjungSubhuman

Thank you very much for the suggestions! I will run the model as per your first point to check. In response to points 2 and 3:

  1. I am using the 3DCaricShop and FaceWareHouse models from the original trainset (left and right), and models generated from CelebA (centre). It seems that the origin of the 3DCaricShop and CelebA models are at the nose tip while the origin of the FaceWareHouse models are between the nose and mouth (where the red line passes through in the background). Is that sufficient cause for training problems? image

  2. This is the result from reconstructing the latent code at epoch 1500 with the setup from the paper and shuffle=False (I am still generating the validation results for shuffle=True) image

mynametia avatar Nov 21 '22 10:11 mynametia

Does the visual results depict the 3D shapes that corresponds to the optimized latent codes for the training set? First try generating the shape using the codes optimized during the training. (model.latent_codes[i]) They should at least overfit the training set.

If the overfitting does not work, I suspect there might be some problem in the dataset. check if the vertex correspondence between the meshes and the 3D alignment of the shapes is correct.

If the overfitting works, then you need to graph a validation error curve and find the checkpoint with the minimum validation error. Without this early stopping with validation curve, optimized codes for unseen shapes may produce incorrect shapes with large error or even with severe noise.

ycjungSubhuman avatar Nov 29 '22 04:11 ycjungSubhuman

Hi, thank you so much for your advice.

I reconstructed the training set shapes (1268 3DCaricShop caricatures and 150 FaceWareHouse regular heads) from model.latent_codes using the 1500-epoch model trained to recreate the paper results. The reconstructed caricatures turned out resembling closer to regular heads and the regular heads could not be reconstructed well at all:

image

I have checked the vertex correspondence and all models have the same vertex connectivity. I am wondering if the issue lies with the Dataset class I am using, which is as shown below. Could I trouble you to check if it looks correct?

class CombinedDatasetPaperAccurate(Dataset):
    def __init__(self,
                 dirs,
                 num_samples=11581,
                 use_landmarks=False):

        self.caricshopsubset = CaricShop3DTrain(dirs[0],
                                                 num_samples=num_samples,
                                                 use_landmarks=use_landmarks)
        self.facewarehousesubset = FaceWareHouseTrain(dirs[1],
                                                       num_samples=num_samples,
                                                       use_landmarks=use_landmarks)

        # load training dataset from paths_obj
        self.caricshop_ds = self.caricshopsubset.ds
        self.facewarehouse_ds = self.facewarehousesubset.ds

        self.ds = torch.utils.data.ConcatDataset([self.caricshop_ds,
                                                  self.facewarehouse_ds])

    def __getitem__(self, i):
        return self.ds[i]

    def __len__(self):
        return len(self.ds)

Otherwise, I am unsure of what else could be wrong, I have based the training code exactly off of training_loop_surface.py. Thank you once again for all your help thus far.

mynametia avatar Dec 03 '22 22:12 mynametia

I think it is a dataset problem.

I have checked the vertex correspondence and all models have the same vertex connectivity.

Could you provide more details on this. What did you do to check this? Or could you share your dataset with us so that we can inspect and see what is the problem?

ycjungSubhuman avatar Dec 03 '22 22:12 ycjungSubhuman

Could you provide more details on this. What did you do to check this?

I checked all the lines that started with f in the .obj file (f v1 v2 v3 ...) and they are all the same for all models in my dataset. Is this correct?

Here are 3 examples from the dataset I am using, 1 from 3DcaricShop, FaceWareHouse and CelebA each. Dataset_Examples.zip

mynametia avatar Dec 04 '22 12:12 mynametia

Hi, did you solve the problem? I checked the dataset examples you provided and it seems they share vertex correspondence with a good pose alignment.

ycjungSubhuman avatar Dec 29 '22 10:12 ycjungSubhuman