vq-vae-2-pytorch icon indicating copy to clipboard operation
vq-vae-2-pytorch copied to clipboard

[Question] What is PixelSnail? How to Train it?

Open EibrielInv opened this issue 5 years ago • 35 comments

Hi! I'm failing to understand the function of PixelSnail. Is it to generate a latent space similar to a GAN?

I trained VQVAE correctly (until the samples were good enought):

python train_vqvae.py ./dataset_path

Then I performed a test to train PixelSnail. Is it correct?

Extracted the codes (I assume that are the encoding for each image on the dataset):

python extract_code.py --ckpt checkpoint/vqvae_241.pt --name small256 ./dataset_path

Then I trained Top hierarchy (about 30 minutes per batch, only trained 1 batch):

python train_pixelsnail.py --hier top --batch 8 small256

Then I trained Bottom hierarchy (about 30 minutes per batch, only trained 1 batch):

python train_pixelsnail.py --hier bottom --batch 8 small256

And finally I sampled:

python sample.py --vqvae checkpoint/vqvae_001.pt --top checkpoint/pixelsnail_top_001.pt --bottom checkpoint/pixelsnail_bottom_001.pt output.png

The output, as expected, is just noise since I only trained 1 batch on Pixelsnail.

output

If I just keep training PixelSnail will I be able to obtain good samples?

Hardware: NVIDIA 1080Ti

Thank you!

EibrielInv avatar Jun 24 '19 01:06 EibrielInv

Yes, it will generates sample of latent code for VQ-VAE. I checked it can make some samples if you train enough. But you will need to use a quite large model.

rosinality avatar Jun 26 '19 12:06 rosinality

would you mind sharing samples ? Just to get an idea of what to expect

pclucas14 avatar Jun 26 '19 13:06 pclucas14

sample Not very nice, but it is from somewhat smaller model than the model in the paper.

rosinality avatar Jun 26 '19 13:06 rosinality

that's pretty good! thanks for sharing :)

pclucas14 avatar Jun 26 '19 13:06 pclucas14

looks great! would you mind sharing your (hyper-)parameter setting and the resultant accuracy of top/bottom PixelSNAIL for this result?

1Konny avatar Jun 27 '19 04:06 1Konny

Top
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 0
  • n_out_res_block: 5
  • attention: True
  • dropout: 0.1
  • batch size: 63 (1e-4) / 64 (1e-5)

Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48%

Bottom
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 5
  • cond_res_channel: 512
  • n_out_res_block: 0
  • attention: False
  • dropout: 0.1
  • batch size: 64

Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21%

I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements.

Hope this helps.

rosinality avatar Jun 27 '19 05:06 rosinality

it really helps! thanks for the details.

1Konny avatar Jun 27 '19 05:06 1Konny

Top
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 0
  • n_out_res_block: 5
  • attention: True
  • dropout: 0.1
  • batch size: 63 (1e-4) / 64 (1e-5)

Trained 109 epochs with lr 1e-4 and 9 epochs with lr 1e-5, and accuracy was about 48%

Bottom
  • channel: 512
  • n_block: 4
  • n_res_block: 5
  • res_channel: 512
  • n_cond_res_block: 5
  • cond_res_channel: 512
  • n_out_res_block: 0
  • attention: False
  • dropout: 0.1
  • batch size: 64

Trained 70 epochs with lr 1e-4 and 4 epochs with lr 1e-5, and accuracy was about 21%

I think you can increase res_channel & dropout to match the hyperparameters in the paper, but I can't use that setting because of large amount of memory requirements.

Hope this helps.

Thank you for sharing details, can you also share how many GPU you use to train these networks?

k-eak avatar Jul 03 '19 02:07 k-eak

@k-eak I have used 4 V100s with mixed precision training.

rosinality avatar Jul 03 '19 03:07 rosinality

@k-eak I have used 4 V100s with mixed precision training.

@rosinality Thanks for sharing the details. I am curious about the mixed precision training, do you use some package ? dose mixed training help you increase the batch size? Is it possible to share more on this part?

ywang370 avatar Jul 09 '19 22:07 ywang370

