conditional-flow-matching
conditional-flow-matching copied to clipboard
OT-CFM performs worse on conditional generation tasks
Thanks very much for this code base, it's been a great way to learn about flow matching. I have a question regarding conditional generation with OT-CFM.
When testing different FM approaches on my own data, I noticed that OT-CFM trains significantly slower and tends to perform much worse on tasks with conditioning. In an effort to isolate this problem I tried conditional MNIST, comparing OT-CFM with FM (using the example provided).
After a single epoch of training, I visualized the generations of both approaches with 1 step and dopri5. FM is on the left, OT-CFM is on the right.
One step generation (euler with 1 step):
Adaptive generation with dopri5:
After one epoch of training, FM has much nicer generations for both 1 sampling step and with dopri5. Even after a longer training time, FM continues to outperform OT-CFM (converges much faster).
After reading more, I noticed that both OT-CFM and Multisample Flow Matching papers only report results for unconditional generation, while papers doing conditional generation such as Stable Diffusion 3 and Flow Matching in Latent Space seem to use standard flow matching without batch optimal transport.
I wonder if the authors have studied this, and if there are any results for OT-CFM conditional tasks, or perhaps if there is a reason or explanation that OT-CFM should not work in this setting. My intuition was that adding conditioning makes the combinatorial space of the OT plan extremely hard to approximate from the limited samples in the batch, and this would be further exaggerated if the conditioning is not on simple class labels but rather continuous values (for example language embeddings for text to image generation etc).
I would greatly appreciate any insight on this, and if there is an approach that is applicable to conditional generation. Thank you!
The code tweaks for this were:
sigma=0.0
if args.fm_method == "fm":
FM = TargetConditionalFlowMatcher(sigma=sigma)
elif args.fm_method == "otcfm":
FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
if args.fm_method == "fm":
t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
y1 = y
elif args.fm_method == "otcfm":
t, xt, ut, _, y1 = FM.guided_sample_location_and_conditional_flow(x0, x1, y1=y)