diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[core] Mochi T2V

Open a-r-r-o-w opened this issue 1 year ago • 1 comments

Fixes #9744

Github: https://github.com/genmoai/models Model: https://huggingface.co/genmo/mochi-1-preview

a-r-r-o-w avatar Oct 25 '24 07:10 a-r-r-o-w

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Tried out framewise-tiling by following what we have for CogVideoX. Results don't look too different motion-wise (can spot some subtle changes at lower FPS) but can see some differences in color/brightness/etc. (expected due to difference in decoding methods).

Internal discussion thread: https://huggingface.slack.com/archives/C065E480NN9/p1729948674659229

`replicate` padding (Original implementation without tiling) - OOM
`replicate` padding (Original implementation with spatial tiling)
`conv_cache` previous latent padding (framewise decoding)
`conv_cache` previous latent padding (framewise decoding + spatial tiling)

cc @yiyixuxu @DN6 @sayakpaul

a-r-r-o-w avatar Oct 29 '24 11:10 a-r-r-o-w

Hi, @a-r-r-o-w

I try to convert the model ckpts with scripts: https://github.com/huggingface/diffusers/blob/mochi/scripts/convert_mochi_to_diffusers.py

And create generation with pipeline in mochi branch: https://github.com/huggingface/diffusers/tree/mochi with code:

import torch from diffusers import MochiPipeline from diffusers.utils import export_to_video

pipe = MochiPipeline.from_pretrained("ckpts", torch_dtype=torch.bfloat16) pipe.to("cuda") prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."

frames = pipe(prompt, num_inference_steps=50, guidance_scale=4.5, num_frames=28, generator=torch.Generator(device="cuda").manual_seed(42), ).frames[0]

export_to_video(frames, "mochi.mp4")

However, the results video is only noise: (all datatype is bf16)

https://github.com/user-attachments/assets/e479f4d3-114f-4483-bf51-582c6ece8f91

feizc avatar Oct 30 '24 03:10 feizc

@feizc In scheduler/scheduler_config.json, change "invert_sigmas": false, to "invert_sigmas": true,

Ednaordinary avatar Oct 30 '24 03:10 Ednaordinary

Is there a working example code somewhere?

tin2tin avatar Nov 06 '24 08:11 tin2tin

hello @a-r-r-o-w , trying to convert the original genmo checkpoints into diffusers format , by the convert diffusers format using this script , the conversion for the encoder fails

