diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Adding Latent SDE

Open anh-tong opened this issue 2 years ago • 13 comments

Hi Patrick,

Following up on the last discussion, I create a pull request containing

  • Small change in diffrax.misc.sde_kl_divegence i.e., handling context and compute KL
  • Add a new notebook of Latent SDEs as a new file examples/neural_sde_vae.ipynb
  • Rename examples/neural_sde.ipynb to examples/neural_sde_gan.ipynb (fix link in the description as well)
  • Update mkdocs.yml

anh-tong avatar May 11 '22 10:05 anh-tong

Looks like the formatting is failing. Have a look at CONTRIBUTING.md.

I've not really gone through most of the example yet; I'll leave a proper review of that once everything so far has been organised.

I will say that I don't think I really believe what's happening here, mathematically. In the infinite-training limit you're just matching the SDE against a single trajectory, so it collapses to an ODE (zero noise). Have a look at the Lorenz example in torchsde for a more convincing (to me) example of training a latent SDE as a generative model, rather than this case which I think is pretty much just supervised learning.

(The real giveaway here is that you're using context=None in sde_kl_divergence. Whilst you don't have to have a context for the abstract notion of "KL divergence between two SDEs", you absolutely need one to have a meaningful latent SDE.)

patrick-kidger avatar May 13 '22 12:05 patrick-kidger

Thanks for the detailed review. I will get back on this after a few days :)

anh-tong avatar May 16 '22 04:05 anh-tong

Sorry for taking so long.

In the recent commits, I have changed diffrax.misc.sde_kl_divergence where the vector field is contructed based on the control term of the input SDE as MultiTerm. I also make a simple unit test for this.

I've implemented the notebook of Latent SDE for Lorenz data as you suggested to make it more like VAE than just supervised learning. This takes some time for me to make it run. (It seems KL annealing is the trick to train this model)

Your other comments into the recent changes are included as well.

anh-tong avatar Jun 13 '22 11:06 anh-tong

Hi all, thank you for the great work! I was just trying to run the code from the pull request and encountered this error

NotImplementedError                       Traceback (most recent call last)
Cell In [13], line 4
      1 while iter < train_iters:
      2     # optimizing
      3     _, training_key = jrandom.split(training_key)
----> 4     loss, grads = make_step(latent_sde)
      5     loss = loss.item()
      6     updates, opt_state = optim.update(grads, opt_state)

File ~/diffrax/.env/lib/python3.8/site-packages/equinox/jit.py:95, in _JitWrapper.__call__(_JitWrapper__self, *args, **kwargs)
     94 def __call__(__self, *args, **kwargs):
---> 95     return __self._fun_wrapper(False, args, kwargs)

File ~/diffrax/.env/lib/python3.8/site-packages/equinox/jit.py:91, in _JitWrapper._fun_wrapper(self, is_lower, args, kwargs)
     89     return self._cached.lower(dynamic, static)
     90 else:
---> 91     dynamic_out, static_out = self._cached(dynamic, static)
     92     return combine(dynamic_out, static_out.value)

    [... skipping hidden 11 frame]

File ~/diffrax/.env/lib/python3.8/site-packages/jax/experimental/host_callback.py:1806, in <lambda>(j)
   1803 id_p.def_abstract_eval(lambda *args: args)
   1804 xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args)
...
   1687           )))
   1688 else:
-> 1689   raise NotImplementedError(f"outfeed rewrite {eqn.primitive}")

NotImplementedError: outfeed rewrite closed_call

Was not sure if there was something I missed? I didn't change anything in the code. This occurs when I call the make_step() function. The library versions are

equinox             0.7.1
jax                 0.3.17
jaxlib              0.3.15+cuda11.cudnn82
optax               0.1.3

Thanks in advance!

harrisonzhu508 avatar Sep 09 '22 19:09 harrisonzhu508

Hmm. This looks like a bug in core JAX -- closed_call looks like a new JAX primitive, which jax.experimental.host_callback.call hasn't been updated to be able to handle.

I'm not sure when closed_call is used, but you should be able to construct a MWE by using jax.experimental.host_callback.call inside a function called by whatever-it-is that generates closed_call. (In particular this should be doable without Diffrax.)

patrick-kidger avatar Sep 09 '22 20:09 patrick-kidger

Hmm. This looks like a bug in core JAX -- closed_call looks like a new JAX primitive, which jax.experimental.host_callback.call hasn't been updated to be able to handle.

I'm not sure when closed_call is used, but you should be able to construct a MWE by using jax.experimental.host_callback.call inside a function called by whatever-it-is that generates closed_call. (In particular this should be doable without Diffrax.)

Thanks for the quick reply Patrick! I see, this makes sense. It's strange that the other examples in Diffrax do still seem to work, I'll investigate this a bit more.

harrisonzhu508 avatar Sep 09 '22 20:09 harrisonzhu508

@patrick-kidger Sorry for taking so long. I will try to get back to this pull request in a couple of days. What I can think of now is to handle the case that diffusion matrices are diagonal.

@harrisonzhu508 I will take a look at the bug. If other examples do not have the problem, it may be because of the current implementation of latent SDE.

