functorch
functorch copied to clipboard
vmap and forward-mode AD fail sometimes on in-place views
The Problem
import torch
from functorch import jvp, vmap
from functools import partial
B = 2
def f(x, y):
x = x.clone()
view = x[0]
x.copy_(y)
return view, x
def push_jvp(x, y, yt):
return jvp(partial(f, x), (y,), (yt,))
x = torch.randn(2, B, 6)
y = torch.randn(2, 6, B)
yt = torch.randn(2, 6, B)
outs, tangents = vmap(push_jvp, in_dims=(1, 2, 2))(x, y, yt)
raises the following:
RuntimeError: vmap: Calling Tensor.as_strided is not supported unless the batch dims being vmapped over are at the front of
the tensor (in memory layout). When they are not at the front of the tensor this operation can be error prone so we actively
discourage it; please file us a bug report and/or try to express the as_strided operation in terms of PyTorch view operatio
ns
If I am understanding what is going on correctly, the root cause of the problem is that, ignoring vmap for a second, in x.copy_(y), x is a regular Tensor and y is a dual tensor:
- the copy_ causes x.tangent to be a copy of y.tangent
- then, the tangent on the base (x) gets propagated to the views. This happens by calling .as_strided.
view.tangentgets assignedx.tangent.as_strided(something)
Now, if y.tangent is a BatchedTensor, then calling as_strided on it may raise the above error message.
Is this actually a problem?
Previously, our approach was to say that vmap x jvp composition only works when the user must only vmap over dimension 0. However, that's not quite correct -- if the user users non-contiguous tensors, then it'll run into this problem. Also, vmap x jvp can produce tensors where the batch dimension is not at 0, so the user has no control over this.
Potential solutions
- When a tangent gets propagated to views as a result of an in-place operation, instead of calling
as_strided, we should call the original view operation. This means we should save the original view operation somewhere. - (From Jeffrey) An alternative to (1) is: instead of calling as_strided, figure out what the correct non-as_strided view operation(s) are by reading the sizes/sizes/storage_offset, and call that instead.
- It is possible to write a batching rule for a "safe as_strided". An as_strided call is safe if it does not expose memory that was not previously exposed in the Tensor. We would (a) add a
safe_as_stridedoperator, (b) save some metadata on if a view Tensor was created from a base through a chain of "safe" operations or not, and (c) dispatch to eithersafe_as_stridedoras_strided
Thoughts? cc @soulitzer @albanD
Just make https://github.com/pytorch/pytorch/blob/e4ea751810bd1b27a105ac43ce2c8c84fabc1167/c10/core/TensorImpl.h#L1084 return false for BatchedTensor and this will go away! :)
Hmm you mean GradWrapper right?
How does this work? Is there special logic in forward-mode AD that handles support_as_strided?
Its not forward AD specific. There's logic in ADInplaceOrView to check for the tensor's support_as_strided method, so this would apply to all views.
There is special logic in all of autograd for this :)
It basically will replace all the places where we would usually call as_strided() to now call the original view op.
We use this to be able to handle conjuate view, cross dtype views (which can't be replaced with as_strided) or nested tensor (which can't handle generic as_strided)
@albanD do you have a sense of how much overhead this adds?
Making this return false for BatchedTensor doesn't actually work because BatchedTensor isn't directly involved in autograd -- autograd sees the TensorWrapper / GradWrapper. As Jeffrey mentioned we would have to set support_as_strided=False for GradWrapper, which would mean that even if vmap is not involved (e.g. the user just uses functorch.{jvp, grad}), they would take the performance hit.
The difference is noticeable for very small ops but not a dealbreaker either:
In [12]: a = torch.view_as_real(torch.rand(2, dtype=torch.complex64, requires_grad=True).clone())
In [13]: b = torch.rand(4, requires_grad=True).clone().view(2, 2)
In [14]: %timeit tmp = a.view_as(a)
969 ns ± 2.32 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [15]: %timeit tmp = b.view_as(b)
866 ns ± 12 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [16]: %timeit tmp = a.add_(1)
3.27 µs ± 11 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [17]: %timeit tmp = b.add_(1)
2.91 µs ± 14.5 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
The first one is the cost to track a full view op instead of nothing. The second is replaying all the views instead of just one as_strided.
Thanks Alban. A few hundred nanoseconds is not that bad
Had some more offline discussion with Alban.
It's important to note that:
- people do use vmap over torch.autograd.grad (and likely will attempt to vmap over the dual tensor API)
- Because of that, a solution that involves telling autograd to "record views for playback" if it sees a BatchedTensor doesn't work -- when the user runs their forward pass, autograd won't record views for playback. When we pass in a BatchedTensor to torch.autograd.grad, it'll be too late -- the views weren't recorded for playback.
So, here's the current plan on record:
- First we should see if we can easily prove the as_strided is "safe". If we can, then no problem, we can write a batching rule for it.
- If it is not easy to prove the as_strided is "safe", then we may need to thread that information through the view system. I.e. when someone calls a view function (like tensor.diag(), and not as_strided() directly), then we thread the information that the view is a "safe as strided". This is technically complicated so we prefer solution no.1 (or something else) if possible