How to use event function when each time is to be pre-processed differently?
Thanks for your sharing!
I use a GNN as ode_interface and initial it using a unique graph at each time. The simplified code is shown below.
odefunc = GNN()
times = torch.linspace(0., 1., 10)
z = torch.randn()
for i in range(len(times)):
odefunc.set_graph(edge[i])
integration_time = torch.tensor([times[i], times[i+1]).float()
solution = odeint_adjoint(odefunc, z, integration_time)
z = solution[-1]
Because the odefunc needs to be updated at each time, the odeint calculation can only be performed at adjacent times. How can I introduce an event function in this case? It seems difficult to use odeint_event directly. Any help is appreciated.
Not entirely sure I follow the issue.
If the issue is that the ODE needs to be updated once you solve past t_{i+1}, then you can also set the time interval as an event (using g(t, x) = t - t_{i+1}), and if this event triggers (you can check the time of event after it returns from odeint_event), then update the odefunc. This effectively allows you to define a different event function within each time interval.