diffusers
diffusers copied to clipboard
[mps] training / inference dtype issues
when training on Diffusers without attention slicing, we see:
/AppleInternal/Library/BuildRoots/ce725a5f-c761-11ee-a4ec-b6ef2fd8d87b/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:788: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'
but with attention slicing, this error disappears.
# Base components to prepare
if torch.backends.mps.is_available():
accelerator.native_amp = False
results = accelerator.prepare(unet, lr_scheduler, optimizer, *train_dataloaders)
unet = results[0]
if torch.backends.mps.is_available():
unet.set_attention_slice()
however, once this issue is resolved, there is a new problem:
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
this is caused by the following logic:
# Check that all trainable models are in full precision
low_precision_error_string = (
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if (
args.train_text_encoder
and accelerator.unwrap_model(text_encoder).dtype != torch.float32
):
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
)
which is done because the AdamW optimiser doesn't work with bf16 weights. however, thanks to @AmericanPresidentJimmyCarter we are now able to use the adamw_bfloat16 package to benefit from an optimizer that can handle pure bf16 training.
once that is the case, we can comment out the low precision warning code, and:
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
).to(weight_dtype)
load the unet directly in the target precision level.
Originally posted by @bghira in https://github.com/huggingface/diffusers/issues/7530#issuecomment-2032234598
@sayakpaul @pcuenca @DN6
if we can rely on the bf16 fixed AdamW optimiser, we can save storage space for the fp32 weights. overall, training becomes more efficient and reliable. thoughts?
I can try to upload a pypi package of the fixed optimiser later today. It is located here: https://github.com/AmericanPresidentJimmyCarter/test-torch-bfloat16-vit-training/blob/main/adam_bfloat16/init.py#L22
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
).to(weight_dtype)
Would it apply to float16 too?
Cc: @patil-suraj too.
unet = UNet2DConditionModel.from_pretrained( args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ).to(weight_dtype)
Would it apply to float16 too?
i can try, but since the sd 2.1 model needs the attention up block's precision upcasted to at least bf16 i didn't think it were useful to test. it will produce black outputs without Xformers in use for SD 2.1 in particular, which @patrickvonplaten investigated here
the other crappy thing is, without autocast, we have to use the same dtype for vae and unet. this is probably fine because the SD 1.x/2.x VAE handles fp16 like a champ. but the u-net requires at least bf16. manually modifying the unet dtype seems to break the mitigations put in place on the up block attn.
fp16 is usually degraded and when you are training with it you need to use tricks like gradient scaling to work at all. For training scripts I am not sure we should recommend not using fp32 with fp16 autocast.
The situation with bfloat16 is different and it seems with the correct optimiser it will always perform near equally to float32. The downside is that old devices may not support it.
older CUDA devices emulate bf16, eg. the T4 on Colab.
Apple MPS supports it with Pytorch 2.3
AMD ROCm i think is the outlier, but it also seems to have emulation. pinging @Beinsezii for clarification
bf16 runs on ROCm with a 5-10% perf hit over fp16. Not sure of the implementation details but it's worked for everything I've needed it to.
I only have an RDNA 3 card, so it may not work as cleanly on the older GPUs without WMMAs
RuntimeError: Function 'MseLossBackward0' returned nan values in its 0th output.
can't do fp16 on SD 2.1, as it goes into NaN at the 0th output.
something still not quite right for mps, maybe the lack of 8bit optimisers really hurts more than i'd think, haha.
we see sampling speed improvements up to bsz=8 and then it hits swap space on a 128G unit
Pypi is here https://pypi.org/project/adamw-bf16/
Repo is here https://github.com/AmericanPresidentJimmyCarter/adamw-bf16
tested the above training implementation on (so far) 300 steps of ptx0/photo-concept-bucket
at a decent learning rate and batch size of 4 on an apple m3 max
it's definitely learning.
compared to what unmodified optimizer does
i mean, it added human eyes to the tiger which might not be unexpected when zero images of tigers exist in the dataset, but it's certainly not what happens now with a fixed bf16 adamw implementation.
edit: importantly, the fixed optimizer does not (so far) run into NaNs unlike the currently-available selections of optimizers
unfortunately i hit a NaN at the 628th step of training, approximately the same place as before
Ufff. Why that damn step? ðŸ˜
looks like it could be https://github.com/pytorch/pytorch/issues/118115 as both optimizers in use that fail in this way do use addcdiv
looks like it could be pytorch/pytorch#118115 as both optimizers in use that fail in this way do use addcdiv
I will look into a fix for this too, if it's just .contiguous
it should be easy.
it crashed after 628 steps and then on resume, it crashed after 300 steps, on the 901st.
it also seems to get a lot slower than it should sometimes - but it was mentioned that could be heat-related. i doubt it's fully thermal problem, but it seems odd
@sayakpaul you know what it ended up being is a cached latent with NaN values. i ran the SDXL VAE in fp16 mode since i was using pytorch 2.2 a few days ago, and that didn't support bf16. it worked pretty well, but i guess one or two of the files had corrupt outputs. so there's no inherent issue with backward pass on torch mps causing nan in 2.3. the one bug that i was observing in pytorch with a small reproducer is now backported to 2.3 as well, as of yesterday morning
the torch compile is now more stable on 2.4, as well - so aot_eager makes a few annoyances with performance go away.
so it's shaping up to be a fair level of support at this juncture for mps and i'll be able to work on that this weekend
you know what it ended up being is a cached latent with NaN values.
You were caching latents? Even with that how were there NaNs? VAE issue or something else?
using the madebyollins sdxl vae fp16 model it occasionally NaNs, but not often enough to find the issue right away
So many sudden glitches.
on a new platform, the workarounds that are required for all platforms might not be added yet.
eg. cuda handles type casting automatically, but mps requires strict types - any of the cuda workarounds for issues people saw >1 year ago are now forgotten. we have to essentially rediscover how cuda needed to work, and apply a lot of the same changes to MPS.
i am removing fp16 training from my own trainer. fp32 is there, but i don't know why anyone would use it.
pure bf16 with the fixed optimizer is the new hero here
Thanks for investigating. I guess it’s just about time now fp16 support gets fixes. If people are aware of these findings I think it should still be okay. But fp16 inference — I don’t think we can throw that one out yet.
fp16 inference is thrown out long ago
- sdxl's vae doesn't work with it
- sd 2.1's unet doesn't work with it
Will respectfully disagree. Not all the cards equally support bfloat16 well.
the ones that don't are going to be upcasting about half of the information to fp32 - eg. the GT 1060 also doesn't support fp16. NVIDIA used to lock it behind a Quadro card purchase.
in any case i don't think the cards that fail to support BF16 are useful for training.
the Google Colab T4 emulates bf16 behind the scenes. which others are there?
@bghira now that you fixed the latents issue, is bf16 training well with my optim?