Check how functorch interacts with autocast
I have a feeling it doesn't
- Vmap seems working with autocast (based on docs)
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand(4, 2, 3, device="cuda")
b_float32 = torch.rand(4, 3, 2, device="cuda")
c_float32 = torch.rand(4, 2, 2, device="cuda")
d_float32 = torch.rand(4, 3, 2, device="cuda")
def func(x, y, z, w):
with autocast():
e_float16 = torch.matmul(x, y)
assert e_float16.dtype == torch.float16, e_float16.dtype
f_float16 = torch.matmul(z, e_float16)
assert f_float16.dtype == torch.float16, f_float16.dtype
return torch.matmul(w, f_float16.float())
expected = func(a_float32, b_float32, c_float32, d_float32)
print(expected.shape)
out = vmap(func, in_dims=0)(a_float32, b_float32, c_float32, d_float32)
print(out.shape)
assert expected.allclose(out)
@autocast()
def func(x, y, z, w):
e_float16 = torch.matmul(x, y)
assert e_float16.dtype == torch.float16, e_float16.dtype
f_float16 = torch.matmul(z, e_float16)
assert f_float16.dtype == torch.float16, f_float16.dtype
return torch.matmul(w, f_float16)
expected = func(a_float32, b_float32, c_float32, d_float32)
print(expected.shape)
out = vmap(func, in_dims=0)(a_float32, b_float32, c_float32, d_float32)
print(out.shape)
assert expected.allclose(out)
- Vmap and
gradseems working too
import torch
from torch.cuda.amp import autocast
from functorch import vmap, grad
def func(x, y):
with autocast():
res = torch.matmul(x, y)
assert res.dtype == torch.float16, res.dtype
res = res.sum()
assert res.dtype == torch.float32, res.dtype
return res
a = torch.rand(4, 2, 3, device="cuda")
b = torch.rand(4, 3, 2, device="cuda")
out = vmap(grad(func), in_dims=0)(a, b)
print(out.shape)
> torch.Size([4, 2, 3])
Updates
- autocast outside
vmap
import torch
from torch.cuda.amp import autocast
from functorch import vmap
a_float32 = torch.rand(4, 2, 3, device="cuda")
b_float32 = torch.rand(4, 3, 2, device="cuda")
c_float32 = torch.rand(4, 2, 2, device="cuda")
d_float32 = torch.rand(4, 3, 2, device="cuda")
def func(x, y, z, w):
e_float16 = torch.matmul(x, y)
assert e_float16.dtype == torch.float16, e_float16.dtype
f_float16 = torch.matmul(z, e_float16)
assert f_float16.dtype == torch.float16, f_float16.dtype
return torch.matmul(w, f_float16)
with autocast():
expected = func(a_float32, b_float32, c_float32, d_float32)
print(expected.shape)
with autocast():
out = vmap(func, in_dims=0)(a_float32, b_float32, c_float32, d_float32)
print(out.shape)
assert expected.allclose(out)
@zou3519 can you detail in which case you think it wont work ?
Hmm, so I have three questions:
What do we expect to be the behavior of the following?
Does the position of the context manager change anything?
def bar(x):
with autocast():
...
# case 1: context manager inside function
vmap(bar)(x)
# case 2: context manager outside
with autocast():
vmap(foo)(x)
Do we know when autocast runs?
We would want all of the functorch transforms to run before autocast. That is, if we do vmap(grad(func))(x, y) where func has autocast, then:
- x and y get wrapped in TensorWrapper(BatchedTensor(..))
- We should dispatch to the autograd implementation of e.g. matmul
- This will redispatch to the batching rule for matmul
- Finally, this will redispatch, and now, autocast should probably happen.
Is autocast fully composable with our transforms?
Autocast in PyTorch happens before autograd: https://github.com/pytorch/pytorch/blob/f3983f9c478c19c2f9054976d285c4a57a6e9ccb/c10/core/DispatchKey.h#L254-L257 . This might mean that it doesn't implement support for backward operators like convolution_backward because it doesn't need to!
If autocast indeeds run after functorch transforms, then grad(F.conv2d) (where func is the function you defined above) will cause convolution_forward and convolution_backward (NB: these not the actual names of these operators) to be emitted, at which point autocast might not be able to handle convolution_backward.
Thanks for details !
I think the answer to the question 1 is that autocast works (at least with torch.matmul) as expected inside/outside vmap. I updated my previous message and put testing code snippet. Do you think we should add them to test_vmap.py ?
I'll check in depth other questions.
I think the answer to the question 1 is that autocast works (at least with torch.matmul) as expected inside/outside
vmap. I updated my previous message and put testing code snippet. Do you think we should add them totest_vmap.py?
Thanks for checking. Yes, test cases are always great!
Some last years investigations results:
amp_dtype = torch.float16
device_type = "cuda"
def func(x, y):
with autocast(dtype=amp_dtype, device_type=device_type):
res = torch.mm(x, y)
assert res.dtype == amp_dtype, res.dtype
# return res
return res.sum()
def func_no_autocast(x, y):
res = torch.mm(x, y)
return res.sum()
a = torch.ones(4, 12, 5, device=device_type)
b = 1.2 * torch.ones(4, 5, 12, device=device_type, dtype=torch.float)
expected = vmap(grad(func_no_autocast), in_dims=0)(a, b)
# -- expected: torch.Size([4, 12, 5]) tensor([[14.4000, 14.4000, 14.4000, 14.4000, 14.4000], ...
out = vmap(grad(func), in_dims=0)(a, b)
# -- out: torch.Size([4, 12, 5]) tensor([[14.4062, 14.4062, 14.4062, 14.4062, 14.4062],
AOT Autograd currently doesn't support autocast, and we confirm this through torch->xla lowering.
For example a simple conv2d:
PyTorch module
model = torch.nn.Conv2d(16, 33, 3, stride=2)
compiled_module = aot_module(model, print_mhlo_to_file('conv2d_forward'), print_mhlo_to_file('conv2d_backward'))
data = torch.randn(20, 16, 50, 100)
with torch.autocast('cuda'):
output = compiled_module(data)
Forward FX graph
def forward(self, primals_1, primals_2, primals_3):
convolution = torch.ops.aten.convolution(primals_3, primals_1, primals_2, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); primals_2 = None
return [convolution, primals_1, primals_3]
Backward FX graph
def forward(self, primals_1, primals_3, tangents_1):
convolution_backward = torch.ops.aten.convolution_backward(tangents_1, primals_3, primals_1, [33], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]); tangents_1 = primals_3 = primals_1 = None
getitem_1 = convolution_backward[1]
getitem_2 = convolution_backward[2]; convolution_backward = None
return [getitem_1, getitem_2]
Forward MHLO
autocast is correctly handled by mhlo.convert for forward computation, but the returned weights (saved_tensors?) are original fp32 master weights instead of casted fp16 ones.
module {
func @main(%arg0: tensor<33x16x3x3xf32>, %arg1: tensor<33xf32>, %arg2: tensor<20x16x50x100xf32>) -> tuple<tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>> {
%0 = "mhlo.convert"(%arg2) : (tensor<20x16x50x100xf32>) -> tensor<20x16x50x100xf16>
%1 = "mhlo.convert"(%arg0) : (tensor<33x16x3x3xf32>) -> tensor<33x16x3x3xf16>
%2 = "mhlo.convert"(%arg1) : (tensor<33xf32>) -> tensor<33xf16>
%3 = call @aten.convolution_overrideable.7(%0, %1, %2) : (tensor<20x16x50x100xf16>, tensor<33x16x3x3xf16>, tensor<33xf16>) -> tensor<20x33x24x49xf16>
%4 = "mhlo.tuple"(%3, %arg0, %arg2) : (tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>) -> tuple<tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>>
return %4 : tuple<tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>>
}
func private @aten.convolution_overrideable.7(%arg0: tensor<20x16x50x100xf16>, %arg1: tensor<33x16x3x3xf16>, %arg2: tensor<33xf16>) -> tensor<20x33x24x49xf16> {
%0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<20x16x50x100xf16>, tensor<33x16x3x3xf16>) -> tensor<20x33x24x49xf16>
%1 = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<33xf16>) -> tensor<20x24x49x33xf16>
%2 = "mhlo.transpose"(%1) {minor_to_major = dense<[1, 3, 2, 0]> : tensor<4xindex>, permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<20x24x49x33xf16>) -> tensor<20x33x24x49xf16>
%3 = mhlo.add %0, %2 : tensor<20x33x24x49xf16>
return %3 : tensor<20x33x24x49xf16>
}
}
Backward MHLO
There is no autocast, causing a failure calling mistyped f16/f32 convolution in the backend.
module {
func @main(%arg0: tensor<33x16x3x3xf32>, %arg1: tensor<20x16x50x100xf32>, %arg2: tensor<20x33x24x49xf16>) -> tuple<tensor<33x16x3x3xf32>, tensor<33xf16>> {
%0 = call @aten.convolution_backward_overrideable.8(%arg2, %arg1, %arg0) : (tensor<20x33x24x49xf16>, tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>) -> tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>
%1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>) -> tensor<20x16x50x100xf32>
%2 = "mhlo.get_tuple_element"(%0) {index = 1 : i32, minor_to_major = dense<[0, 1, 3, 2]> : tensor<4xindex>} : (tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>) -> tensor<33x16x3x3xf32>
%3 = "mhlo.get_tuple_element"(%0) {index = 2 : i32} : (tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>) -> tensor<33xf16>
%4 = "mhlo.tuple"(%2, %3) : (tensor<33x16x3x3xf32>, tensor<33xf16>) -> tuple<tensor<33x16x3x3xf32>, tensor<33xf16>>
return %4 : tuple<tensor<33x16x3x3xf32>, tensor<33xf16>>
}
func private @aten.convolution_backward_overrideable.8(%arg0: tensor<20x33x24x49xf16>, %arg1: tensor<20x16x50x100xf32>, %arg2: tensor<33x16x3x3xf32>) -> tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>> {
%0 = "mhlo.transpose"(%arg2) {minor_to_major = dense<[1, 0, 2, 3]> : tensor<4xindex>, permutation = dense<[2, 3, 1, 0]> : tensor<4xi64>} : (tensor<33x16x3x3xf32>) -> tensor<3x3x16x33xf32>
%1 = "mhlo.reverse"(%0) {dimensions = dense<[0, 1]> : tensor<2xi64>, minor_to_major = dense<[1, 0, 2, 3]> : tensor<4xindex>} : (tensor<3x3x16x33xf32>) -> tensor<3x3x16x33xf32>
%2 = mhlo.convolution(%arg0, %1) dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 3], [2, 3]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<20x33x24x49xf16>, tensor<3x3x16x33xf32>) -> tensor<20x16x50x100xf32>
%3 = mhlo.convolution(%arg1, %arg0) dim_numbers = [f, b, 0, 1]x[i, o, 0, 1]->[0, 1, b, f], window = {stride = [1, 1], pad = [[0, -1], [0, -1]], lhs_dilate = [1, 1], rhs_dilate = [2, 2]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<20x16x50x100xf32>, tensor<20x33x24x49xf16>) -> tensor<3x3x16x33xf32>
%4 = "mhlo.transpose"(%3) {minor_to_major = dense<[0, 1, 3, 2]> : tensor<4xindex>, permutation = dense<[3, 2, 0, 1]> : tensor<4xi64>} : (tensor<3x3x16x33xf32>) -> tensor<33x16x3x3xf32>
%5 = mhlo.constant dense<0.000000e+00> : tensor<f16>
%6 = "mhlo.reduce"(%arg0, %5) ( {
^bb0(%arg3: tensor<f16>, %arg4: tensor<f16>): // no predecessors
%8 = mhlo.add %arg3, %arg4 : tensor<f16>
"mhlo.return"(%8) : (tensor<f16>) -> ()
}) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<20x33x24x49xf16>, tensor<f16>) -> tensor<33xf16>
%7 = "mhlo.tuple"(%2, %4, %6) : (tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>) -> tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>
return %7 : tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>
}
}
IMHO there is 2 possible fixes:
- return weights with casted dtypes in forward, and feed them directly to backward
- fix autocast for backward ops in core, possibly having unnecessary casts in backward
This was checked already, but I want to take a closer look at it one more time
Problem: Autocast doesn't work with functorch's grad transform
autocast rules are implemented for forward-passes only (think convolution) and not backward operators (e.g. convolution_backward). This leads to some interesting situations:
(1) grad over function that uses autocast
def f(x, w):
with torch.autocast('cuda'):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()
gm = make_fx(grad(f))(x, w)
print(gm.code)
- the functorch transforms run first
- grad(f) produces some code that does convolution and convolution_backward
- Then, autocast imposes and then transforms the convolution to be in fp16, but not convolution_backward.
- Ideally
grad(f)would have the same behavior as performingy = f(x); torch.autograd.grad(y, x). The non-functorch version does the right thing (convolution and convolution backward are run in fp16) but the functorch version
(2) autocast over grad(f)
def f(x, w):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()
with torch.autocast('cuda'):
gm = make_fx(grad(f))(x, w)
print(gm.code)
- the functorch transforms run first
- grad(f) produces some code that does convolution and convolution_backward
- Then, autocast imposes and then transforms the convolution to be in fp16, but not convolution_backward.
- This case is weird but perhaps not a problem -- autocast doesn't support directly transforming backward ops, and grad(f) creates backward ops.
Potential solutions
Solution 1: have autocast participate in functorch's stack-based dispatching
If with torch.autocast('cuda') pushed a mode onto functorch's stack, then we would get the correct behavior.
- In case 1, at the time of invoking conv2d, the functorch stack would look like [grad, autocast] <- top, and autocast would get to run first!
The downside is that this adds some additional dispatching overhead for people who do not use functorch. But maybe we can optimize this away: only push onto the functorch stack if a functorch transform is already active.
Solution 2: Special-case autocast in functorch's grad transform
functorch's grad transform will enable only the autograd keys and then do a dispatch through them. The problem with this in case 1 is that we want autocast to run first. To fix it, we can have functorch's grad transform enable the autocast key as well as the autograd keys.
This feels super hacky to me; autocast is logically something separate from autograd.
Discussion
Thoughts? @samdow @Chillee @ezyang
Why don't you just put functorch keys after autocast? It directly solves (1); and I agree that in (2) it should be "as if" you ran autocast on a program that explicitly had written out its forwards and backwards; weird but acceptable.
Why don't you just put functorch keys after autocast? It directly solves (1); and I agree that in (2) it should be "as if" you ran autocast on a program that explicitly had written out its forwards and backwards; weird but acceptable.
I think this breaks case 2. The expected behavior for case 2 is if we ran autocast on a program that had explicitly had written out its forwards and backward. In that case, convolution would be run in fp16 and convolution_backward would be run in fp32.
Walking through case 2, under the assumption that we put the functorch keys after autocast:
def f(x, w):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()
with torch.autocast('cuda'):
gm = make_fx(grad(f))(x, w)
print(gm.code)
- On the call to conv2d, autocast goes first. So it'll do
F.conv2d(x.half(), w.half()) - When the grad key gets to conv2d, autograd saves the fp16 version of x and w.
- convolution_backward will get run in fp16 instead of fp32!
This is one of those "the ordering of the dispatch keys matters" things that functorch tries to solve with the stack-based dispatching: in case 2, we want functorch's grad transform to run before autocast
But... isn't that the point?
import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F
def f(x, w):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4, device='cuda', requires_grad=True)
w = torch.randn(3, 3, 2, 2, device='cuda', requires_grad=True)
with torch.autocast('cuda'):
gm = make_fx(lambda x, w: torch.autograd.grad(f(x, w), (x, w)))(x, w)
print(gm)
gives
def forward(self, x_1, w_1):
_to_copy = torch.ops.aten._to_copy.default(w_1, dtype = torch.float16); w_1 = None
_to_copy_3 = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None
convolution = torch.ops.aten.convolution.default(_to_copy_3, _to_copy, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
sum_5 = torch.ops.aten.sum.default(convolution, dtype = torch.float32); convolution = None
ones_like = torch.ops.aten.ones_like.default(sum_5, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False, memory_format = torch.preserve_format); sum_5 = None
expand = torch.ops.aten.expand.default(ones_like, [2, 3, 3, 3]); ones_like = None
_to_copy_8 = torch.ops.aten._to_copy.default(expand, dtype = torch.float16); expand = None
convolution_backward = torch.ops.aten.convolution_backward.default(_to_copy_8, _to_copy_3, _to_copy, [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]); _to_copy_8 = _to_copy_3 = _to_copy = None
getitem = convolution_backward[0]
getitem_11 = convolution_backward[1]
getitem_12 = convolution_backward[2]; convolution_backward = None
_to_copy_13 = torch.ops.aten._to_copy.default(getitem, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); getitem = None
_to_copy_14 = torch.ops.aten._to_copy.default(getitem_11, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0)); getitem_11 = None
return (_to_copy_13, _to_copy_14)
Er, to clarify, my claim is that (A) should behave the same as (B) which is supposed to be different from (C):
Snippet A:
import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F
from functorch import grad
def f(x, w):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()
# A
with torch.autocast('cuda'):
gm = make_fx(grad(f))(x, w)
print(gm.code)
Snippet B:
import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F
from functorch import grad
def f(x, w):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()
# B
explicit_forward_backward = make_fx(grad(f))(x, w)
with torch.autocast('cuda'):
gm = make_fx(explicit_forward_backward)(x, w)
print(gm.code)
Snippet C:
import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F
def f(x, w):
return F.conv2d(x, w).sum()
x = torch.randn(2, 3, 4, 4, device='cuda', requires_grad=True)
w = torch.randn(3, 3, 2, 2, device='cuda', requires_grad=True)
# C
with torch.autocast('cuda'):
gm = make_fx(lambda x, w: torch.autograd.grad(f(x, w), x))(x, w)
print(gm)
Are you saying the expected behavior should be that A behaves the same way as C?
If we put the autocast key before the functorch keys, then A will behave the same way as C.
NB: looks like on PyTorch master, A behaves the same way as C, so under my claim above this is a bug
Yes, I think I am fitting my intuition to what PyTorch master does today. What is muddying the waters is the autocast context manager, as compared to a hypothetical autocast functorch transform.
The claim I'll make is that in conventional PyTorch, when you use autograd and autocast together, autocast ALWAYS gets applied first, no matter where the autocast context manager shows up in the code. This is of course what the point of dispatch key ordering is: it fixes an order transforms apply--go use functorch if you want to reorder. So of course an autocast transform would work differently, since you'd expect the order transforms are applied to be respected.
Does that help?