functorch
functorch copied to clipboard
functorch is JAX-like composable function transforms for PyTorch.
Makes the necessary changes to vmap.py to make [Using TensorDict for functorch](https://github.com/facebookresearch/rl/pull/364) work
We should fix it. Not quite sure how though. https://gist.github.com/zou3519/73a6189e21561f6ef5b42874e8a4826f
This is a subgraph from `tts_angular` model The generated backward pass has many `None` outputs, suggesting that that `requires_grad` is somehow not passed correctly when LSTM cell is used. ~~~...
Hello, I saw that here is a place to leave suggestions and use cases. I am working on solving PDEs (numerical method) using PyTorch, I cannot avoid Jacobian (Newton iteration,...
TL;DR - `torch.linalg.slogdet` is over one order of magnitude slower in computing per-sample gradients in the latest nightly version of PyTorch/FuncTorch (`1.13.0.dev20220721` / ` 0.3.0a0+e8a68f4`) than a previous version of...
Hello, I am using functorch for my project. It's a good substitute for Jax and makes life easy. However it does not have GPU support for M1 chips which is...
## Motivation We recently used [TorchOpt](https://github.com/metaopt/TorchOpt) as a functional optimizer API mentioned in [functorch parallel training example](https://github.com/pytorch/functorch/blob/main/examples/ensembling/parallel_train.py#L86) to achieve [batchable optimization training small neural networks on one GPU](https://github.com/metaopt/TorchOpt/blob/main/examples/FuncTorch/parallel_train_torchopt.py) with `functorch.vmap`....
## The Problem ```py import torch from functorch import jvp, vmap from functools import partial B = 2 def f(x, y): x = x.clone() view = x[0] x.copy_(y) return view,...
Hi, I have a GRU model and want to calculate the jacobian of the model with functorch. But there is a performance drop because we have not yet implemented the...
This happens a lot to us, we should add a test to prevent it from happening. Test probably needs to shell out