video-diffusion-pytorch icon indicating copy to clipboard operation
video-diffusion-pytorch copied to clipboard

Generating longer videos at test time

Open mrkulk opened this issue 2 years ago • 46 comments

Thank you for quickly implementing this model @lucidrains ! Maybe you already have or are planning to do this -- "To manage the computational requirements of training our models, we only train on a small subset of say 16 frames at a time. However, at test time we can generate longer videos by extending our samples." (sec 3.1). Currently the sample() function is fixed length. Wanted to check in with you before taking a stab

mrkulk avatar Apr 19 '22 03:04 mrkulk

You are welcome and no immediate plans yet for that portion. If you get to it first, do submit a PR :)

lucidrains avatar Apr 19 '22 03:04 lucidrains

@mrkulk still planning on giving it a stab?

lucidrains avatar Apr 26 '22 18:04 lucidrains

been planning to but came across a big issue in the stack trace before it. I am unable to get really good temporally coherence in the basic prediction task (gradient method is supposed to fix this but still getting choppy predictions)

On Tue, Apr 26, 2022 at 11:04 AM Phil Wang @.***> wrote:

@mrkulk https://github.com/mrkulk still planning on giving it a stab?

— Reply to this email directly, view it on GitHub https://github.com/lucidrains/video-diffusion-pytorch/issues/4#issuecomment-1110096558, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKPXKBBVVPXRX3FRWIGDDDVHAVY5ANCNFSM5TXOI2QA . You are receiving this because you were mentioned.Message ID: @.***>

mrkulk avatar Apr 27 '22 07:04 mrkulk

@mrkulk ahh got it, ok, i'll take a look at the gradient method tomorrow

lucidrains avatar Apr 27 '22 16:04 lucidrains

@lucidrains still training (moving mnist) and it might reach temporal coherence after more training but this is where it's at after 40k steps. W B Chart 4_28_2022, 11_45_40 PM 40kiters_32x_20steps

mrkulk avatar Apr 29 '22 06:04 mrkulk

@mrkulk cool! have you tried training it with periodic arresting of attention across time as described in the paper?

i'll have to revisit the gradient method next week

lucidrains avatar Apr 30 '22 00:04 lucidrains

@lucidrains let me start a run now and see

mrkulk avatar Apr 30 '22 00:04 mrkulk

much more stable but there is a temporal coherence issue. actually the problem might be deeper than inference sampling -- we would expect the predictions in just the training snippets to be consistent but it is not

W B Chart 5_2_2022, 12_24_23 AM

prompt_null_42092_2abb9dc6dfed3fcd0f20

mrkulk avatar May 02 '22 07:05 mrkulk

@mrkulk very cool! how often are you arresting the attention across time for the experiment above?

lucidrains avatar May 02 '22 19:05 lucidrains

@lucidrains tried this https://github.com/lucidrains/video-diffusion-pytorch/blob/main/video_diffusion_pytorch/video_diffusion_pytorch.py#L284. Do we mean the same thing or are you referring to something else?

mrkulk avatar May 02 '22 20:05 mrkulk

@mrkulk yup that's it! but what i got from the paper is that they didn't train with the arresting of attention across time exclusively? they traded off training normally vs restricting it to each frame. wasn't sure if this was done on a schedule, or alternating, or some other strategy. correct me if i'm way off base

lucidrains avatar May 03 '22 03:05 lucidrains

@lucidrains it looks like they are mixing random frames at the end of a video seq -- "To implement this joint training, we concatenate random independent image frames to the end of each video sampled from the dataset, and we mask the attention in the temporal attention blocks to prevent mixing information across video frames and each individual image frame. We choose these random independent images from random videos within the same dataset;"

mrkulk avatar May 03 '22 04:05 mrkulk