anh-tong avatar Sep 10 '22 04:09 anh-tong

Hi @harrisonzhu508, you can run the current code with jax=0.3.15 (while waiting for further update). The latest version of equinox=0.7.1 works fine.

As Patrick mentioned, it must be something to do with JAX core. I also find that the bug occurs when we use eqx.filter_value_and_grad (eqx.filter_jit still works). This is related to this pull request (https://github.com/google/jax/pull/10711) and host_callback does not handle this yet.

anh-tong avatar Sep 20 '22 16:09 anh-tong

Hi @patrick-kidger ,

I tried to make sde_kl_divergence can handle more general cases but I do not have a complete solution.

If I understand correctly, the goal of materialise_vf to find the PyTree structure of the output of vf_prod. However, such an output should agree with PyTree structure of the ouput of drift (as ODETerm) which we always can have access. Therefore, we may not need to materialise vf_prod. Also we may not need to convert everything into arrays.

If we can restrict our case where both drift and diffusion has the same PyTree structure (and leaf nodes are jnp.ndarray), we can simple handle block-diagonal diffusions using tree_map and checking the shape of array in leaf nodes.

The current implementation can handle block diagonal difussion matrices having PyTree as

drift = {
        "block1": jnp.zeros((2,)),
        "block2": jnp.zeros((2,)),
        "block3": jnp.zeros((3,)),
    }
diffusion = {
        "block1": jnp.ones((2,)),
        "block2": jnp.ones((2, 3)),
        "block3": jnp.ones((3, 4)),
    }

The first block corresponds to WeaklyDiagonalControlTerm. The remaining ones correspond to the general ControlTerm. I did not make any experiments on this part but a unit test to test this.

I also pass context using args in vf(t, y, args) but this may break the API that args should be PyTree while context is a function.

The difficulty I encounter when handling the more general case can be described in this code.

import jax.tree_util as jtu
import jax.numpy as jnp

vf_prod = {'block1': jnp.ones((2,)), 'block2': jnp.ones((1))}
diffusion = {'block1': jnp.ones((2,)), "block2": [[1., 1., 1.]]}

# vf_prod_tree obtained either from `materialise_vf` or input `drift`
vf_prod_tree = jtu.tree_structure(vf_prod) # PyTreeDef({'block1': *, 'block2': *})
diffusion_tree = jtu.tree_structure(diffusion) # PyTreeDef({'block1': *, 'block2': [[*, *, *]]})

transposed = jtu.tree_map(lambda *xs: list(xs), *[vf_prod, diffusion])
# PyTreeDef({'block1': [*, *], 'block2': [*, [[*, *, *]]]})

# next step is to convert the diffusion part of `block2` to array. But we don't know how
# maybe can use `is_leaf` in `jtu.tree_map`. But what is the condition to decide a leaf?

anh-tong avatar Sep 21 '22 12:09 anh-tong

Hi @harrisonzhu508, you can run the current code with jax=0.3.15 (while waiting for further update). The latest version of equinox=0.7.1 works fine.

As Patrick mentioned, it must be something to do with JAX core. I also find that the bug occurs when we use eqx.filter_value_and_grad (eqx.filter_jit still works). This is related to this pull request (google/jax#10711) and host_callback does not handle this yet.

Thanks a lot!

harrisonzhu508 avatar Sep 26 '22 21:09 harrisonzhu508

Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot!

global_step_950

harrisonzhu508 avatar Sep 26 '22 21:09 harrisonzhu508

Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot!

Hi, I guess this happens because the current parameter setting with kl_anneal_iters = 1000 may not be suitable for the data in the plot. Please try kl_anneal_iters = 100 instead (like in torchsde).

kl_anneal_iters actually helps the training to figure out a good set of parameters at the early stage by prioritizing optimizing the likelihood over KL divergence. Fitting the data in the figure is relatively simple so that it may not take so long to reach a part with a reasonable likelihood. The collapsing is explained in this paper(see Section 5) as the model only learns via likelihoods.

anh-tong avatar Sep 27 '22 00:09 anh-tong

Hi @anh-tong, thanks a lot for the very clean implementation again! I was trying to reproduce an example that is very similar to https://github.com/google-research/torchsde/blob/master/examples/latent_sde.py. Running the latter script yields the attached. But using your notebook implementation, I've noticed that the posterior sample paths seem to collapse to a deterministic function (even in intervals where there's no data), I was wondering if you noticed something similar too? Thanks a lot!

Hi, I guess this happens because the current parameter setting with kl_anneal_iters = 1000 may not be suitable for the data in the plot. Please try kl_anneal_iters = 100 instead (like in torchsde).

kl_anneal_iters actually helps the training to figure out a good set of parameters at the early stage by prioritizing optimizing the likelihood over KL divergence. Fitting the data in the figure is relatively simple so that it may not take so long to reach a part with a reasonable likelihood. The collapsing is explained in this paper(see Section 5) as the model only learns via likelihoods.

That makes sense, thanks for the explanation! I haven't got it working (I'm training on samples from a stochastic process) but I'll try and play around with the KL annealing!

harrisonzhu508 avatar Sep 27 '22 16:09 harrisonzhu508