conditional-flow-matching icon indicating copy to clipboard operation
conditional-flow-matching copied to clipboard

non-integer condition allowed for image generation with Unet

Open psteinb opened this issue 6 months ago • 2 comments

  • added more tests for running forward functions of models
  • test which expands Unet to take floating point conditions (or anything else)
  • edited Unet to accept custom embedding net (this allows for non-integer labels to condition the generation)
  • added demo notebook to demonstrate behavior

What does this PR do?

Fixes #163

No breaking changes, all tests pass.

Before submitting

  • [x] Did you make sure title is self-explanatory and the description concisely explains the PR?
  • [x] Did you make sure your PR does only one thing, instead of bundling different changes together?
  • [x] Did you list all the breaking changes introduced by this pull request?
  • [x] Did you test your PR locally with pytest command?
  • [x] Did you run pre-commit hooks with pre-commit run -a command?

Did you have fun?

Make sure you had fun coding 🙃

Summary by Sourcery

Enhance Unet model to support non-integer and custom embedding conditions for image generation

New Features:

  • Allow Unet to accept custom embedding networks for conditioning
  • Support non-integer labels for model conditioning

Bug Fixes:

  • Fix label embedding for non-class conditional networks

Enhancements:

  • Modify Unet forward method to be more flexible with label embeddings
  • Remove strict type checking for conditional labels

Tests:

  • Add test for Unet initialization
  • Add test for MLP model
  • Add test for conditional model with non-integer labels

psteinb avatar May 08 '25 16:05 psteinb