tch-rs icon indicating copy to clipboard operation
tch-rs copied to clipboard

Request for torch.distributions primitives

Open jerry73204 opened this issue 6 years ago • 19 comments

It seems current tch-rs lacks distribution primitives like torch.distributions and tf.distributions. Though it already have probability generators, there's still rooms for probability arithmetic and deductions.

I see a promising crate rv that may fit my needs. However, its interface is not general enough. We cannot pass a mean/std tensor to rv's Gaussion new() function.

Porting torch.distributions may solve the problem. It requires some patience to make it work, or I could go ahead to improve rv. I'd like to know if author has any plan to port the module, or leaves it to other crates.

If you're asking which place needs this feature, I'm implementing ELBO loss for GQN (paper).

jerry73204 avatar Apr 28 '19 17:04 jerry73204

Thanks for the suggestion. This indeed seems like a nice thing to add - the underlying torch primitives for sampling should already exist in the rust wrappers, e.g. normal so I guess this would mostly consist in adding some trait for distributions and implementing it mimicking the python implementation, e.g. Normal. Would that be useful to you? Also which methods/distributions would be the most interesting to you?

LaurentMazare avatar Apr 28 '19 17:04 LaurentMazare

Yes, I see that normal function. I need probability arithmetic rather than prob generators. Take the GQN for example, I need a Normal object which parameters are defined by tensors, and compute the log_prob (log prob mass function) for any value on the Normal. I also have to compute KL div for two normals. There's no need for sampling.

Torch and TensorFlow have their own such function. Interestingly, you'll see NotImplementedError if you look into Torch's source code. So I bet improving rv would be a good direction. Currently I just write raw formulas with cautions on floating point precisions.

jerry73204 avatar Apr 28 '19 18:04 jerry73204

Do you need to backprop through your KL divergence ? If that's the case I'm not sure that rv could be used out of the box but maybe I'm missing something ? When looking at the 'Normal' distribution in torch, I don't see much NotImplementedError besides in the base Distribution class. Also the kl divergence for the normal distribution can be found here.

LaurentMazare avatar Apr 28 '19 18:04 LaurentMazare

The NotImplememtedError goes here. Sorry for inprecise comment.

Backprop is desired in my case. As far as I know, torch.distributioms looks like an add-on of pytorch rather than in libtorch. Implementation in Rust is like a completely new library.

Let me write my some code for this and to see any proper way to build this feature.

jerry73204 avatar Apr 28 '19 19:04 jerry73204

Re NotImplementedError, indeed that's the base class for distributions in pytorch so nothing is implemented here and the various methods are implemented in the derived classes like 'Normal'. The distribution bit is included within the main pytorch repo and package (contrary to say vision which is an external repo and pypi package). I don't have a proper opinion on where this belongs - in the main crate or in a separate one but starting by an external crate sounds good and if there is some upside to merge it in the main repo we can consider it later. Let me know if you notice some pytorch primitives missing from tch-rs that could be useful to you!

LaurentMazare avatar Apr 28 '19 19:04 LaurentMazare

Just to mention that I added a variational auto-encoder example to tch-rs. This inctludes a KL divergence loss here. It's certainly very far away from what a nice distributions api would provide but it may be handy.

LaurentMazare avatar Apr 30 '19 07:04 LaurentMazare

It is just a case of adding numerical approximation functions for pdfs and cdfs of popular statistical distributions. I personally just implemented the Normal cdf derived in this article using tensors and could normally calculate the gradients with the library:

fn norm_cdf(x: &Tensor) -> Tensor {
    let denom = ((-358.0 * x / 23.0) + 111.0 * (37.0 * x / 294.0).atan()).exp() + 1.0;
    1.0 / denom
}

vegapit avatar Jul 17 '19 07:07 vegapit

@vegapit yes that's mostly about adding such functions and probably some traits for the various distributions. You can see the implementation for the normal distribution in the python api here. The code for the cdf is a bit different from yours and relies on torch.erf. Not sure which one has the best precision.

