diffusers
diffusers copied to clipboard
training fails with torch.float16 dtype - weights are deleted in train loop
Describe the bug
i'm using the "textual inversion" training notebooks/scripts. could never get the TPU setup in Flax working, but i blame google for that one - they need to keep up their evil "yeah we got paying beta testers buahahaha" game
in the pytorch one, if using torch.float16 data type, on the second iteration of the training loop, the text_encoder's weights somehow get reset, thus making the "placeholder_token_id"'s weights all nans instead of either whatever they were after setting from initializer_token_id's, or what that was plus grads. there's no error raised, i think it's suppressed somewhere? because when trying some methods directly, i get "that ain't implemented for half() dtype, yo" errors, inconsistently.
this doesn't happen with float32, nor does it happen with bfloat16, all running identical code... took me all. dang. day. to determine it was the dtype causing those weights to magically disappear...
bfloat16 performance is absolute $#!+, not sure why, since it's supposed to be faster according to pytorch's own documentation (but then again, it's pytorch, so "production quality" isn't in our expectations now is it precious). so, training works with float32, but only with batch size = 1 (on a 16gb gpu) and at double the time per iteration, vs float16, which does run double the batch size at double the speed, but all loss and grads are nans (so you'll get thru a full 90 mins of training without knowing that nothing happened).
Reproduction
https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb
adjustments and tests summarized by:
- put all args into dict then SimpleNamespace(**args) - including arg for dtype, which is set to torch.float16
- adjust all "from_pretrained" references to include "torch_dtype=args.dtype"
- make sure you're jammin some proper drum n bass, keep your spirits up...
- in training loop, add ".type(args.dtype)" to the assignments for noise and noisy_latents
- add an exception check for "if torch.isnan(loss).any():" so you don't waste your time
- after exception. you can then test that the weights are gone, because "token_embeds.weight[placeholder_token_id]" or "text_encoder.get_input_embeddings().weight[placeholder_token_id]" will now return nans, where just a few minutes ago, before cranking up the drum n bass, there were weights there equal to token_embeds.weight.data[initializer_token_id]
Logs
No response
System Info
colab. it's colab pro, not that pro does anything but cost more.
Hey @krahnikblis,
We don't actively maintain the Flax example script at the moment because we sadly don't find the time.
We'd love to review a PR though if you find time.
cc'ing @pcuenca and @duongna21 in case you have any ideas are curious
@patrickvonplaten actually, @krahnikblis is reporting a PyTorch error when training textual inversion using float16
. Pinging @patil-suraj in case he knows what's going on, otherwise happy to debug it myself (assign it to me Suraj if you don't have time to investigate it right now).
right, for clarity, the issue happens in the torch version, when using float16. my reference to bfloat16 was also torch version - it's super slow (slower than torch.float32), but works
regarding the flax version, i did get it working on TPUs (and of course using jnp.bloat16), but pretty much rewrote everything to make it work. but woweee does it run fast on flax w TPU!
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Gently pinging @patil-suraj in case he has time, otherwise I'll debug next week :)
Let me see if I can reproduce the error on main
.
I just tried this using diffusers
main
and couldn't reproduce it. The training worked fine in fp16
. Here's the command I used
accelerate launch --gpu_ids="0" --mixed_precision="fp16" \
../diffusers/examples/textual_inversion/textual_inversion.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATA_DIR \
--learnable_property="object" \
--placeholder_token="<cat-toy>" --initializer_token="toy" \
--resolution=512 \
--train_batch_size=4 \
--gradient_accumulation_steps=1 \
--max_train_steps=1000 --save_steps=2000 \
--learning_rate=2e-4 --scale_lr \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--output_dir=$OUTPUT_DIR \
--mixed_precision="fp16"
In that case I'm closing it for now. @krahnikblis feel free to reopen if you are still observing this :)