Fix `latents.dtype` before `vae.decode()` at ROCm devices in `StableDiffusionPipeline`s
Co-authored-by: Bagheera [email protected]
What does this PR do?
This was talked about at a previous PR: #7858.
Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline?
- [x] Did you read our philosophy doc (important for complex PRs)?
- [x] Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [ ] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. @yiyixuxu @bghira
so i can actually hit this one on CUDA, but it only happens during training without autocast 🤔
the weights are in bf16 precision, not fp32. maybe this is what causes it?
Since I don't have a ROCm device, I can't work on this PR directly :/ What to do here? Doesn't inferencing with bf16 produce a casting error?
inferencing is when the error occurs, but that's generally during training time, as the components can be initialised with different weight dtypes. it's not really clear why the scheduler can change the dtype other than certain calculations end up with the default torch dtype of float32 when one isn't specified, and i think autocast takes care of this. but we can't rely on autocast existing or being in use, because not all platforms support it - and future training might not use it at all, instead relying on bf16 optimiser states.