torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

How to use event function when each time is to be pre-processed differently?

Open xlk369293141 opened this issue 2 years ago • 1 comments

Thanks for your sharing!

I use a GNN as ode_interface and initial it using a unique graph at each time. The simplified code is shown below.

odefunc = GNN()
times = torch.linspace(0., 1., 10)

z = torch.randn()
for i in range(len(times)):
    odefunc.set_graph(edge[i])
    integration_time = torch.tensor([times[i], times[i+1]).float()
    solution = odeint_adjoint(odefunc, z, integration_time)
    z = solution[-1]

Because the odefunc needs to be updated at each time, the odeint calculation can only be performed at adjacent times. How can I introduce an event function in this case? It seems difficult to use odeint_event directly. Any help is appreciated.

xlk369293141 avatar Mar 05 '23 07:03 xlk369293141

Not entirely sure I follow the issue.

If the issue is that the ODE needs to be updated once you solve past t_{i+1}, then you can also set the time interval as an event (using g(t, x) = t - t_{i+1}), and if this event triggers (you can check the time of event after it returns from odeint_event), then update the odefunc. This effectively allows you to define a different event function within each time interval.

rtqichen avatar Mar 06 '23 22:03 rtqichen