diffusers icon indicating copy to clipboard operation
diffusers copied to clipboard

[mps] training / inference dtype issues

Open bghira opened this issue 10 months ago • 32 comments

when training on Diffusers without attention slicing, we see:

/AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:788: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'

but with attention slicing, this error disappears.

    # Base components to prepare
    if torch.backends.mps.is_available():
        accelerator.native_amp = False
    results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders)
    unet = results[0]
    if torch.backends.mps.is_available():
        unet.set_attention_slice()

however, once this issue is resolved, there is a new problem:

    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

this is caused by the following logic:

    # Check that all trainable models are in full precision
    low_precision_error_string = (
        "Please make sure to always have all model weights in full float32 precision when starting training - even if"
        " doing mixed precision training. copy of the weights should still be float32."
    )

    if accelerator.unwrap_model(unet).dtype != torch.float32:
        raise ValueError(
            f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
        )

    if (
        args.train_text_encoder
        and accelerator.unwrap_model(text_encoder).dtype != torch.float32
    ):
        raise ValueError(
            f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
            f" {low_precision_error_string}"
        )

which is done because the AdamW optimiser doesn't work with bf16 weights. however, thanks to @AmericanPresidentJimmyCarter we are now able to use the adamw_bfloat16 package to benefit from an optimizer that can handle pure bf16 training.

once that is the case, we can comment out the low precision warning code, and:

    unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    ).to(weight_dtype)

load the unet directly in the target precision level.

Originally posted by @bghira in https://github.com/huggingface/diffusers/issues/7530#issuecomment-2032234598

bghira avatar Apr 02 '24 14:04 bghira

@sayakpaul @pcuenca @DN6

bghira avatar Apr 02 '24 14:04 bghira

if we can rely on the bf16 fixed AdamW optimiser, we can save storage space for the fp32 weights. overall, training becomes more efficient and reliable. thoughts?

bghira avatar Apr 02 '24 14:04 bghira

I can try to upload a pypi package of the fixed optimiser later today. It is located here: https://github.com/AmericanPresidentJimmyCarter/test-torch-bfloat16-vit-training/blob/main/adam_bfloat16/init.py#L22

  unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    ).to(weight_dtype)

Would it apply to float16 too?

sayakpaul avatar Apr 02 '24 15:04 sayakpaul

Cc: @patil-suraj too.

sayakpaul avatar Apr 02 '24 15:04 sayakpaul

  unet = UNet2DConditionModel.from_pretrained(
        args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
    ).to(weight_dtype)

Would it apply to float16 too?

i can try, but since the sd 2.1 model needs the attention up block's precision upcasted to at least bf16 i didn't think it were useful to test. it will produce black outputs without Xformers in use for SD 2.1 in particular, which @patrickvonplaten investigated here

bghira avatar Apr 02 '24 15:04 bghira

the other crappy thing is, without autocast, we have to use the same dtype for vae and unet. this is probably fine because the SD 1.x/2.x VAE handles fp16 like a champ. but the u-net requires at least bf16. manually modifying the unet dtype seems to break the mitigations put in place on the up block attn.

bghira avatar Apr 02 '24 15:04 bghira

fp16 is usually degraded and when you are training with it you need to use tricks like gradient scaling to work at all. For training scripts I am not sure we should recommend not using fp32 with fp16 autocast.

The situation with bfloat16 is different and it seems with the correct optimiser it will always perform near equally to float32. The downside is that old devices may not support it.

older CUDA devices emulate bf16, eg. the T4 on Colab.

Apple MPS supports it with Pytorch 2.3

AMD ROCm i think is the outlier, but it also seems to have emulation. pinging @Beinsezii for clarification

bghira avatar Apr 02 '24 15:04 bghira

bf16 runs on ROCm with a 5-10% perf hit over fp16. Not sure of the implementation details but it's worked for everything I've needed it to.

I only have an RDNA 3 card, so it may not work as cleanly on the older GPUs without WMMAs

Beinsezii avatar Apr 02 '24 18:04 Beinsezii

RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.

can't do fp16 on SD 2.1, as it goes into NaN at the 0th output.

bghira avatar Apr 02 '24 19:04 bghira

something still not quite right for mps, maybe the lack of 8bit optimisers really hurts more than i'd think, haha.

we see sampling speed improvements up to bsz=8 and then it hits swap space on a 128G unit

image