encoder.down_blocks.2.norms.0.norm_layer.bias, encoder.block_out.attentions.0.to_out.0.bias, encoder.down_blocks.2.norms.5.norm_layer.weight, encoder.down_blocks.1.attentions.1.to_out.0.weight, encoder.down_blocks.2.attentions.4.to_q.weight, encoder.down_blocks.1.attentions.1.to_k.weight, encoder.down_blocks.2.attentions.1.to_out.0.bias, encoder.down_blocks.1.attentions.2.to_k.weight, encoder.block_out.attentions.2.to_k.weight, encoder.down_blocks.2.attentions.2.to_out.0.bias, encoder.down_blocks.1.attentions.0.to_v.weight, encoder.down_blocks.0.norms.0.norm_layer.bias, encoder.down_blocks.1.attentions.0.to_k.weight, encoder.down_blocks.1.norms.2.norm_layer.weight, encoder.down_blocks.0.attentions.1.to_v.weight, encoder.down_blocks.2.attentions.1.to_k.weight, encoder.down_blocks.0.attentions.1.to_out.0.bias, encoder.down_blocks.2.attentions.4.to_k.weight, encoder.down_blocks.1.attentions.3.to_v.weight, encoder.down_blocks.1.attentions.1.to_out.0.bias, encoder.down_blocks.0.norms.1.norm_layer.bias, encoder.down_blocks.2.attentions.4.to_out.0.weight, encoder.down_blocks.1.attentions.3.to_q.weight, encoder.down_blocks.0.norms.1.norm_layer.weight, encoder.down_blocks.1.attentions.3.to_out.0.weight, encoder.down_blocks.2.norms.2.norm_layer.bias, encoder.down_blocks.1.norms.0.norm_layer.weight, encoder.block_out.attentions.2.to_v.weight, encoder.down_blocks.2.attentions.0.to_k.weight, encoder.block_out.attentions.1.to_out.0.weight, encoder.down_blocks.2.attentions.1.to_out.0.weight, encoder.down_blocks.1.norms.2.norm_layer.bias, encoder.down_blocks.0.attentions.1.to_q.weight, encoder.down_blocks.2.attentions.5.to_out.0.bias, encoder.down_blocks.0.attentions.0.to_k.weight, encoder.down_blocks.1.attentions.2.to_out.0.weight, encoder.down_blocks.1.attentions.3.to_k.weight, encoder.block_out.norms.0.norm_layer.bias, encoder.down_blocks.0.attentions.2.to_k.weight, encoder.block_out.attentions.2.to_q.weight, encoder.down_blocks.0.attentions.0.to_out.0.weight, encoder.down_blocks.0.norms.2.norm_layer.weight, encoder.down_blocks.2.attentions.3.to_k.weight, encoder.block_out.norms.0.norm_layer.weight, encoder.block_out.attentions.1.to_q.weight, encoder.down_blocks.2.norms.0.norm_layer.weight, encoder.down_blocks.2.attentions.3.to_out.0.weight, encoder.down_blocks.0.attentions.0.to_out.0.bias, encoder.down_blocks.2.norms.2.norm_layer.weight, encoder.down_blocks.2.attentions.4.to_v.weight, encoder.down_blocks.2.norms.4.norm_layer.weight, encoder.block_out.norms.1.norm_layer.weight, encoder.down_blocks.1.attentions.3.to_out.0.bias, encoder.down_blocks.0.attentions.0.to_q.weight, encoder.block_out.attentions.2.to_out.0.weight, encoder.down_blocks.2.norms.1.norm_layer.bias, encoder.down_blocks.1.norms.0.norm_layer.bias, encoder.down_blocks.2.attentions.0.to_out.0.bias, encoder.block_out.attentions.0.to_v.weight, encoder.block_out.attentions.1.to_out.0.bias, encoder.down_blocks.2.attentions.5.to_v.weight, encoder.block_out.norms.2.norm_layer.bias, encoder.block_out.attentions.0.to_k.weight, encoder.down_blocks.0.attentions.1.to_k.weight, encoder.down_blocks.0.attentions.2.to_v.weight, encoder.down_blocks.1.attentions.0.to_out.0.bias, encoder.down_blocks.2.attentions.2.to_q.weight, encoder.down_blocks.2.attentions.4.to_out.0.bias, encoder.down_blocks.2.attentions.2.to_out.0.weight, encoder.down_blocks.1.attentions.2.to_v.weight, encoder.down_blocks.2.attentions.3.to_v.weight, encoder.down_blocks.2.norms.3.norm_layer.bias, encoder.down_blocks.2.attentions.5.to_q.weight, encoder.down_blocks.1.attentions.2.to_q.weight, encoder.down_blocks.2.norms.4.norm_layer.bias, encoder.block_out.attentions.0.to_out.0.weight, encoder.down_blocks.1.attentions.2.to_out.0.bias, encoder.down_blocks.2.attentions.3.to_q.weight, encoder.down_blocks.1.attentions.0.to_q.weight, encoder.down_blocks.1.norms.1.norm_layer.weight, encoder.down_blocks.2.attentions.5.to_out.0.weight, encoder.down_blocks.0.attentions.2.to_out.0.weight, encoder.down_blocks.1.attentions.1.to_q.weight, encoder.block_out.attentions.0.to_q.weight, encoder.down_blocks.2.attentions.1.to_q.weight, encoder.block_out.attentions.1.to_k.weight, encoder.down_blocks.1.norms.1.norm_layer.bias, encoder.down_blocks.2.norms.5.norm_layer.bias, encoder.down_blocks.2.attentions.2.to_k.weight, encoder.down_blocks.0.norms.2.norm_layer.bias, encoder.down_blocks.2.attentions.0.to_v.weight, encoder.block_out.norms.2.norm_layer.weight, encoder.down_blocks.0.norms.0.norm_layer.weight, encoder.down_blocks.1.norms.3.norm_layer.weight, encoder.down_blocks.0.attentions.2.to_q.weight, encoder.block_out.attentions.1.to_v.weight, encoder.down_blocks.1.attentions.1.to_v.weight, encoder.block_out.attentions.2.to_out.0.bias, encoder.down_blocks.1.norms.3.norm_layer.bias, encoder.down_blocks.2.attentions.5.to_k.weight, encoder.down_blocks.0.attentions.1.to_out.0.weight, encoder.down_blocks.0.attentions.2.to_out.0.bias, encoder.down_blocks.2.attentions.3.to_out.0.bias, encoder.down_blocks.2.attentions.1.to_v.weight, encoder.down_blocks.2.attentions.0.to_out.0.weight, encoder.down_blocks.2.attentions.2.to_v.weight, encoder.down_blocks.2.norms.1.norm_layer.weight, encoder.down_blocks.2.attentions.0.to_q.weight, encoder.down_blocks.2.norms.3.norm_layer.weight, encoder.down_blocks.0.attentions.0.to_v.weight, encoder.down_blocks.1.attentions.0.to_out.0.weight, encoder.block_out.norms.1.norm_layer.bias.