@ywang370 I have used NVIDIA apex amp (https://github.com/NVIDIA/apex), with opt_level O1. I think mixed precision training was quite helpful for increasing batch sizes and reducing training times. It is hard to compare directly as GPU is different, but with mixed precision training on V100s is more than 2x faster with 2x batch sizes than FP32 training on P40s.

rosinality avatar Jul 09 '19 23:07 rosinality

@rosinality Hi I am having 2 P100, is there any improvement if I use apex for training the Pixel SNAIL in general ? Would you mind sharing the code you have used to enable mixed precision training you have mentioned above using apex ? I can not find where did you use apex in the github repo.

phongnhhn92 avatar Jul 18 '19 15:07 phongnhhn92

@phongnhhn92 Added simple support form apex at 7a2fbda.

rosinality avatar Jul 21 '19 12:07 rosinality

Hi, I have train vqvae and I got very similar images. my dataset is 159 images. then I train extract_code.py (my point here) How many checkpoints should I use in the end!?

after that I tried to train train_pixelsnail.py ( ervery time I got a problem in line 40 in dataset.py it is about no decode) then i tried to check if the lmdb file has some data or not , i print the env.state and I got this out put ({'psize': 4096, 'depth': 0, 'branch_pages': 0, 'leaf_pages': 0, 'overflow_pages': 0, 'entries':)

I am trying to solve it but it is not working.

thanks a lot.

zaitoun90 avatar Aug 14 '19 21:08 zaitoun90

How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32).

karamarieliu avatar Aug 14 '19 21:08 karamarieliu

How long did it take you per epoch (and how many iterations did you have in an epoch)? I'm finding it takes a considerable amount of time (~7 hours for 34k iterations of batch size 32).

for which one did you mean ( train_vqvae.py or extract_code.py)!? for both of them, is not that much 15 mints. I have a small dataset and I am using 2x gtx 1080 GPU.

for train_pixelsnail.py I ma not succeed till now, I have the above problem.

I used the original parameters and I change the batch_size to 32 for(train_vqvae.py).

zaitoun90 avatar Aug 14 '19 21:08 zaitoun90

@zaitoun90 Could you recheck extract_code.py step? I think it might lmdb related problems. @karamarieliu Yes, train_pixelsnail.py requires a lot of time as PixelSNAIL model is quite large.

rosinality avatar Aug 14 '19 23:08 rosinality

@rosinality thanks, now it is working.

zaitoun90 avatar Aug 20 '19 08:08 zaitoun90

Hi one more question, I run everything correctly but still I am getting the samples similar to the original one. I thought that I can generate different images !? Could be!!:: @rosinality is the samples that you shared is different from the original dataset!?

zaitoun90 avatar Aug 23 '19 10:08 zaitoun90

@zaitoun90 Do you mean output from ground truth code input? Then it should be similar to input images. To get samples from the model you can use sample.py.

rosinality avatar Aug 23 '19 14:08 rosinality

@rosinality yes, I use sample.py but still, the output similar to the input images. I expect after this long training of vqvae and pxielsnail that I can generate different samples.

zaitoun90 avatar Aug 23 '19 14:08 zaitoun90

@zaitoun90 sample.py doesn't use image inputs. sample.py should generate samples from scratch.

rosinality avatar Aug 24 '19 06:08 rosinality

@rosinality when I check the sample.py I noticed that F.one_hot function seem to be taking too much of time (190 seconds for top-level with batch size:32). I tried to change it with a scatter function to update it according to the previous samples but for some reason, the network is processing much slower now. Do you have any idea why this is happening and have any suggestions on how to improve the sampling time?

k-eak avatar Aug 29 '19 22:08 k-eak

@k-eak Current implementation is quite inefficient, for example one_hot will operate on sequences of 16896 elements per example at top-level. Maybe you can use some kind of caching. I also have tried to implement caching, but I got only 2x improvements...

rosinality avatar Aug 30 '19 04:08 rosinality

@rosinality Thank you for the suggestion. I replaced the one-hot and now I update it after each sample with the scatter function. Although this improved the speed compared one_hot, the network is now taking longer to process and in the end, the improvement is very small. Do you think I am missing something?

Here is the changed sampling code: (I removed one_hot function in pixelsnail)

row = torch.zeros(batch, 512, *size, dtype=torch.int64).to(device)
row_sample = torch.zeros(batch, *size, dtype=torch.int64).to(device)
cache = {}
for i in tqdm(range(size[0])):
    for j in range(size[1]):
        out, cache = model(row[:, :, : i + 1, :], condition=condition, cache=cache)
        prob = torch.softmax(out[:, :, i, j] / temperature, 1)
        sample = torch.multinomial(prob, 1)
        row[:,:,i,j] = row[:,:,i,j].scatter(1, sample, 1)
        row_sample[:, i, j] = sample.squeeze(-1)
return row_sample

k-eak avatar Aug 30 '19 22:08 k-eak

@k-eak Did you added torch.cuda.synchronize()? I think speed measurement can be inaccurate because of asynchronous nature of PyTorch. Also speed gain can be small as much of the computation will occur in the rest of the model.

rosinality avatar Sep 01 '19 01:09 rosinality

@rosinality oh my bad, I needed to add torch.cuda.synchronize(). So my method does not change the speed that much and mostly saves a couple of seconds for large batches. I might try adding caching idea from "https://github.com/PrajitR/fast-pixel-cnn/blob/master/fast_pixel_cnn_pp/fast_nn.py" but might take some time to implement it on PyTorch.

k-eak avatar Sep 01 '19 02:09 k-eak

When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on?

I've got these settings:

amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512

Mut1nyJD avatar Sep 06 '19 11:09 Mut1nyJD

@Mut1nyJD

D When I use train_pixelsnail.py accuracy immediately hits 1.0 and the loss goes to basically zero after less than 100 iterations. This feels weird to me, what is going on?

I've got these settings:

amp='O0', batch=12, channel=512, ckpt=None, dropout=0.1, epoch=200, hier='top', lr=0.0001, n_cond_res_block=0, n_out_res_block=5, n_res_block=5, n_res_channel=512

I have the same issue. In my case the data_set might be "too simple", this is just my guess... What about your data???

Slimco86 avatar Oct 02 '19 14:10 Slimco86

@Slimco86

No I don't think it is too simple, I am using this on here:

https://www.mut1ny.com/peoplepose20k

Mut1nyJD avatar Oct 02 '19 20:10 Mut1nyJD