def cdf(self, value):
        return 0.5 * (1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)))

LaurentMazare avatar Jul 17 '19 09:07 LaurentMazare

@LaurentMazare I did not know you actually had already added the error function implementation in Torch, otherwise I would have used it. I do not know what the numerical approximation in torch.erf is but my guess is that its precision must be similar to the function I described above.

vegapit avatar Jul 17 '19 10:07 vegapit

Btw the error function is also available in tch-rs https://docs.rs/tch/0.1.0/tch/struct.Tensor.html#method.erf (as the rust low level wrappers are automatically generated, we mostly get these for free)

LaurentMazare avatar Jul 17 '19 10:07 LaurentMazare

A little update here. In my previous gqnrs project has some useful traits for prob distributions. There is only Normal dist there. Suppose we can start a working branch to fill the blanks for {Bernoulli,Exponential,Categorical, etc} dists?

jerry73204 avatar Jul 18 '19 07:07 jerry73204

Yes that kind of trait would indeed probably be useful, @vegapit do you think this would cover your use case?

LaurentMazare avatar Jul 18 '19 09:07 LaurentMazare

I use Torch to solve for maximum likelihood in non-linear parametrised models. The density functions can be reconstructed or approximated using tensor methods but I guess it is clearly more user friendly to provide them in wrappers as Pytorch does

vegapit avatar Jul 18 '19 10:07 vegapit

I started to work on a crate to port the distributions: https://github.com/spebern/tch-distr-rs Besides porting, the most tedious work is to test everything.

I think the best way to ease porting is to test against the python implementations directly by supplying the same input and comparing the outputs. Later, this can also be extended with fuzzy input.

The Distribution trait is open for discussion and pull requests are more than welcome to add more distributions/tests.

spebern avatar May 31 '20 14:05 spebern

This looks very nice, thanks for sharing! Once this has been polished/tested a bit, we could probably mention it in the tch-rs main readme file for better discoverability if that's ok with you.

LaurentMazare avatar May 31 '20 14:05 LaurentMazare

That would be really nice! It definitely needs some polishing and more implemented distributions, but I really think that the testing against python implementations takes away a lot of work.

spebern avatar May 31 '20 14:05 spebern

This looks very nice, thanks for sharing! Once this has been polished/tested a bit, we could probably mention it in the tch-rs main readme file for better discoverability if that's ok with you.

@LaurentMazare, I wonder if it a good idea to introduce tch-distr-rs of @spebern as a feature for torch-rs. So that user no need to introduce 2 crates when using torch in rust.

dbsxdbsx avatar Nov 20 '21 07:11 dbsxdbsx

If it's just to avoid having an additional dependency for crates that would want to use this, I would lean more towards keeping an external crate, and in general having smaller composable crates for the bits that are outside of the core tch-rs, e.g. I'm more thinking about moving the vision models in their own crate, the RL bits to their own thing too etc.

LaurentMazare avatar Nov 23 '21 19:11 LaurentMazare

If it's just to avoid having an additional dependency for crates that would want to use this, I would lean more towards keeping an external crate, and in general having smaller composable crates for the bits that are outside of the core tch-rs, e.g. I'm more thinking about moving the vision models in their own crate, the RL bits to their own thing too etc.

@LaurentMazare , the reason for why I hope tch-distr-rs could be part of tch-rs is that in pytorch, the distribution part code is also part of the whole python torch module, though it is not a part of code in the C++ version. Meanwhile, I think it not proper to treat tch-distr-rs as a tool ONLY for reinforcement learning or some other fields.

Therefore, I suggest making it as an optional feature, which would also be flexible (as user could decide whether to include it or not through tag "feature" in Cargo.toml) and easy to transfer from pytorch for users familiar with pytorch.

dbsxdbsx avatar Nov 24 '21 02:11 dbsxdbsx