Please make sure to pass low_cpu_mem_usage=False and device_map=None if you want to randomly initialize those weights or else make sure your checkpoint file is correct. when converting to bf16 format this is the error I get , is there any fix for it

VikramxD avatar Nov 06 '24 11:11 VikramxD

@VikramxD seems like arrow make some change to covert_hf.py. last version of it works correctly.

foreverpiano avatar Nov 06 '24 11:11 foreverpiano

@tin2tin Check the main branch docs

https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi

Ednaordinary avatar Nov 06 '24 23:11 Ednaordinary

hey folks, thanks for the great work as always. I was trying the example posted here : https://huggingface.co/docs/diffusers/main/en/api/pipelines/mochi, but it fails with : Entry Not Found for url: https://huggingface.co/genmo/mochi-1-preview/resolve/main/model_index.json . I guess we got to wait for this to merge : https://huggingface.co/genmo/mochi-1-preview/discussions/18/files ?

agneet42 avatar Nov 08 '24 21:11 agneet42

@agneet42 You can add variant="refs/pr/18" to the model loading statement in the mean time

Ednaordinary avatar Nov 08 '24 21:11 Ednaordinary

thanks @Ednaordinary ! i used revision="refs/pr/18", and it works!

agneet42 avatar Nov 08 '24 21:11 agneet42

I'm getting RuntimeError: "replication_pad3d_cuda" not implemented for 'BFloat16'

edit: solved with

pip install --upgrade torch
pip install --upgrade xformers

nicollegah avatar Nov 09 '24 01:11 nicollegah

VikramSingh178 have shared working Diffusers weights here: https://huggingface.co/VikramSingh178/mochi-diffuser-bf16

However, with the Mochi1, I fail to get it to generate anything decent (without deformities and bad quality). Ex. this is 99 steps and 22 min on a 4090:

https://github.com/user-attachments/assets/273bd1fb-e42e-4ce4-b349-525daf7804b3

tin2tin avatar Nov 09 '24 03:11 tin2tin

looks similar for me

https://github.com/user-attachments/assets/b6ab9e80-755c-47ab-bcb0-380153fe73f8

nicollegah avatar Nov 09 '24 03:11 nicollegah

However, I fail to get it to generate anything decent (without deformities and bad quality). Ex. this is 99 steps and 22 min on a

+1

glide-the avatar Nov 09 '24 08:11 glide-the

hey @tin2tin , whats your parameters settings with this pipeline

I am using something similar to this and got this , with cpu offload and vae tiling enabled

