jax-cfd icon indicating copy to clipboard operation
jax-cfd copied to clipboard

trajectory function use target_trajectory as input in loss_and_gradient

Open gwen-git opened this issue 2 years ago • 0 comments

Hi,

I'm trying to understand the logic behind the various functions in ml.train_utils.

in the definition of loss_and_gradient the returned _loss function uses the target trajectory as part of the input of the trajectory_fn. I cannot make sense of it. The description of the function states that the trajectory_fn should accepts params and initial_velocity, which make sense to me, but then I don't understand why would we want to use the target_trajectory as the initial_velocity (which presumably doesn't have the same shape as initial_velocity since one is a trajectory and the other the velocity at a single time step).

I'm pretty much stuck here because I now don't know how to use this in the context of defining a train_step function etc.

Best regards,

Gwen

gwen-git avatar Feb 21 '23 15:02 gwen-git