pymc4
pymc4 copied to clipboard
Numpy-like interaction with tensors
As our random variables are now tensors, manipulations that users of pymc3 are familiar with might no longer be valid, requiring them to use tensorflow's particular syntax for manipulating tensors. I might be wrong, but advanced indexing of numpy arrays seem to no longer work and requires tf.gather_nd. Is this deviation from pymc3-like syntax desirable?
I dont think there is good workaround for that as we rely on Tensorflow for tensor operation.
I wasn't aware that advanced indexing is not supported by TF, that's a bummer. It's the only thing were non-numpy syntax would be required, however, or am I missing something else?
Hmmm.... @junpenglao & @twiecki I am commenting from the dark on this, but jax devs explicitly want to use the NumPy API - any chance we could experiment with that? What would be needed - a re-implementation of the probability distributions and bijectors are the ones I can see at the moment - but any more?
The numpy API is idiomatic, which is why I like the tensor libraries that strive to use its API rather than invent its own. CuPy and jax devs have either a formal or an informal policy that strives to use the NumPy API, but TF and PyTorch devs seem to half-heartedly do so.
@csuter Any thoughts on this discussion?
There's some historical context here, and (substantially) more here. In particular, @shoyer points to a NEP discussing a proposal to work around shortcomings of numpy advanced indexing. I know nothing about the details of why TF doesn't have numpy-style advanced indexing but it's built and designed by a large group of smart folks so a reasonable prior is that there are at least some good reasons :)
A library of distributions and bijectors built on jax would be awesome, but would also be a ton of work.
it's built and designed by a large group of smart folks so a reasonable prior is that there are at least some good reasons
Won't deny that there's definitely a large group of smart folks! But we all know about what happens to a broth when there's many chefs involved :smile:, particularly smart (and likely) opinionated ones!
Jokes aside, @csuter, thanks for providing the context on indexing. I see the issues with numpy indexing now.
Apart from numpy-like indexing, though, I had issues with both TF and PyTorch's use of non-exact function naming (e.g. tf.reduce_sum rather than tf.sum), which makes it tedious (though not difficult) to replace old NumPy code with TF code. If exact function names were copied over, I would have moved my codebase to TF in a heartbeat. Imagine import tensorflow as np! For this reason, Jax has been attractive for writing models, because a lot of the old autograd.numpy code I wrote can be ported over by a single import change.
I have other questions, perhaps I'll raise it in a different issue.