{
    "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
    "negative_prompt": "((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), out of frame, extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))), out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))",
    "num_inference_steps": 30,
    "guidance_scale": 3.5,
    "height": 480,
    "width": 848,
    "num_frames": 150
}

https://github.com/user-attachments/assets/30233913-f047-4a8a-bcbd-dbc69293a16d

VikramxD avatar Nov 09 '24 10:11 VikramxD

I might not have specified a negative prompt, and generated much fewer frames. The rest is the same.

tin2tin avatar Nov 09 '24 11:11 tin2tin

Try using this checkpoint for the transformer https://huggingface.co/imnotednamode/mochi-1-preview-mix-nf4 in the pipeline , its nf4 quantized mixed in with bf16 gives decent results @tin2tin

VikramxD avatar Nov 09 '24 13:11 VikramxD

Try using this checkpoint for the transformer https://huggingface.co/imnotednamode/mochi-1-preview-mix-nf4 in the pipeline , its nf4 quantized mixed in with bf16 gives decent results @tin2tin

There, it says that we must convert the mochi checkpoint to diffusers format. Is this still up-to-date? Is this still required? I've been running

pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16) pipe.enable_model_cpu_offload() pipe.enable_vae_tiling() prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]

without converting

nicollegah avatar Nov 09 '24 13:11 nicollegah

Yes @nicollegah the documentation for that is out of date , but the checkpoint would work with it out of the box

transformer = MochiTransformer3DModel.from_pretrained("imnotednamode/mochi-1-preview-mix-nf4",torch_dtype=torch.bfloat16)
pipe = MochiPipeline.from_pretrained("VikramSingh178/mochi-diffuser-bf16", torch_dtype=torch.bfloat16,transformer=transformer)

this would work out of the box for the bf16 weights

VikramxD avatar Nov 09 '24 13:11 VikramxD

import torch
from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.utils import export_to_video
pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", revision="refs/pr/18", torch_dtype=torch.bfloat16 )
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
prompt="A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical atmosphere of this unique musical performance."

negative_prompt="((((ugly)))), (((duplicate))), ((morbid)), ((mutilated)), out of frame, extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), ((ugly)), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))), out of frame, ugly, extra limbs, (bad anatomy), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), mutated hands, (fused fingers), (too many fingers), (((long neck)))"
frames = pipe(prompt=prompt,negative_prompt=negative_prompt, num_inference_steps=30, guidance_scale=3.5, height=480, width=848, num_frames=64).frames[0]
export_to_video(frames, "mochi.mp4", fps=15)

https://github.com/user-attachments/assets/9eb26565-1e8f-408f-90f1-ae7e941aac9c

the results are quite far from what i find online. maybe another sampler would be better?

nicollegah avatar Nov 09 '24 15:11 nicollegah

Hi everyone. We can replicate the quality of the original model by running inference in full precision. Anything lower seems to affect the video entirely (FP32 video is not the same as FP16/BF16 video) in terms of quality and temporal motion. I have fully reproduced results only in FP32 from the original repository, but can take a look again to see if it wasn't an implementation mistake in Diffusers.

We know that padding tokens masking might be an issue for shorter prompts as @jzhang38 suggested, but even without them the videos for short prompts are still very high quality. I haven't found time to investigate this yet since I'm on vacation but will try and take a look soon.

Please keep in mind the following findings:

  • FP32 works best for inference quality
  • Higher frames values work best. Lower frames seem to lead to different artifacts in both original and Diffusers implementation from my testing.
  • Diffusers uses original implementation by default, which does not support framewise decoding and can lead to OOM. If you would like to decode more frames, make sure to enable framewise decoding implementation by doing pipe.vae._enable_framewise_decoding.

When comparing our implementation to theirs, I rewrote their code to use F.sdpa instead of flash attention, and the final outputs matched with an absmax difference of 0.005, which is essentially equivalent because these numerical differences can arise from order of operations. The outputs of flash attention and F.sdpa implementation are more dissimilar numerically but that is expected.

If you find any problems with the modeling conversion, please do report because that would be extremely valuable and helpful. I don't think there are numerical differences in VAE and transformer if we consider the same conditions and use deterministic algorithms, but if you find that there are any and report it, we would be super grateful!

