style-augmentation icon indicating copy to clipboard operation
style-augmentation copied to clipboard

Shouldn't be StylePredictor and Transfer Network trained jointly ?

Open virgile-blg opened this issue 4 years ago • 2 comments

Hi,

Thanks for your work and sharing your code ! Here in the Style Predictor code you say you detach the embedding vector because you don't want to train "end-to-end".

In the original paper from Ghiasi et al., they explicitely say : "We find it sufficient to jointly train the style prediction network P(·) and style transfer network T(·) on a large corpus of photographs and paintings"

How can the Style Predictor net be trained then with your implementation ?

virgile-blg avatar Jun 03 '20 10:06 virgile-blg

Hi gilevir,

This repo doesn't contain any code to train the style augmentation network, it's just models, checkpoints and an example script. The idea is to provide a ready to use pretrained StyleAugmentor class that can be slotted into a standard training loop and just work without any training on the user's part. That's why that detach (and the one in StyleAugmentor) is there, it's a precaution that ensures that you don't accidentally backpropagate errors through the (restyled) input image and into the transformer or style predictor networks while training your own model. Doing so would definitely be bad - you'd end up changing the transformer weights in a way that minimized your training loss, whereas data augmentation is supposed to make the task harder at train time so that you generalize better to unseen images.

If you're wondering why this repo even has a StylePredictor when it's not supposed to train the transformer, it's there for style interpolation. So if you set alpha to something other than 1, the StyleAugmentor will linearly interpolate between a randomly sampled embedding and the predicted style of the input image (see Equation 6 in our paper). There's also the option to set useStylePredictor to False, in which case the mean style embedding from ImageNet will be used instead of computing the style of the input image, saving you a bit of compute.

If you want to (re)train the transformer network and style predictor yourself, you could remove the detach and then do as they do in Ghiasi et al.. I'm actually a bit confused why I've given an option to not detach the output in StyleAugmentor, since I can't think of a situation where you would want to backpropagate through it. If you wanted to retrain the transformer / style predictor then you wouldn't use the StyleAugmentor class, since it's just a wrapper around the transformer / style predictor for convenient usage in downstream tasks. Seems to me it would be more sensible to have a fixed detach in StyleAugmentor and no detach in Ghiasi or StylePredictor for that reason. Thanks for bringing this to my attention, I'll push a change up.

philipjackson avatar Aug 29 '20 12:08 philipjackson

Hi @philipjackson and thanks for your detailed answer !

I am indeed interested in re training the StylePredictor (as well as the transfer network). But i am still a bit confused how. Did you end up training those nets with a batch size of 8 as described in Ghiasi et al. paper ? If that is the case, would the style batch contains the same style at each step or the style batch contains 8 different random styles ?

Also, have you experimented implementing the two fully connected layer of StylePredictor as Conv2d with kernel size 1 (which seems to be the case in Magenta TF implementation if i am not wrong) ?

Thanks a lot !

virgile-blg avatar Sep 16 '20 17:09 virgile-blg