@lucidrains one interesting thing I noticed is that without focus_on_the_present turned out, it has a hard time even memorizing a single video frame. It does a good job with this turned on but its choppy even after a lot of training . Trying to see if it gets better with focus_on_the_present=np.random.uniform(0, 1) > 0.5 and focus_on_the_present=True if self.global_step <= 2000 else False prompt_null_4889_904d41c4ce3b1d4c1d1b

.

mrkulk avatar May 03 '22 23:05 mrkulk

@lucidrains one interesting thing I noticed is that without focus_on_the_present turned out, it has a hard time even memorizing a single video frame. It does a good job with this turned on but its choppy even after a lot of training . Trying to see if it gets better with focus_on_the_present=np.random.uniform(0, 1) > 0.5 and focus_on_the_present=True if self.global_step <= 2000 else False prompt_null_4889_904d41c4ce3b1d4c1d1b

.

ohh nice! is that generated?

lucidrains avatar May 04 '22 16:05 lucidrains

@lucidrains yes but it is a overfitting test but I suspect it will work (smooth motions won't happen or there might also be mixing of digits). I also ran a schedule (2k with focus on present and then turned it off). It didn't work as expected (although I rarely see some smooth motions). W B Chart 5_4_2022, 9_28_28 AM

It's fine around 2k when focus on present is on but then it 062ab3f9-4abd-4715-a8d9-72095886dfe0 prompt_null_1873_0a8135f5c62898e9bfaa diverges --

@lucidrains I am beginning to suspect its something in the Unet? Could it be attention or positional embeddings?

mrkulk avatar May 04 '22 16:05 mrkulk

@mrkulk positional embedding should be fine, i'm using classic T5 relative positional bias (could even switch to a stronger one, rotary embeddings, if need be)

let me offer a way to turn off the sparse linear attention i have at every layer, and we can debug to see if that is the culprit

i have also switched from resnet blocks to the newer convnext https://arxiv.org/abs/2201.03545 , but can always bring back resnets if somehow it isn't suitable for generative work

lucidrains avatar May 04 '22 16:05 lucidrains

@lucidrains it looks like they are mixing random frames at the end of a video seq -- "To implement this joint training, we concatenate random independent image frames to the end of each video sampled from the dataset, and we mask the attention in the temporal attention blocks to prevent mixing information across video frames and each individual image frame. We choose these random independent images from random videos within the same dataset;"

i see, this must be to counter overfitting, as most videos have very similar frames. i'll think about how to build this into the trainer

lucidrains avatar May 04 '22 16:05 lucidrains

@mrkulk https://github.com/lucidrains/video-diffusion-pytorch/commit/233c1d695e1a80267dac7ddd64d1d8acab17b1f6#diff-4ff1a95f5e6b9add82d0e523fd2d858ca38e67b393ea87c2ae88a8b14a0fbb1cR305 this should allow you to turn off the linear attention, in case that is causing the divergence

lucidrains avatar May 04 '22 16:05 lucidrains

@mrkulk ok, i've brought back the old resnet blocks in version 0.3.1, and started a run on my own moving mnist dataset

perhaps jumping onwards to convnext wasn't the greatest idea :sweat_smile:

lucidrains avatar May 04 '22 18:05 lucidrains

@lucidrains ok sounds good. will wait for your ping to do some more debugging/testing once you take a stab. btw before 0.3.1 I got errors on forward due to SpatialLinearAttention (goes away if you turn it off though). you may have probably already run into it

mrkulk avatar May 04 '22 18:05 mrkulk

@mrkulk yup, that attention error should be fixed! here is the experiment https://wandb.ai/lucidrains/video-diffusion-redo/reports/moving-mnist-video-diffusion--VmlldzoxOTQ3OTM0?accessToken=6m0nlx9992n6pind2j3113v03tbsps52v0rtkyw4jqotpgz99ziwlx2zsh6remna also, thanks for the sponsorship! :heart:

lucidrains avatar May 04 '22 19:05 lucidrains

@lucidrains awesome! the samples, even at the beginning look qualitatively different. loss seems to be steadily going down. wonder why its blobby

mrkulk avatar May 04 '22 22:05 mrkulk

@lucidrains awesome! the samples, even at the beginning look qualitatively different. loss seems to be steadily going down. wonder why its blobby

haha, i actually compared it to my previous convnext run and it looks about the same

this moving mnist dataset actually comes from https://github.com/RoshanRane/segmentation-moving-MNIST (minus the salt and pepper background noise)

i'll retry a run tonight with the new focus-on-present probabilities hyperparameter

lucidrains avatar May 05 '22 00:05 lucidrains

@lucidrains i was using a different trainer but if i use your's and this below code then it seems more reasonable (only 1k iters). The moving mnist is from: https://www.cs.toronto.edu/~nitish/unsupervised_video/

from moviepy.editor import ImageSequenceClip

def moving_mnist_gif_creator():
  root = os.path.expanduser('~/datasets/mnist_test_seq.npy')
  out_root = os.path.expanduser('~/datasets/moving_mnist')
  data = np.load(root)
  for ii in range(data.shape[1]):
    clip = ImageSequenceClip(list(data[:10, ii][..., None]), fps=20)
    name = str(ii) + '.gif'
    clip.write_gif(Path(out_root) / name, fps=60)

e8aa949b-b6c6-4445-8d86-58144cb2a3df

mrkulk avatar May 05 '22 01:05 mrkulk

@mrkulk It turns out I had the wrong settings

You are right, the old resnet blocks work much better than convnext blocks, and I will likely remove them today so not to confuse researchers

14

14k for moving mnist - it has already figured out some of the background objects, though have yet to segment the digits, but is already training way faster than a pure attention-based solution

lucidrains avatar May 05 '22 13:05 lucidrains

i've also added rotary positional embeddings to the temporal attention. that should definitely help no question

lucidrains avatar May 05 '22 13:05 lucidrains

sample_13000_72f8c3d673af2569cb7a

for reference, using convnext blocks at around 13k

lucidrains avatar May 05 '22 13:05 lucidrains

very interesting! can give this a shot today. didn't expect convnext to not work as well. Another interesting and somewhat related point --the transframer (https://arxiv.org/pdf/2203.09494.pdf) paper used NF-ResNet block within a 3d Unet for video generation. but still need to mentally consolidate and think about these variations

mrkulk avatar May 05 '22 16:05 mrkulk

sample_19000_6697151caa3aa2553f54

@mrkulk cool! was not aware of transframer - will need to queue that up for reading

second day of training, it has yet again improved (but still has not segmented the digits yet as the attention based one has, but it is still early in training)

there are these weird flickering artifacts as well, from frame to frame, and i'll have to see if they persist once i move the training onto some remote compute that came my way

lucidrains avatar May 06 '22 15:05 lucidrains

sample_19000_6697151caa3aa2553f54

@mrkulk cool! was not aware of transframer - will need to queue that up for reading

second day of training, it has yet again improved (but still has not segmented the digits yet as the attention based one has, but it is still early in training)

there are these weird flickering artifacts as well, from frame to frame, and i'll have to see if they persist once i move the training onto some remote compute that came my way

it seems to have stabilized on the background and filling in higher frequency details -- quite promising!

how many videos are you training this on and what batch size?

ran a test on 10k videos on the same dataset but smaller batch size (on a V100). This is without sparse linear attn. I think the trend is similar? :

https://wandb.ai/csmai/multimodal-video-generation/reports/moving-mnist-video_diffusion_model--VmlldzoxOTYwNjY0?accessToken=q26z0simmgj0zqjc61kfsixezrzvo3clidjkuujwl5ttrpivpoj3c5idj886qh4c

Seems like the digits will come towards the very end but it interesting that pure attn gets it. How many global steps are you in compared to the pure attn based one?

mrkulk avatar May 06 '22 18:05 mrkulk