HiDT icon indicating copy to clipboard operation
HiDT copied to clipboard

Questions about the wikiart training sets

Open Rancherzhang opened this issue 4 years ago • 3 comments

Hello, I want to reproduce your great job, but to my limited knowledge, I have two questions right now. Firstly, I'm trying to rewrite the training phrase and beginning to train on the wikiart with content-dir of 'wikiart/Rococo' while style-dir of 'wikiart/Symbolism', but the intermediate result is not good as you, so I want to know what content-dir and style-dir you choose on the wikiart datasets? Secondly, my loss on style distribution could not converge, it is always around between 4.2-4.4. My code is as below:

class StyleDistLoss(nn.Module):
    '''
    style distribition loss of s and s'
    '''
    def __init__(self, pool_size):
        super(StyleDistLoss, self).__init__()
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_style_batch = 0
            self.style_batches = []
        self.loss = nn.L1Loss()

    def __call__(self, sc, st):
        '''
            return the standart Gaussian distribution loss of input 
            style source {sc} and style traget {st} which are respective to s and s' in the paper
        '''
        styles = []
        if self.pool_size == 0:
            styles.extend([sc, st])
        else:
            styles += self.style_batches
            styles.extend([sc, st])

            detach_sc = sc.clone().detach()
            detach_st = st.clone().detach()

            if self.num_style_batch + 2 < self.pool_size:
                self.style_batches.extend([detach_sc, detach_st])
                self.num_style_batch += 2
            else:
                random_idx = [x for x in range(self.num_style_batch)]
                random.shuffle(random_idx)
                self.style_batches[random_idx[0]] = detach_sc
                self.style_batches[random_idx[1]] = detach_st
        tensor_styles = torch.squeeze(torch.cat(styles, 0))
        styles_mean = torch.mean(tensor_styles, dim=0)
        tminuss = tensor_styles - styles_mean
        cov = torch.mm(tminuss.t(), tminuss) / tensor_styles.shape[0]
        std_cov = cov.diag(diagonal=0)
        total_loss = self.loss(styles_mean, torch.zeros_like(styles_mean))
        total_loss += self.loss(cov, torch.ones_like(cov))
        total_loss += self.loss(std_cov, torch.ones_like(std_cov))
        return total_loss

Could you please give me some advice? Thanks!

Rancherzhang avatar Oct 27 '20 02:10 Rancherzhang

Hi! I believe you have issues with lhe line total_loss += self.loss(cov, torch.ones_like(cov))

torch.ones_like(cov) returns a matrix filled with 1s, not an identity matrix. Therefore, your loss does not enforce the disentanglement of the components of the style vector. Try torch.eye instead.

denkorzh avatar Oct 27 '20 07:10 denkorzh

Hi! I believe you have issues with lhe line total_loss += self.loss(cov, torch.ones_like(cov))

torch.ones_like(cov) returns a matrix filled with 1s, not an identity matrix. Therefore, your loss does not enforce the disentanglement of the components of the style vector. Try torch.eye instead.

Thank you for your advice, I will have a try! Besides, I have a question about training on wikiart datasets, because I have noticed that in inference step in your readme.md file, you use landscape images as content images while wikiart images as style images, so, whether I should use the same strategy in my training stage on the wikiart?

Rancherzhang avatar Oct 27 '20 08:10 Rancherzhang

Sorry for the delayed reply. For style transfer model, both content and style images are being sampled from the wikiart dataset. We observe that trained in this manner, the model can be applied to the real images

belkakari avatar Nov 12 '20 13:11 belkakari