Vid-ODE icon indicating copy to clipboard operation
Vid-ODE copied to clipboard

Why ODE in Encoder is implemented this way ?

Open abdelwahed opened this issue 3 years ago • 1 comments

Thanks for sharing this and Congrats on a nice paper.

Starting from the line referenced (full code block underneath it), seems that ODE in the Encoder is not calling DiffeqSolver directly like the Decoder. Alternatively, the odeFunc is called and multiplied by the time step and the ODE is going backward. Any clues why it was implemented this way ?

https://github.com/psh01087/Vid-ODE/blob/040f0ab19f3d840c1ffef53a0bcdf01fcd2fc444/models/base_conv_gru.py#L151

  # Time configuration
  # Run ODE backwards and combine the y(t) estimates using gating
  prev_t, t_i = time_steps[-1] + 0.01, time_steps[-1]
  latent_ys = []
  
  time_points_iter = range(0, time_steps.size(-1))
  if run_backwards:
      time_points_iter = reversed(time_points_iter)
  
  for idx, i in enumerate(time_points_iter):

      inc = self.z0_diffeq_solver.ode_func(prev_t, prev_input_tensor) * (t_i - prev_t)
      assert (not torch.isnan(inc).any())
      tracker.write_info(key=f"inc{idx}", value=inc.clone().cpu())
      
      ode_sol = prev_input_tensor + inc
      tracker.write_info(key=f"prev_input_tensor{idx}", value=prev_input_tensor.clone().cpu())
      tracker.write_info(key=f"ode_sol{idx}", value=ode_sol.clone().cpu())
      ode_sol = torch.stack((prev_input_tensor, ode_sol), dim=1)  # [1, b, 2, c, h, w] => [b, 2, c, h, w]
      assert (not torch.isnan(ode_sol).any())
      
      if torch.mean(ode_sol[:, 0, :] - prev_input_tensor) >= 0.001:
          print("Error: first point of the ODE is not equal to initial value")
          print(torch.mean(ode_sol[:, :, 0, :] - prev_input_tensor))
          exit()
      
      yi_ode = ode_sol[:, -1, :]
      xi = input_tensor[:, i, :]
      
      # only 1 now
      yi = self.cell_list[0](input_tensor=xi,
                             h_cur=yi_ode,
                             mask=mask[:, i])
      
      # return to iteration
      prev_input_tensor = yi
      prev_t, t_i = time_steps[i], time_steps[i - 1]
      latent_ys.append(yi)
  
  latent_ys = torch.stack(latent_ys, 1)
  
  return yi, latent_ys 

abdelwahed avatar Nov 18 '21 19:11 abdelwahed

I also have the same question, any answers to this? @psh01087

ercanburak avatar May 28 '22 09:05 ercanburak