a-r-r-o-w avatar Nov 09 '24 16:11 a-r-r-o-w

Hi @a-r-r-o-w! Thanks for the great reply.

I've also noticed that different precisions seem to change the output drastically, and I have a theory that the model is extremely reliant on the outer layers (i.e. similar to "proj_out" in sd3). It's also possible mochi relies a lot on fractional values for these layers, something that bf16 drops a lot of for exponent values. I think an ultimate solution in this regard will be #9177, as then block weights can be casted to bf16 (or smaller, like float8_e4m3fn) while outer weights can stay in fp32. This is shown especially in https://huggingface.co/imnotednamode/mochi-1-preview-mix-nf4-small, where keeping specific layers in bf16 helped significantly with quality (though is far from fp32)

The main point of optimizations in your implementation are for the vae, but by default a bf16 version of the transformer cannot fit the full 161 frames in a 24gb card (at least, I have not been able to). I would also like to find out just how important which components being in what precision matters, i.e. whether the vae being in fp32 or bf16 matters significantly, so I'm hoping you can provide some baseline latents from the FP32 transformer via the following (I'm unable to run the fp32 transformer on my card):

import torch
import numpy as np
from diffusers import MochiPipeline, AutoencoderKLMochi
torch.manual_seed(42)

model = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.float32, vae=None, revision="refs/pr/18")
model.enable_model_cpu_offload()
video = model("A squirrel runs around on a branch", num_inference_steps=50, guidance_scale=6.0, height=480, width=848, num_frames=61, output_type="latent")

np_frames = video.frames.cpu().float().numpy()
np.save("video.npy", np_frames)

That would be greatly appreciated!

Try using this checkpoint for the transformer https://huggingface.co/imnotednamode/mochi-1-preview-mix-nf4 in the pipeline , its nf4 quantized mixed in with bf16 gives decent results

Thanks for noticing my checkpoint! I had quite a fun time making it. I'll update docs to the current state.

Ednaordinary avatar Nov 09 '24 19:11 Ednaordinary

Yeah in our findings, Mochi-1 seems to be quite susceptible to precision loss and imnotednamode/mochi-1-preview-mix-nf4 does a good job of finding the layers that should be kept in the full precision. It could be an artifact of how Mochi was trained, though.

The original code suggests running inference with autocasting turned on with BF16 without downcasting with to().

I think an ultimate solution in this regard will be https://github.com/huggingface/diffusers/pull/9177, as then block weights can be casted to bf16 (or smaller, like float8_e4m3fn) while outer weights can stay in fp32.

Perhaps we could try out the dynamic upcasting function from the PR and verify if that helps? Would you maybe like to give that a try?

sayakpaul avatar Nov 09 '24 19:11 sayakpaul

@sayakpaul Unfortunately, I'm running into merge conflicts between layerwise upcasting and the current main branch. mochi got deleted, otherwise I would use it

Ednaordinary avatar Nov 09 '24 19:11 Ednaordinary

if this helps, here is a comparison between fp32 and bf16. The prompt was "A finger pressing a doorbell button", with 67 frames. Are either of the results satisfactory? Thoughts?

https://github.com/user-attachments/assets/ba0e051a-f7ff-494f-86e6-046d7a406dd1

https://github.com/user-attachments/assets/44233376-960a-4f10-9129-6f0c409741c4

Code :

pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", 
torch_dtype=torch.float32, revision="refs/pr/18").to('cuda') # update torch_dtype = float32/bf16

pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
pipe.vae._enable_framewise_decoding()

generator = torch.manual_seed(42)
frames = pipe(prompt, num_frames=67, num_inference_steps=50, guidance_scale=3.5, generator=generator).frames[0]

agneet42 avatar Nov 09 '24 20:11 agneet42

Unfortunately, I'm running into merge conflicts between layerwise upcasting and the current main branch. mochi got deleted, otherwise I would use it

@Ednaordinary oh I was suggesting to use the main function introduced in the PR.

