diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

support for MeanFlow

Open prakashjayy opened this issue 5 months ago • 2 comments

The meanflow paper requires two things

  • support for multiple time embeddings.
  • calculating jvp

support for multiple time embeddings

my idea is to have a parameter called multiple_time_embeddings

model = UNet2DModel(
    sample_size=32,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    multiple_time_embeddings=True,
    block_out_channels=(64, 128, 256, 512),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
    up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
)
model.to(device)
print("model loaded")
  • in the init we can change this as
if multiple_time_embeddings:
    embedding_dim = block_out_channels[0]
    _time_embed_dim = time_embed_dim // 2
else:
    embedding_dim = block_out_channels[0]
    _time_embed_dim = time_embed_dim

if time_embedding_type == "fourier":
    self.time_proj = GaussianFourierProjection(embedding_size=embedding_dim, scale=16)
    timestep_input_dim = 2 * embedding_dim
elif time_embedding_type == "positional":
    self.time_proj = Timesteps(embedding_dim, flip_sin_to_cos, freq_shift)
    timestep_input_dim = embedding_dim
elif time_embedding_type == "learned":
    self.time_proj = nn.Embedding(num_train_timesteps, embedding_dim)
    timestep_input_dim = embedding_dim

self.time_embedding = TimestepEmbedding(timestep_input_dim, _time_embed_dim)

and in forward we can change this as

if self.config.multiple_time_embeddings:
      assert timestep.shape[1] == 2, "timestep should have 2 channels"
      timestep = timestep.flatten(0)

  # 1. time
  timesteps = timestep
  if not torch.is_tensor(timesteps):
      timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
  elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
      timesteps = timesteps[None].to(sample.device)

  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
  if self.config.multiple_time_embeddings:
      timesteps = timesteps * torch.ones(sample.shape[0]*2, dtype=timesteps.dtype, device=timesteps.device)
  else:
      timesteps = timesteps * torch.ones(sample.shape[0], dtype=timesteps.dtype, device=timesteps.device)

  t_emb = self.time_proj(timesteps)

  # timesteps does not contain any weights and will always return f32 tensors
  # but time_embedding might actually be running in fp16. so we need to cast here.
  # there might be better ways to encapsulate this.
  t_emb = t_emb.to(dtype=self.dtype)
  emb = self.time_embedding(t_emb)

  if self.config.multiple_time_embeddings:
      #divide into two parts and then add them up 
      bs = sample.shape[0]
      emb = torch.cat(torch.split(emb, bs, dim=0), dim=1)

calculating jvp

func_output, directional_deriv_jvp = torch.autograd.functional.jvp(
    model, 
    (xt.to(device), (rt*1000).to(device)), 
    (target.to(device), torch.tensor([0, 1]).repeat(xt.shape[0], 1).to(device))
)

using AttnProcessor2_0 throws the following error

RuntimeError: derivative for aten::_scaled_dot_product_efficient_attention_backward is not implemented

right now there is no way to make AttnProcessor default. it is automatically selected using hasattr(F, "scaled_dot_product_attention") and self.scale_qk.

what is the way forward for this?

prakashjayy avatar Jun 23 '25 13:06 prakashjayy

As of now I doing this

from diffusers.models.attention_processor import AttnProcessor


for blocks in model.down_blocks:
    if hasattr(blocks, "attentions"):
        for attn in blocks.attentions:
            attn.processor = AttnProcessor()

for blocks in model.up_blocks:
    if hasattr(blocks, "attentions"):
        for attn in blocks.attentions:
            attn.processor = AttnProcessor()

model.mid_block.attentions[0].processor = AttnProcessor()
print("processor set")

prakashjayy avatar Jun 23 '25 13:06 prakashjayy

my current approach is here https://github.com/prakashjayy/genai/blob/main/flow/02_mean_flow.ipynb

prakashjayy avatar Jun 23 '25 13:06 prakashjayy