bghira avatar Apr 03 '24 21:04 bghira

image

bghira avatar Apr 03 '24 22:04 bghira

Pypi is here https://pypi.org/project/adamw-bf16/

Repo is here https://github.com/AmericanPresidentJimmyCarter/adamw-bf16

tested the above training implementation on (so far) 300 steps of ptx0/photo-concept-bucket at a decent learning rate and batch size of 4 on an apple m3 max

it's definitely learning. image

compared to what unmodified optimizer does image

i mean, it added human eyes to the tiger which might not be unexpected when zero images of tigers exist in the dataset, but it's certainly not what happens now with a fixed bf16 adamw implementation.

edit: importantly, the fixed optimizer does not (so far) run into NaNs unlike the currently-available selections of optimizers

bghira avatar Apr 04 '24 20:04 bghira

unfortunately i hit a NaN at the 628th step of training, approximately the same place as before

bghira avatar Apr 05 '24 01:04 bghira

Ufff. Why that damn step? 😭

sayakpaul avatar Apr 05 '24 02:04 sayakpaul

looks like it could be https://github.com/pytorch/pytorch/issues/118115 as both optimizers in use that fail in this way do use addcdiv

bghira avatar Apr 05 '24 02:04 bghira

looks like it could be pytorch/pytorch#118115 as both optimizers in use that fail in this way do use addcdiv

I will look into a fix for this too, if it's just .contiguous it should be easy.

it crashed after 628 steps and then on resume, it crashed after 300 steps, on the 901st.

it also seems to get a lot slower than it should sometimes - but it was mentioned that could be heat-related. i doubt it's fully thermal problem, but it seems odd

bghira avatar Apr 05 '24 13:04 bghira

@sayakpaul you know what it ended up being is a cached latent with NaN values. i ran the SDXL VAE in fp16 mode since i was using pytorch 2.2 a few days ago, and that didn't support bf16. it worked pretty well, but i guess one or two of the files had corrupt outputs. so there's no inherent issue with backward pass on torch mps causing nan in 2.3. the one bug that i was observing in pytorch with a small reproducer is now backported to 2.3 as well, as of yesterday morning

the torch compile is now more stable on 2.4, as well - so aot_eager makes a few annoyances with performance go away.

so it's shaping up to be a fair level of support at this juncture for mps and i'll be able to work on that this weekend

bghira avatar Apr 05 '24 23:04 bghira

you know what it ended up being is a cached latent with NaN values.

You were caching latents? Even with that how were there NaNs? VAE issue or something else?

sayakpaul avatar Apr 06 '24 02:04 sayakpaul

using the madebyollins sdxl vae fp16 model it occasionally NaNs, but not often enough to find the issue right away

bghira avatar Apr 06 '24 02:04 bghira

So many sudden glitches.

sayakpaul avatar Apr 06 '24 03:04 sayakpaul

on a new platform, the workarounds that are required for all platforms might not be added yet.

eg. cuda handles type casting automatically, but mps requires strict types - any of the cuda workarounds for issues people saw >1 year ago are now forgotten. we have to essentially rediscover how cuda needed to work, and apply a lot of the same changes to MPS.

i am removing fp16 training from my own trainer. fp32 is there, but i don't know why anyone would use it.

pure bf16 with the fixed optimizer is the new hero here

bghira avatar Apr 06 '24 12:04 bghira

Thanks for investigating. I guess it’s just about time now fp16 support gets fixes. If people are aware of these findings I think it should still be okay. But fp16 inference — I don’t think we can throw that one out yet.

sayakpaul avatar Apr 06 '24 12:04 sayakpaul

fp16 inference is thrown out long ago

  • sdxl's vae doesn't work with it
  • sd 2.1's unet doesn't work with it

bghira avatar Apr 06 '24 12:04 bghira

Will respectfully disagree. Not all the cards equally support bfloat16 well.

sayakpaul avatar Apr 06 '24 12:04 sayakpaul

the ones that don't are going to be upcasting about half of the information to fp32 - eg. the GT 1060 also doesn't support fp16. NVIDIA used to lock it behind a Quadro card purchase.

in any case i don't think the cards that fail to support BF16 are useful for training.

the Google Colab T4 emulates bf16 behind the scenes. which others are there?

bghira avatar Apr 06 '24 12:04 bghira

@bghira now that you fixed the latents issue, is bf16 training well with my optim?