# Credits to `dn6`
# Copy-pasted from 
# https://github.com/huggingface/diffusers/blob/layerwise-upcasting/src/diffusers/models/modeling_utils.py
def enable_layerwise_upcasting(model, upcast_dtype=None):
    upcast_dtype = upcast_dtype or torch.float32
    original_dtype = model.dtype
    print(f"{original_dtype=}")

    def upcast_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(upcast_dtype)

    def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(original_dtype)

    def fn_recursive_upcast(module):
        # Upcast entire module and exist recursion
        module.register_forward_pre_hook(upcast_dtype_hook_fn)
        module.register_forward_hook(cast_to_original_dtype_hook_fn)

        has_children = list(module.children())
        if not has_children:
            module.register_forward_pre_hook(upcast_dtype_hook_fn)
            module.register_forward_hook(cast_to_original_dtype_hook_fn)

        for child in module.children():
            fn_recursive_upcast(child)

    for module in model.children():
        fn_recursive_upcast(module)

And then using it like so:

enable_layerwise_upcasting(the_model, torch.bfloat16)

LMK if this is unclear.

sayakpaul avatar Nov 09 '24 20:11 sayakpaul

Okay, here's what I have after some testing.

from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.models.transformers.transformer_mochi import MochiTransformerBlock
import torch
from diffusers.utils import export_to_video

# The below defines how many layers to quantize. The amount of layers quantized is quant_div / quant_mod. Think of this as an inverted quality slider. Values should stay rather low to avoid edge cases
quant_div = 1
quant_mod = 2
full_dtype = torch.float16
cast_dtype = torch.float8_e5m2
torch.manual_seed(42)

if quant_div > quant_mod: print("quant_div should be less than or equal to quant_mod")

# Credits to `dn6`
# Copy-pasted + slight edit from 
# https://github.com/huggingface/diffusers/blob/layerwise-upcasting/src/diffusers/models/modeling_utils.py
def enable_layerwise_upcasting(model, upcast_dtype=None, original_dtype=None):
    upcast_dtype = upcast_dtype or torch.float32
    original_dtype = original_dtype or model.dtype

    def upcast_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(upcast_dtype)

    def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(original_dtype)

    def fn_recursive_upcast(module):
        # Upcast entire module and exist recursion
        module.register_forward_pre_hook(upcast_dtype_hook_fn)
        module.register_forward_hook(cast_to_original_dtype_hook_fn)

        has_children = list(module.children())
        if not has_children:
            module.register_forward_pre_hook(upcast_dtype_hook_fn)
            module.register_forward_hook(cast_to_original_dtype_hook_fn)

        for child in module.children():
            fn_recursive_upcast(child)

    for module in model.children():
        fn_recursive_upcast(module)

print("Loading transformer")
transformer = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", revision="refs/pr/18", subfolder="transformer", torch_dtype=full_dtype)
print("Adding cast hooks to transformer")
block_idx = 0
for idx, i in enumerate(transformer.modules()):
    if isinstance(i, MochiTransformerBlock):
        block_idx += quant_div
        if not block_idx % quant_mod and idx != 16 and idx != 2084: # 16 and 2084 are the first and last layer respectively, and should likely be skipped.
            print(idx)
            i.to(cast_dtype)
            enable_layerwise_upcasting(i, upcast_dtype=full_dtype, original_dtype=cast_dtype)

pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", revision="refs/pr/18", torch_dtype=full_dtype, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
#pipe.vae._enable_framewise_decoding()
frames = pipe("a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background", num_inference_steps=50, guidance_scale=8.0, height=480, width=848, num_frames=61).frames[0]
export_to_video(frames, "mochi.mp4", fps=15)

https://github.com/user-attachments/assets/b172b018-084d-4237-8e53-18cef310fbb6

It's obviously not full quality, but it's getting closer. My script lets you define a fractional value of block layers to keep and layers to cast to a lower bit type, and keeps all other layers in the original type. fp16 seems better as a high type (fp32 is the preferred type, but takes longer and more vram). This is what happens when I try using bf16, likely because mochi is heavily reliant on the fractional part of the float for whatever reason:

