conditional-flow-matching icon indicating copy to clipboard operation
conditional-flow-matching copied to clipboard

NeuralODE trajectory API is quite limiting

Open rsanchezgarc opened this issue 10 months ago • 2 comments

Hi,

I am trying to use your package with the from torchcfm.models.unet.unet import SuperResModel and other custom models that have kwargs in their forward method, but I think that the NeuralODE.trajectory method is not compatible with those models?

Could you please try to add a model_kwargs parameters to NeuralODE.trajectory, NeuralODE.forward, etc?

Thanks!

rsanchezgarc avatar Apr 22 '24 09:04 rsanchezgarc

Hello,

This seems to be a problem with Torchdyn. A workaround might be to use torchdiffeq instead. You could also write your own custom Euler integration method. Unfortunately, as the NeurIPS deadline is only one month away, I will not have time to look to this issue especially as it is not really related to TorchCFM but rather to Torchdyn.

Best, K.

kilianFatras avatar Apr 23 '24 20:04 kilianFatras

I would implement this by inheriting from the SuperResModel class i.e.

class MySuperResModel(SuperResModel):
    def forward(t, x):
         return super().forward(t, x, model_kwargs)

I think we do this in the conditional example

atong01 avatar Apr 24 '24 00:04 atong01