TATS
TATS copied to clipboard
Script for training condition on text
Thanks for the awesome project. How can I train final model condition on text?
Thank you for your interest in our project! We are working on the final steps of conditional training scripts as well as a colab notebook. The notebook is almost done and please feel free to take a try now! (https://colab.research.google.com/drive/1yblr4IolH91ZA61FfZyk2n8rvndCIFmm?usp=sharing). Please stay tuned on the training scripts!
Thanks for the notebook and I will look into it.
We have added the relevant instructions on how to train a conditional transformer and sample videos with it. Please check the ReadMe file and feel free to let me know if you have any questions!
Thanks for uploading the code.
Hi, If I want to load pretrained weight for text to video generation and fine tuned the model for another dataset. I found there are two pretrained weight (1) vqgan and (2) gpt_ckpt. I can use with --vqvae vqgan_ckpt. But how to load gpt_ckpt to load the transformer weight?
Moreover, do I need to save vqgan_ckp and gpt_ckpt to use test script? If yes how can I do that? I found there is only one ckpt saved during training, no separate ckpt for vqgan and gpt?
Hi, if you want to fine tune from a checkpoint, you can use the --resume_from_checkpoint=your_checkpoint_file
when using any training scripts. If you are using pyotch-lightning > 2.0, this should be --ckpt_path
. See https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html?highlight=resume_from_checkpoint#resume-from-checkpoint.
Yes, you need both vqgan and gpt ckpts to do generation. They are trained separately. VQGAN is first trained. When training transformer models, VQGAN weights are used but frozen, so only gpt ckpt is saved.
Thanks for the guidelines. Another question to train text-video generation which data-loader did you use? tats/coinrun/coinrun_dataset.py OR tats/coinrun/coinrun_dataset_v2.py.
When I tried to load coinrun_dataset.py I get "KeyError: 'ground'" error.
Did you use train.json for training mugen dataset? or modified dataset somehow?
I tried to fine-tuned using given pretrained wight, but it generates wired video. Do I need to decrease learning rate when finetune with given weight?
We used an old, internal version of the MUGEN dataset. I will be working on the improvements to support the new official dataloader.
I don't have any experience with the finetuning. How was your loss curve for finetuning?
When I fine-tune using new data loader, the loss is decreasing and acc1 and acc5 is quite good during training. But during inference, the generated video become worst.
@trahman8 Hi, have u tried to change text_seq_len
?