torchcde icon indicating copy to clipboard operation
torchcde copied to clipboard

Sequence outputs from Neural ODE (similar to 'many to many' RNN)?

Open dactylogram opened this issue 2 years ago • 3 comments

Hi, Patrick! I'm currently training irregularly sampled data, and previously I used many to many RNN for modeling. When time-series data is sampled at the time of t1, t2, t3, and t4, my intended model will predict matched outcomes, y1, y2, y3, and y4.

And I want my model not to know the future sequence data. For example, when predicting at the time of t2, the model should not know the information of t3 and t4 and it should yield the same result even though the information of t3 and t4 will change.

My previous code of RNN is like below: L_LSTM = nn.LSTM(n_hidden, n_hidden, batch_first=True) sequences_l, _ = L_LSTM(X, (state_h, state_c)) # sequences_l.shape => n_batch, n_sequence, n_hidden

I recently found this article, neural ODE, and I'm really interested in its amazing concepts. Because my dataset is severely irregularly sampled, neural ODE seems to improve the model performance. I want to apply neural ODE to my dataset instead of RNN, but I'm not pretty sure whether my code is right or not.

Example code in README.md is like below: zt = torchcde.cdeint(X=X, func=self.func, z0=z0, t=X.interval) zT = zt[..., -1, :] # get the terminal value of the CDE, zT.shape => n_batch, n_hidden

and I want to change this code like below: sequences_t = torchcde.cdeint(X=X, func=func, z0=z0, t=X._t) # sequences_t.shape => n_batch, n_sequence, n_hidden (X._t represents all time sequences according to my best knowledge)

Does it make sense to change codes like that? Will sequences_t act like sequences_l (from RNN code)? Thanks in advance!

dactylogram avatar Feb 17 '22 11:02 dactylogram

Yep, that looks good to me, with one exception: you want X.grid_points rather than X._t.

Incidentally you may find a variety of other resources interesting:

  • This paper on interpolation schemes: https://arxiv.org/abs/2106.11028. In particular it recommends replacements to the natural cubic splines we originally used in https://arxiv.org/abs/2005.08926. This is because natural cubic splines are "non causal" -- i.e. the data at t4 affects the interpolation between t1 and t2. (This sounds like it'll be important to you.)
  • This "textbook" on NDEs in general: https://arxiv.org/abs/2202.02435
  • If you ever need JAX instead of PyTorch, then there is a (new!) equivalent library you can use: https://github.com/patrick-kidger/diffrax

Hope that helps!

patrick-kidger avatar Feb 17 '22 13:02 patrick-kidger

Thank you for your kind comments!! Can I ask you some questions on X.grid_points? I'm afraid I'm not good at mathematic equations, and detailed process of how time values are processed is hard to understand for me.

At first, I think X.grid_points (or X._t) represents scaled time values, such as 0.2, 0.5, or other floats from predetermined range, and the difference between two sequential values implies time difference (e.g. t1: 0.2, t2: 0.5 means 0.3 of scaled time elapsed). But I found it is a set of incrementally increasing integers from 0 with fixed interval according to the code (0, 1, 2, 3, ... and so on), and it does not carry the information of irregularly sampled time.

    if t is None:
        t = torch.linspace(0, coeffs.size(-2), coeffs.size(-2) + 1, dtype=coeffs.dtype, device=coeffs.device)

I want to change t as scaled time values (e.g. [0, 0.05, 0.35, 0.44, ... ]), but I also found the note in this code is like below, which warns against using t argument If I want to use neural CDEs.

    """
    Arguments:
        coeffs: As returned by `torchcde.natural_cubic_coeffs`.
        t: As passed to linear_interpolation_coeffs. (If it was passed. If you are using neural CDEs then you **do
            not need to use this argument**. See the Further Documentation in README.md.)
    """

In the example code in "https://github.com/patrick-kidger/torchcde/blob/master/example/irregular_data.py", time information was included in the first variable of pseudodata x (x.shape = n_batch, n_sequence, n_variables). Time variable is not specially treated in the example code, and incrementally increasing integers are used for X.grid_points.

Is it okay that time variable is simply included in the data column without any specification? (model do not know which column is time variable). Thanks again!

dactylogram avatar Feb 18 '22 04:02 dactylogram

Have a read of Section 3.2.1.3 of On Neural Differential Equations. The rest of Chapter 3 might also be helpful for context.

patrick-kidger avatar Feb 18 '22 13:02 patrick-kidger