jax-cfd
jax-cfd copied to clipboard
trajectory function use target_trajectory as input in loss_and_gradient
Hi,
I'm trying to understand the logic behind the various functions in ml.train_utils.
in the definition of loss_and_gradient the returned _loss function uses the target trajectory as part of the input of the trajectory_fn. I cannot make sense of it. The description of the function states that the trajectory_fn should accepts params
and initial_velocity
, which make sense to me, but then I don't understand why would we want to use the target_trajectory as the initial_velocity (which presumably doesn't have the same shape as initial_velocity since one is a trajectory and the other the velocity at a single time step).
I'm pretty much stuck here because I now don't know how to use this in the context of defining a train_step function etc.
Best regards,
Gwen