https://github.com/user-attachments/assets/8dcc331e-f30b-499c-9541-78f9d6a1b569

Ednaordinary avatar Nov 10 '24 05:11 Ednaordinary

Nice, thanks for the exploration. Maybe it'd be worth applying this upcasting thingy w.r.t some kind weight-to-activation norm ratio? For layers having higher norms, we don't apply it and for layers having lower norms, we apply. WDYT?

sayakpaul avatar Nov 10 '24 20:11 sayakpaul

I'm seeing similar if not slightly better results, though I'll be honest that I'm not sure if I've implemented it right. I also switched fp8 datatype to e4m3fn, which provides better fractional precision

https://github.com/user-attachments/assets/f25d89a3-e77a-429a-ac4f-b56c73c02a93

from diffusers import MochiPipeline, MochiTransformer3DModel
from diffusers.models.transformers.transformer_mochi import MochiTransformerBlock
import torch
from diffusers.utils import export_to_video

# The below defines how many layers to quantize. The amount of layers kept is quant_div / quant_mod. Think of this as an inverted quality slider
quant_div = 1
quant_mod = 2
full_dtype = torch.float16
cast_dtype = torch.float8_e4m3fn
torch.manual_seed(42)

if quant_div > quant_mod: print("quant_div should be less than or equal to quant_mod")

# Credits to `dn6`
# Copy-pasted + slight edit from 
# https://github.com/huggingface/diffusers/blob/layerwise-upcasting/src/diffusers/models/modeling_utils.py
def enable_layerwise_upcasting(model, upcast_dtype=None, original_dtype=None):
    upcast_dtype = upcast_dtype or torch.float32
    original_dtype = original_dtype or model.dtype

    def upcast_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(upcast_dtype)

    def cast_to_original_dtype_hook_fn(module, *args, **kwargs):
        module = module.to(original_dtype)

    def fn_recursive_upcast(module):
        # Upcast entire module and exist recursion
        module.register_forward_pre_hook(upcast_dtype_hook_fn)
        module.register_forward_hook(cast_to_original_dtype_hook_fn)

        has_children = list(module.children())
        if not has_children:
            module.register_forward_pre_hook(upcast_dtype_hook_fn)
            module.register_forward_hook(cast_to_original_dtype_hook_fn)

        for child in module.children():
            fn_recursive_upcast(child)

    for module in model.children():
        fn_recursive_upcast(module)

print("Loading transformer")
transformer = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", revision="refs/pr/18", subfolder="transformer", torch_dtype=torch.float32)
print("Adding cast hooks to transformer")
block_idx = 0
mean_attns = []
for idx, i in enumerate(transformer.modules()):
    if isinstance(i, MochiTransformerBlock) and idx != 2084: # 2084 is the last layer, and should likely be skipped.
        mean_attns.append((idx, torch.mean(torch.ravel(i.attn1.norm_k.weight)))) # can be changed with norm_q, values seem similar
mean_attns.sort(key=lambda x: x[1])
attn_high = [x[0] for x in mean_attns[:int(len(mean_attns) // (quant_mod / quant_div))]]
transformer.to(torch.float16)
for idx, i in enumerate(transformer.modules()):
    if isinstance(i, MochiTransformerBlock):
        if idx in attn_high:
            print(idx)
            i.to(cast_dtype)
            enable_layerwise_upcasting(i, upcast_dtype=full_dtype, original_dtype=cast_dtype)

pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", revision="refs/pr/18", torch_dtype=full_dtype, transformer=transformer)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
#pipe.vae._enable_framewise_decoding()
frames = pipe("a cat walks along the sidewalk of a city. The camera follows the cat at knee level. The city has many people and cars moving around, with advertisement billboards in the background", num_inference_steps=50, guidance_scale=8.0, height=480, width=848, num_frames=61).frames[0]
export_to_video(frames, "mochi.mp4", fps=15)

Ednaordinary avatar Nov 10 '24 23:11 Ednaordinary