conditional-flow-matching
conditional-flow-matching copied to clipboard
non-integer condition allowed for image generation with Unet
- added more tests for running forward functions of models
- test which expands Unet to take floating point conditions (or anything else)
- edited Unet to accept custom embedding net (this allows for non-integer labels to condition the generation)
- added demo notebook to demonstrate behavior
What does this PR do?
Fixes #163
No breaking changes, all tests pass.
Before submitting
- [x] Did you make sure title is self-explanatory and the description concisely explains the PR?
- [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
- [x] Did you list all the breaking changes introduced by this pull request?
- [x] Did you test your PR locally with
pytestcommand? - [x] Did you run pre-commit hooks with
pre-commit run -acommand?
Did you have fun?
Make sure you had fun coding 🙃
Summary by Sourcery
Enhance Unet model to support non-integer and custom embedding conditions for image generation
New Features:
- Allow Unet to accept custom embedding networks for conditioning
- Support non-integer labels for model conditioning
Bug Fixes:
- Fix label embedding for non-class conditional networks
Enhancements:
- Modify Unet forward method to be more flexible with label embeddings
- Remove strict type checking for conditional labels
Tests:
- Add test for Unet initialization
- Add test for MLP model
- Add test for conditional model with non-integer labels