pytensor icon indicating copy to clipboard operation
pytensor copied to clipboard

Implement all Ops in PyTorch (help welcome!)

Open ricardoV94 opened this issue 1 year ago • 29 comments

Description

If you want to help implementing some of these Ops just leave a comment below saying which ones you are interested in. We'll give you some time to work on it (and then put it back up to grabs).

See the documentation for How to implement PyTorch Ops and tests: https://pytensor.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html

Example PR: #836

See https://github.com/pymc-devs/pytensor/issues/821#issuecomment-2202258929 for suggestions on equivalent torch functions

Tensor creation Ops

  • [x] Alloc and AllocEmpty #836
  • [x] Arange #836
  • [x] Eye #877
  • [ ] ScalarFromTensor
  • [ ] TensorFromScalar
  • [x] Repeat #890
  • [x] Unique #890
  • [x] Sort / Argsort #897
  • [ ] Tri

Shape Ops

  • [x] Dimshuffle - Done in #764
  • [x] Reshape #926
  • [x] Shape, Shape_i #926
  • [x] SpecifyShape #926
  • [x] Unbroadcast #926
  • [x] Join #869
  • [ ] Split
  • [x] MakeVector #926

Math Ops

  • [x] Elemwise - Done in #764 (but not complete! Specific cases will require a custom ScalarOp dispatch)
  • [x] CAReduce (Sum, All, Any...) - Assigned to @HarshvirSandhu
  • [x] CumOp - #837
  • [x] Softmax, LogSoftmax and Grads in #846
  • [x] Dot #878
  • [x] BatchedDot #878
  • [ ] Argmax

Indexing Ops

  • [ ] Subtensor #910
  • [ ] Inc/SetSubtensor #910
  • [ ] AdvancedSubtensor[1] #910
  • [ ] AdvancedIncSubtensor[1] #910

Branching Ops

  • [x] CheckAndRaise - Done in #764
  • [ ] Ifelse #940
  • [ ] ScalarLoop
  • [ ] Scan
  • [ ] OpFromGraph
  • [ ] Blockwise

Linalg Ops

  • [ ] SVD #920
  • [ ] Det #920
  • [ ] Eig #920
  • [ ] Eigh #920
  • [ ] MatrixInverse #920
  • [ ] MatrixPinv #920
  • [ ] QRFull #920
  • [ ] SLogDet #920
  • [ ] BlockDiagonal #922
  • [ ] Cholesky #922
  • [ ] Solve #922
  • [ ] SolveTriangular #922

SparseOps

  • [ ] ... (to be filled)

RandomVariable Ops

  • [ ] ... Need to figure out API differences

If you need an Op that's not in this list, comment below and we'll add it!

ricardoV94 avatar Jun 14 '24 16:06 ricardoV94

Hi there @ricardoV94, I'm attending the hackathon at Pydata London.

I'd like to work on Softmax.

HAKSOAT avatar Jun 15 '24 15:06 HAKSOAT

Hi there @ricardoV94, I'm attending the hackathon at Pydata London.

I'd like to work on Softmax.

Sure!

ricardoV94 avatar Jun 16 '24 08:06 ricardoV94

Thanks @ricardoV94

I have been doing some reading around the docs, I noticed that #764 has the initial setup. I believe that contains ground work for adding other Ops, so I'd be following the PR and the Softmax Ops will likely use some code from it.

Am I thinking about this the right way?

HAKSOAT avatar Jun 17 '24 17:06 HAKSOAT

@HAKSOAT yes, you should be able to start a branch already from #764. All the functionality is done, except some tests are failing because they depend non-optionally on GPU. We should make that optional and get it merged soon enough, but you need not wait :)

ricardoV94 avatar Jun 17 '24 18:06 ricardoV94

#764 is merged

ricardoV94 avatar Jun 20 '24 15:06 ricardoV94

Hello @ricardoV94, I would like to work on Reshape

t3chw avatar Jun 20 '24 17:06 t3chw

Go ahead. We'll link and lock the Op when you open a PR

