transformer_latent_diffusion
transformer_latent_diffusion copied to clipboard
Adding Flash attention
Will improve the training and inference speed by a large margin!!
Hey the model uses https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html which should already use flash attention.