ControlNet
ControlNet copied to clipboard
Add Support for Python 3.11 and PyTorch 2.0
Raising some issues I found when running ControlNet with Python 3.11 and PyTorch 2.0. The changes required to run with these versions this are outlined below:
Fix 1:
Encountered a ValueError when running python tool_add_control_sd21.py ./models/v2-1_768-ema-pruned.ckpt ./models/control_sd21_ini.ckpt.
Error message:
ValueError: mutable default <class 'timm.models.maxxvit.MaxxVitConvCfg'> for field conv_cfg is not allowed: use default_factory
Solution: Implemented the fix from the following commit in HuggingFace's PyTorch Image Models repository: https://github.com/huggingface/pytorch-image-models/pull/1649/commits/a4823653b940482b553ae5e2fe02086847029b36
Alternatively, you can use the pre-release version of timm by running pip install --pre timm.
Fix 2:
Encountered an ImportError for rank_zero_only from pytorch_lightning.utilities.distributed.
Replace:
from pytorch_lightning.utilities.distributed import rank_zero_only
With:
from pytorch_lightning.utilities.rank_zero import rank_zero_only
Fix 3:
remove dataloader_idx param in 'ImageLogger' and 'def on_train_batch_start(self, batch, batch_idx):'
Then you can train successfully with Python 3.11 and PyTorch 2.0, also added cpu core variable for more efficient data loading.
Did you notice any speedup when using pt 2.0 and compiling the models?
@SamSamhuns
Most of the speedup during training and inference will be by manually adding the new optimization techniques. ~~This PR serves as a bare bones setup of the latest python/pytorch version and the rest is really up to the end user.~~
See: https://huggingface.co/docs/diffusers/optimization/torch2.0
@tensorneko Hi, I am using version 1.5, and I followed your modifications. It worked fine. ~~I encountered an error related to the missing dataloader_idx, thus I did anything with that part. I am not certain if this issue is due to differences between versions 1.5 and 2.1.~~ Additionally, perhaps you could include your changes in the tutorial_train.py file. Furthermore, maybe torch.compile could be incorporated in your pull request.