torchdyn
torchdyn copied to clipboard
Passing in multiple arguments
Additional Description
I have a network I wish to train f(x, x_dot, theta) where x and x_dot are the inputs, theta are the network weights. This is a slightly odd problem since x_dot is the corrupted derivative of x and I wish to train a network to give me the correct x_dot. To solve the ODE, I need to pass in x at t=0 but the network itself doesn't use x in its forward pass, only x_dot.
How would I passi n multiple arguments like this to a NeuralODE in torchdyn? I am guessing the way to do this is to concatenate the two so I get x_x_dot = torch.cat((x, x_dot)) but I am not sure if this is correct.
In torchdiffeq, what I did was call the solver like so
class Network(nn.Module):
def forward(t, args):
x, x_dot = args
return self.mlp(x_dot)
y = odeint_adjoint(network, (x_i, x_dot), t_span)
what would be the equivalent approach in torchdyn?
Thanks for the Q!
The concat approach should work just fine but I agree that it's not necessarily the most transparent.
Using a pytree based approach is likely most flexible in the long run, though, so that's also something we're keeping an eye on.
A pytree seems like a large hammer for a small nail. If there is no direct support for passing in tuples as arguments, I imagine that would be easier to add in the short term. Just a couple of checks (isinstance(x, tuple) and isinstance(x[i], torch.Tensor) and then continue from there.
Update here, fixed_odeint supports the state as a dict for now. We have yet to extend it to the adaptive solver.
Hi, I have similar needs when coding my repo based on torchdyn.
I need the function of odeint to support passing some data type as (dx, dlog(x)) for building generative models such as a continuous normalizing flow. The variable of x should be tensor of any shape, while dlog is simply a scalar. (I tried to reshape and concat these tensors into one tensor and do the reverse when calling modules. But it is harmful for grad_fn to go backward.)
Temporarily, it seems that I have to turn back to torchdiffeq, which accepts tuple data type input.
I suggest torchdyn to support tree-like tensor data type input. One of the implementation is https://github.com/opendilab/treevalue.