Vid-ODE
Vid-ODE copied to clipboard
Why ODE in Encoder is implemented this way ?
Thanks for sharing this and Congrats on a nice paper.
Starting from the line referenced (full code block underneath it), seems that ODE in the Encoder is not calling DiffeqSolver directly like the Decoder. Alternatively, the odeFunc is called and multiplied by the time step and the ODE is going backward. Any clues why it was implemented this way ?
https://github.com/psh01087/Vid-ODE/blob/040f0ab19f3d840c1ffef53a0bcdf01fcd2fc444/models/base_conv_gru.py#L151
# Time configuration
# Run ODE backwards and combine the y(t) estimates using gating
prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]
latent_ys = []
time_points_iter = range(0, time_steps.size(-1))
if run_backwards:
time_points_iter = reversed(time_points_iter)
for idx, i in enumerate(time_points_iter):
inc = self.z0_diffeq_solver.ode_func(prev_t, prev_input_tensor) * (t_i - prev_t)
assert (not torch.isnan(inc).any())
tracker.write_info(key=f"inc{idx}", value=inc.clone().cpu())
ode_sol = prev_input_tensor + inc
tracker.write_info(key=f"prev_input_tensor{idx}", value=prev_input_tensor.clone().cpu())
tracker.write_info(key=f"ode_sol{idx}", value=ode_sol.clone().cpu())
ode_sol = torch.stack((prev_input_tensor, ode_sol), dim=1) # [1, b, 2, c, h, w] => [b, 2, c, h, w]
assert (not torch.isnan(ode_sol).any())
if torch.mean(ode_sol[:, 0, :] - prev_input_tensor) >= 0.001:
print("Error: first point of the ODE is not equal to initial value")
print(torch.mean(ode_sol[:, :, 0, :] - prev_input_tensor))
exit()
yi_ode = ode_sol[:, -1, :]
xi = input_tensor[:, i, :]
# only 1 now
yi = self.cell_list[0](input_tensor=xi,
h_cur=yi_ode,
mask=mask[:, i])
# return to iteration
prev_input_tensor = yi
prev_t, t_i = time_steps[i], time_steps[i - 1]
latent_ys.append(yi)
latent_ys = torch.stack(latent_ys, 1)
return yi, latent_ys
I also have the same question, any answers to this? @psh01087