ricardoV94 avatar Jun 20 '24 17:06 ricardoV94

Hi @ricardoV94, I have opened a PR for the Softmax Ops. I see it has been grouped with LogSoftmax and Grads, so I can update the PR to include them.

HAKSOAT avatar Jun 23 '24 22:06 HAKSOAT

Hi @ricardoV94, I will work on the Dot Op now.

HangenYuu avatar Jun 28 '24 07:06 HangenYuu

If someone wants to look through the codebase and populate the list of Ops above that would also be very helpful :)

ricardoV94 avatar Jun 28 '24 11:06 ricardoV94

If someone wants to look through the codebase and populate the list of Ops above that would also be very helpful :)

@ricardoV94 Does something like this work? There are correponding torch function/method attached each op.

  1. pytensor.tensor.elemwise
  • [ ] Elemwise:
  • [ ] CAReduce:
  • [ ] DimShuffle: torch.transpose.
  • [ ] Softmax: torch.nn.functional.softmax.
  • [ ] SoftmaxGrad: Hand-crafted like JAX or use PyTorch autograd.
  • [ ] LogSoftmax: torch.nn.functional.log_softmax.
  1. pytensor.tensor.extra_ops
  • [ ] Bartlett: torch.bartlett_window.
  • [ ] CumOp: torch.cumprod & torch.bincount.
  • [ ] FillDiagonal: torch.Tensor.fill_diagonal_.
  • [ ] FillDiagonalOffset: torch.diagonal_scatter with src parameter set to a vector of identical values with compatible shape.
  • [ ] RavelMultiIndex: no equivalent in PyTorch, must be crafted from native ops.
  • [ ] Repeat: torch.repeat_interleave.
  • [ ] Unique: torch.unique.
  • [ ] UnravelIndex: torch.unravel_index.
  • [ ] bincount: torch.bincount.
  • [ ] broadcast_to: torch.broadcast_to.
  1. pytensor.tensor.nlinalg
  • [ ] BatchedDot: torch.matmul with check for batch dimension or torch.mm.
  • [ ] Dot: torch.matmul.
  • [ ] MaxAndArgmax: torch.max & torch.argmax.
  • [ ] SVD: torch.linalg.svd.
  • [ ] Det: torch.linalg.det.
  • [ ] Eig: torch.linalg.eig.
  • [ ] Eigh: torch.linalg.eigh.
  • [ ] MatrixInverse: torch.linalg.inv.
  • [ ] MatrixPinv: torch.linalg.pinv.
  • [ ] QRFull: torch.linalg.qr.
  • [ ] SLogDet: torch.linalg.slogdet.
  1. pytensor.tensor.slinalg
  • [ ] BlockDiagonal: torch.block_diag.
  • [ ] Cholesky: torch.linalg.cholesky.
  • [ ] Solve: torch.linalg.solve.
  • [ ] SolveTriangular: torch.linalg.solve_triangular.
  1. Pytensor.tensor.random - Sampling from a distribution will have to happen through torch.distributions. The underlying random number generator modules are torch.Generator and torch.random. The documentation for different families are all listed on Probability distributions - torch.distributions — PyTorch 2.2 documentation.
  • [ ] RandomStateType(type): Random state or random seed in PyTorch is not exposed as a separate class. It is returned as a torch.Tensor via torch.get_rng_state() and also in get_state() of torch.Generator.
  • [ ] RandomVariable(func): Like JAX, I will have to implement a new class to wrap the random number generator of a specific distribution.
  • [ ] Generator: torch.Generator.
  • [ ] Generic family
    • [ ] BetaRV: torch.distributions.beta.Beta
    • [ ] DirichletRV: torch.distributions.dirichlet.Dirichlet
    • [ ] PoissonRV: torch.distributions.poisson.Poisson
    • [ ] MvNormalRV: torch.distributions.multivariate_normal.MultivariateNormal
  • [ ] Loc-scale family
    • [ ] CauchyRV: torch.distributions.cauchy.Cauchy
    • [ ] GumbelRV: torch.distributions.gumbel.Gumbel
    • [ ] LaplaceRV: torch.distributions.laplace.Laplace
    • [ ] LogisticRV:
    • [ ] NormalRV: torch.distributions.normal.Normal
    • [ ] StandardNormalRV: A special case of torch.distributions.normal.Normal
  • [ ] No datatype family:
    • [ ] BernoulliRV: torch.distributions.bernoulli.Bernoulli
    • [ ] CategoricalRV: torch.distributions.categorical.Categorical
  • [ ] Uniform density family:
    • [ ] RandIntRV: torch.randint
    • [ ] IntegersRV:
    • [ ] UniformRV: torch.distributions.uniform.Uniform
  • [ ] Shape-scale family:
    • [ ] ParetoRV: torch.distributions.pareto.Pareto
    • [ ] GammaRV: torch.distributions.gamma.Gamma
  • [ ] ExponentialRV: torch.distributions.exponential.Exponential
  • [ ] StudentTRV: torch.distributions.studentT.StudentT
  • [ ] ChoiceRV: torch.multinomial?
  • [ ] PermutationRV: torch.randperm
  • [ ] BinomialRV: torch.distributions.binomial.Binomial
  • [ ] MultinomialRV: torch.distributions.multinomial.Multinomial
  • [ ] VonMisesRV: torch.distributions.von_mises.VonMises
  1. Pytensor.scalar - still unsure how to convert this to PyTorch due to missing docstring and the lack of current knowledge about methods on Scalar in PyTorch.
  • [ ] ScalarOp:
  • [ ] Add:
  • [ ] Mul:
  • [ ] Sub:
  • [ ] IntDiv:
  • [ ] Mod:
  • [ ] Cast:
  • [ ] Clip:
  • [ ] Composite:
  • [ ] Identity:
  • [ ] Second:
  • [ ] BetaIncInv:
  • [ ] Erf:
  • [ ] Erfc:
  • [ ] Erfcinv:
  • [ ] Erfcx:
  • [ ] Erfinv:
  • [ ] GammaIncCInv:
  • [ ] GammaIncInv:
  • [ ] Iv:
  • [ ] Ive:
  • [ ] Log1mexp:
  • [ ] Psi:
  • [ ] TriGamma:
  • [ ] Softplus: torch.unique
  1. pytensor.scan.op.Scan - still unsure how to convert this to PyTorch
  • [ ] 8. pytensor.sparse - torch.sparse supports sparse matrices. Outside the same methods implemented in JAX, I can map more methods to Torch (unless I miss something and the current ones implemented turn out to be enough to create everything from there).
  • [ ] SparseTensorType: torch.tensor(layout=torch.sparse_<sparse_type>) via different variations of torch.tensor.to_sparse.
  • [ ] Dot: torch.smm.
  • [ ] StructuredDot: torch.sparse.mm() (currently not working for CSR matrix, especially on GPU).
  1. pytensor.tensor.basic
  • [ ] Alloc:
  • [ ] AllocEmpty:
  • [ ] ARange:
  • [ ] ExtractDiag:
  • [ ] Eye:
  • [ ] Join:
  • [ ] MakeVector:
  • [ ] ScalarFromTensor:
  • [ ] Split:
  • [ ] TensorFromScalar:
  • [ ] Tri:
  • [ ] SortOp:

HangenYuu avatar Jul 02 '24 08:07 HangenYuu

I could help with the remaining Tensor creation ops to begin with. Let me know.

twaclaw avatar Jul 02 '24 08:07 twaclaw

Thanks @twaclaw, feel free to open a PR

ricardoV94 avatar Jul 02 '24 08:07 ricardoV94

I will have a look at Repeat, Unique, etc. during the weekend.

@ricardoV94, regarding Unique:

  • is it desired to emulate the behaviour of return_index=True or rather throw NotImplementedError? (torch.unique doesn't feature that param)
  • is the axis param required? Your current JAX Op implementation doesn't support a non None axis

twaclaw avatar Jul 04 '24 18:07 twaclaw

@twaclaw in general we want to support exactly the same functionality from the original Op. When that is not possible or too complicated raising NotImplementedError is fine.

Regarding JAX, we probably cannot compile (JIT) any function that has unique in it because JAX can't handle dynamic shapes. So it's a bit moot whether we say we support axis or not, although the NotImplementedError could be removed and we could just dispatch to jax.numpy.unique rather straightforward. Feel free to open a PR for that if you want.

ricardoV94 avatar Jul 05 '24 09:07 ricardoV94

I'll be working on the indexing Ops now

HarshvirSandhu avatar Jul 05 '24 21:07 HarshvirSandhu

@twaclaw in general we want to support exactly the same functionality from the original Op. When that is not possible or too complicated raising NotImplementedError is fine.

Regarding JAX, we probably cannot compile (JIT) any function that has unique in it because JAX can't handle dynamic shapes. So it's a bit moot whether we say we support axis or not, although the NotImplementedError could be removed and we could just dispatch to jax.numpy.unique rather straightforward. Feel free to open a PR for that if you want.

@ricardoV94, regarding the JAX implementation of unique, one possible option would be to make the param size in jax.numpy.unique static. Anyhow, I think the current implementation might be broken (i.e., lax_numpy is undefined).

twaclaw avatar Jul 07 '24 13:07 twaclaw

I can't imagine many cases where I would know the size of the unique elements but not what/where they were? If I knew what/where they are I would just select them instead of using unique.

More importantly we don't have size in our Unique Op and I don't think it makes sense to add it for this edge case. A more general approach will be to be clever about what can and cannot be jitted. I think there's an issue open for that already.

For now we can probably just remove the implementation and let it raise NotImplementedError if as you say it's broken anyway

ricardoV94 avatar Jul 07 '24 13:07 ricardoV94

You are right, unique is not compatible with JIT. You don't need to know the exact size but guess it (maybe overestimate it). I was wondering whether it is possible to insert parameters (like a Constant size) into the Op Graph 🤔. But I agree a NotImplementedError would be appropriate.

twaclaw avatar Jul 07 '24 14:07 twaclaw

I can take a look at the linear algebra ops.

twaclaw avatar Jul 08 '24 10:07 twaclaw

Seems like Reshape is become more and more relevant, if anyone wants to tackle it

ricardoV94 avatar Jul 10 '24 07:07 ricardoV94

Seems like Reshape is become more and more relevant, if anyone wants to tackle it

Is someone working on this?

twaclaw avatar Jul 11 '24 17:07 twaclaw

Seems like Reshape is become more and more relevant, if anyone wants to tackle it

Is someone working on this?

Not yet I think. You can go ahead

ricardoV94 avatar Jul 11 '24 18:07 ricardoV94

Is the checklist at the top up to date on what else is needed?

Ch0ronomato avatar Jul 12 '24 21:07 Ch0ronomato

More or less up to date except linalg and indexing is being worked on

ricardoV94 avatar Jul 12 '24 21:07 ricardoV94

If someone is interested we need to check whether we can bridge nicely between PyTensor and Torch random number generator APIs.

We have added a recent documentation page explaining how random variables work in PyTensor: https://github.com/pymc-devs/pytensor/pull/928

As a reminder we're targetting torch compile functionality in case that matters

ricardoV94 avatar Jul 17 '24 19:07 ricardoV94

@ricardoV94 I'll take a stab after i finish some of the operators in #939. I need to build a bit more familiarity with the pytensor code first

Ch0ronomato avatar Jul 20 '24 19:07 Ch0ronomato

Coming from PyMC, adding a sparse solve would be useful I believe...

danielgerardclaassen avatar Aug 24 '24 11:08 danielgerardclaassen

Coming from PyMC, adding a sparse solve would be useful I believe...

This issue is not very relevant for that request, since we first need it in PyTensor to begin with, before we add it to the PyTorch backend. We haven't done anything with Sparse stuff in the PyTorch backend to begin with

ricardoV94 avatar Aug 24 '24 11:08 ricardoV94