functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Check how functorch interacts with autocast

Open zou3519 opened this issue 4 years ago • 15 comments

I have a feeling it doesn't

zou3519 avatar Dec 02 '21 14:12 zou3519

  • Vmap seems working with autocast (based on docs)
# Creates some tensors in default dtype (here assumed to be float32)
a_float32 = torch.rand(4, 2, 3, device="cuda")
b_float32 = torch.rand(4, 3, 2, device="cuda")
c_float32 = torch.rand(4, 2, 2, device="cuda")
d_float32 = torch.rand(4, 3, 2, device="cuda")

def func(x, y, z, w):
    with autocast():
        e_float16 = torch.matmul(x, y)
        assert e_float16.dtype == torch.float16, e_float16.dtype
        f_float16 = torch.matmul(z, e_float16)
        assert f_float16.dtype == torch.float16, f_float16.dtype
    return torch.matmul(w, f_float16.float())

expected = func(a_float32, b_float32, c_float32, d_float32)
print(expected.shape)

out = vmap(func, in_dims=0)(a_float32, b_float32, c_float32, d_float32)
print(out.shape)

assert expected.allclose(out)

@autocast()
def func(x, y, z, w):
    e_float16 = torch.matmul(x, y)
    assert e_float16.dtype == torch.float16, e_float16.dtype
    f_float16 = torch.matmul(z, e_float16)
    assert f_float16.dtype == torch.float16, f_float16.dtype
    return torch.matmul(w, f_float16)

expected = func(a_float32, b_float32, c_float32, d_float32)
print(expected.shape)

out = vmap(func, in_dims=0)(a_float32, b_float32, c_float32, d_float32)
print(out.shape)

assert expected.allclose(out)
  • Vmap and grad seems working too

import torch
from torch.cuda.amp import autocast
from functorch import vmap, grad

def func(x, y):
    with autocast():
        res = torch.matmul(x, y)
        assert res.dtype == torch.float16, res.dtype
        res = res.sum()
        assert res.dtype == torch.float32, res.dtype
    return res

a = torch.rand(4, 2, 3, device="cuda")
b = torch.rand(4, 3, 2, device="cuda")

out = vmap(grad(func), in_dims=0)(a, b)
print(out.shape)
> torch.Size([4, 2, 3])

Updates

  • autocast outside vmap

import torch
from torch.cuda.amp import autocast
from functorch import vmap

a_float32 = torch.rand(4, 2, 3, device="cuda")
b_float32 = torch.rand(4, 3, 2, device="cuda")
c_float32 = torch.rand(4, 2, 2, device="cuda")
d_float32 = torch.rand(4, 3, 2, device="cuda")

def func(x, y, z, w):
    e_float16 = torch.matmul(x, y)
    assert e_float16.dtype == torch.float16, e_float16.dtype
    f_float16 = torch.matmul(z, e_float16)
    assert f_float16.dtype == torch.float16, f_float16.dtype
    return torch.matmul(w, f_float16)

with autocast():
    expected = func(a_float32, b_float32, c_float32, d_float32)
print(expected.shape)

with autocast():
    out = vmap(func, in_dims=0)(a_float32, b_float32, c_float32, d_float32)
print(out.shape)

assert expected.allclose(out)

vfdev-5 avatar Dec 07 '21 15:12 vfdev-5

@zou3519 can you detail in which case you think it wont work ?

vfdev-5 avatar Dec 08 '21 11:12 vfdev-5

Hmm, so I have three questions:

What do we expect to be the behavior of the following?

Does the position of the context manager change anything?

def bar(x):
  with autocast():
    ...

# case 1: context manager inside function
vmap(bar)(x)

# case 2: context manager outside
with autocast():
  vmap(foo)(x)

Do we know when autocast runs?

We would want all of the functorch transforms to run before autocast. That is, if we do vmap(grad(func))(x, y) where func has autocast, then:

  1. x and y get wrapped in TensorWrapper(BatchedTensor(..))
  2. We should dispatch to the autograd implementation of e.g. matmul
  3. This will redispatch to the batching rule for matmul
  4. Finally, this will redispatch, and now, autocast should probably happen.

Is autocast fully composable with our transforms?

Autocast in PyTorch happens before autograd: https://github.com/pytorch/pytorch/blob/f3983f9c478c19c2f9054976d285c4a57a6e9ccb/c10/core/DispatchKey.h#L254-L257 . This might mean that it doesn't implement support for backward operators like convolution_backward because it doesn't need to!

If autocast indeeds run after functorch transforms, then grad(F.conv2d) (where func is the function you defined above) will cause convolution_forward and convolution_backward (NB: these not the actual names of these operators) to be emitted, at which point autocast might not be able to handle convolution_backward.

zou3519 avatar Dec 08 '21 14:12 zou3519

Thanks for details !

I think the answer to the question 1 is that autocast works (at least with torch.matmul) as expected inside/outside vmap. I updated my previous message and put testing code snippet. Do you think we should add them to test_vmap.py ?

I'll check in depth other questions.

vfdev-5 avatar Dec 08 '21 16:12 vfdev-5

I think the answer to the question 1 is that autocast works (at least with torch.matmul) as expected inside/outside vmap. I updated my previous message and put testing code snippet. Do you think we should add them to test_vmap.py ?

Thanks for checking. Yes, test cases are always great!

zou3519 avatar Dec 09 '21 00:12 zou3519

Some last years investigations results:

amp_dtype = torch.float16
device_type = "cuda"

def func(x, y):
    with autocast(dtype=amp_dtype, device_type=device_type):
        res = torch.mm(x, y)
        assert res.dtype == amp_dtype, res.dtype
    # return res
    return res.sum()


def func_no_autocast(x, y):
    res = torch.mm(x, y)
    return res.sum()


a = torch.ones(4, 12, 5, device=device_type)
b = 1.2 * torch.ones(4, 5, 12, device=device_type, dtype=torch.float)

expected = vmap(grad(func_no_autocast), in_dims=0)(a, b)
# -- expected: torch.Size([4, 12, 5]) tensor([[14.4000, 14.4000, 14.4000, 14.4000, 14.4000], ...
out = vmap(grad(func), in_dims=0)(a, b)
# -- out: torch.Size([4, 12, 5]) tensor([[14.4062, 14.4062, 14.4062, 14.4062, 14.4062],

vfdev-5 avatar Jan 10 '22 14:01 vfdev-5

AOT Autograd currently doesn't support autocast, and we confirm this through torch->xla lowering.

For example a simple conv2d:

PyTorch module

model = torch.nn.Conv2d(16, 33, 3, stride=2)
compiled_module = aot_module(model, print_mhlo_to_file('conv2d_forward'), print_mhlo_to_file('conv2d_backward'))
data = torch.randn(20, 16, 50, 100)

with torch.autocast('cuda'):
    output = compiled_module(data)

Forward FX graph

def forward(self, primals_1, primals_2, primals_3):
    convolution = torch.ops.aten.convolution(primals_3, primals_1, primals_2, [2, 2], [0, 0], [1, 1], False, [0, 0], 1);  primals_2 = None
    return [convolution, primals_1, primals_3]

Backward FX graph

def forward(self, primals_1, primals_3, tangents_1):
    convolution_backward = torch.ops.aten.convolution_backward(tangents_1, primals_3, primals_1, [33], [2, 2], [0, 0], [1, 1], False, [0, 0], 1, [False, True, True]);  tangents_1 = primals_3 = primals_1 = None
    getitem_1 = convolution_backward[1]
    getitem_2 = convolution_backward[2];  convolution_backward = None
    return [getitem_1, getitem_2]

Forward MHLO

autocast is correctly handled by mhlo.convert for forward computation, but the returned weights (saved_tensors?) are original fp32 master weights instead of casted fp16 ones.

module  {
  func @main(%arg0: tensor<33x16x3x3xf32>, %arg1: tensor<33xf32>, %arg2: tensor<20x16x50x100xf32>) -> tuple<tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>> {
    %0 = "mhlo.convert"(%arg2) : (tensor<20x16x50x100xf32>) -> tensor<20x16x50x100xf16>
    %1 = "mhlo.convert"(%arg0) : (tensor<33x16x3x3xf32>) -> tensor<33x16x3x3xf16>
    %2 = "mhlo.convert"(%arg1) : (tensor<33xf32>) -> tensor<33xf16>
    %3 = call @aten.convolution_overrideable.7(%0, %1, %2) : (tensor<20x16x50x100xf16>, tensor<33x16x3x3xf16>, tensor<33xf16>) -> tensor<20x33x24x49xf16>
    %4 = "mhlo.tuple"(%3, %arg0, %arg2) : (tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>) -> tuple<tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>>
    return %4 : tuple<tensor<20x33x24x49xf16>, tensor<33x16x3x3xf32>, tensor<20x16x50x100xf32>>
  }
  func private @aten.convolution_overrideable.7(%arg0: tensor<20x16x50x100xf16>, %arg1: tensor<33x16x3x3xf16>, %arg2: tensor<33xf16>) -> tensor<20x33x24x49xf16> {
    %0 = mhlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 2], pad = [[0, 0], [0, 0]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<20x16x50x100xf16>, tensor<33x16x3x3xf16>) -> tensor<20x33x24x49xf16>
    %1 = "mhlo.broadcast_in_dim"(%arg2) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<33xf16>) -> tensor<20x24x49x33xf16>
    %2 = "mhlo.transpose"(%1) {minor_to_major = dense<[1, 3, 2, 0]> : tensor<4xindex>, permutation = dense<[0, 3, 1, 2]> : tensor<4xi64>} : (tensor<20x24x49x33xf16>) -> tensor<20x33x24x49xf16>
    %3 = mhlo.add %0, %2 : tensor<20x33x24x49xf16>
    return %3 : tensor<20x33x24x49xf16>
  }
}

Backward MHLO

There is no autocast, causing a failure calling mistyped f16/f32 convolution in the backend.

module  {
  func @main(%arg0: tensor<33x16x3x3xf32>, %arg1: tensor<20x16x50x100xf32>, %arg2: tensor<20x33x24x49xf16>) -> tuple<tensor<33x16x3x3xf32>, tensor<33xf16>> {
    %0 = call @aten.convolution_backward_overrideable.8(%arg2, %arg1, %arg0) : (tensor<20x33x24x49xf16>, tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>) -> tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>
    %1 = "mhlo.get_tuple_element"(%0) {index = 0 : i32} : (tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>) -> tensor<20x16x50x100xf32>
    %2 = "mhlo.get_tuple_element"(%0) {index = 1 : i32, minor_to_major = dense<[0, 1, 3, 2]> : tensor<4xindex>} : (tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>) -> tensor<33x16x3x3xf32>
    %3 = "mhlo.get_tuple_element"(%0) {index = 2 : i32} : (tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>) -> tensor<33xf16>
    %4 = "mhlo.tuple"(%2, %3) : (tensor<33x16x3x3xf32>, tensor<33xf16>) -> tuple<tensor<33x16x3x3xf32>, tensor<33xf16>>
    return %4 : tuple<tensor<33x16x3x3xf32>, tensor<33xf16>>
  }
  func private @aten.convolution_backward_overrideable.8(%arg0: tensor<20x33x24x49xf16>, %arg1: tensor<20x16x50x100xf32>, %arg2: tensor<33x16x3x3xf32>) -> tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>> {
    %0 = "mhlo.transpose"(%arg2) {minor_to_major = dense<[1, 0, 2, 3]> : tensor<4xindex>, permutation = dense<[2, 3, 1, 0]> : tensor<4xi64>} : (tensor<33x16x3x3xf32>) -> tensor<3x3x16x33xf32>
    %1 = "mhlo.reverse"(%0) {dimensions = dense<[0, 1]> : tensor<2xi64>, minor_to_major = dense<[1, 0, 2, 3]> : tensor<4xindex>} : (tensor<3x3x16x33xf32>) -> tensor<3x3x16x33xf32>
    %2 = mhlo.convolution(%arg0, %1) dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 3], [2, 3]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<20x33x24x49xf16>, tensor<3x3x16x33xf32>) -> tensor<20x16x50x100xf32>
    %3 = mhlo.convolution(%arg1, %arg0) dim_numbers = [f, b, 0, 1]x[i, o, 0, 1]->[0, 1, b, f], window = {stride = [1, 1], pad = [[0, -1], [0, -1]], lhs_dilate = [1, 1], rhs_dilate = [2, 2]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = ["DEFAULT", "DEFAULT"]} : (tensor<20x16x50x100xf32>, tensor<20x33x24x49xf16>) -> tensor<3x3x16x33xf32>
    %4 = "mhlo.transpose"(%3) {minor_to_major = dense<[0, 1, 3, 2]> : tensor<4xindex>, permutation = dense<[3, 2, 0, 1]> : tensor<4xi64>} : (tensor<3x3x16x33xf32>) -> tensor<33x16x3x3xf32>
    %5 = mhlo.constant dense<0.000000e+00> : tensor<f16>
    %6 = "mhlo.reduce"(%arg0, %5) ( {
    ^bb0(%arg3: tensor<f16>, %arg4: tensor<f16>):  // no predecessors
      %8 = mhlo.add %arg3, %arg4 : tensor<f16>
      "mhlo.return"(%8) : (tensor<f16>) -> ()
    }) {dimensions = dense<[0, 2, 3]> : tensor<3xi64>} : (tensor<20x33x24x49xf16>, tensor<f16>) -> tensor<33xf16>
    %7 = "mhlo.tuple"(%2, %4, %6) : (tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>) -> tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>
    return %7 : tuple<tensor<20x16x50x100xf32>, tensor<33x16x3x3xf32>, tensor<33xf16>>
  }
}

byronyi avatar Jan 24 '22 23:01 byronyi

IMHO there is 2 possible fixes:

  1. return weights with casted dtypes in forward, and feed them directly to backward
  2. fix autocast for backward ops in core, possibly having unnecessary casts in backward

byronyi avatar Jan 24 '22 23:01 byronyi

This was checked already, but I want to take a closer look at it one more time

zou3519 avatar Sep 15 '22 15:09 zou3519

Problem: Autocast doesn't work with functorch's grad transform

autocast rules are implemented for forward-passes only (think convolution) and not backward operators (e.g. convolution_backward). This leads to some interesting situations:

(1) grad over function that uses autocast

def f(x, w):
  with torch.autocast('cuda'):
    return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()

gm = make_fx(grad(f))(x, w)

print(gm.code)
  • the functorch transforms run first
  • grad(f) produces some code that does convolution and convolution_backward
  • Then, autocast imposes and then transforms the convolution to be in fp16, but not convolution_backward.
  • Ideally grad(f) would have the same behavior as performing y = f(x); torch.autograd.grad(y, x). The non-functorch version does the right thing (convolution and convolution backward are run in fp16) but the functorch version

(2) autocast over grad(f)

def f(x, w):
  return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()

with torch.autocast('cuda'):
  gm = make_fx(grad(f))(x, w)

print(gm.code)
  • the functorch transforms run first
  • grad(f) produces some code that does convolution and convolution_backward
  • Then, autocast imposes and then transforms the convolution to be in fp16, but not convolution_backward.
  • This case is weird but perhaps not a problem -- autocast doesn't support directly transforming backward ops, and grad(f) creates backward ops.

Potential solutions

Solution 1: have autocast participate in functorch's stack-based dispatching

If with torch.autocast('cuda') pushed a mode onto functorch's stack, then we would get the correct behavior.

  • In case 1, at the time of invoking conv2d, the functorch stack would look like [grad, autocast] <- top, and autocast would get to run first!

The downside is that this adds some additional dispatching overhead for people who do not use functorch. But maybe we can optimize this away: only push onto the functorch stack if a functorch transform is already active.

Solution 2: Special-case autocast in functorch's grad transform

functorch's grad transform will enable only the autograd keys and then do a dispatch through them. The problem with this in case 1 is that we want autocast to run first. To fix it, we can have functorch's grad transform enable the autocast key as well as the autograd keys.

This feels super hacky to me; autocast is logically something separate from autograd.

Discussion

Thoughts? @samdow @Chillee @ezyang

zou3519 avatar Sep 23 '22 20:09 zou3519

Why don't you just put functorch keys after autocast? It directly solves (1); and I agree that in (2) it should be "as if" you ran autocast on a program that explicitly had written out its forwards and backwards; weird but acceptable.

ezyang avatar Sep 23 '22 20:09 ezyang

Why don't you just put functorch keys after autocast? It directly solves (1); and I agree that in (2) it should be "as if" you ran autocast on a program that explicitly had written out its forwards and backwards; weird but acceptable.

I think this breaks case 2. The expected behavior for case 2 is if we ran autocast on a program that had explicitly had written out its forwards and backward. In that case, convolution would be run in fp16 and convolution_backward would be run in fp32.

Walking through case 2, under the assumption that we put the functorch keys after autocast:

def f(x, w):
  return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()

with torch.autocast('cuda'):
  gm = make_fx(grad(f))(x, w)

print(gm.code)
  • On the call to conv2d, autocast goes first. So it'll do F.conv2d(x.half(), w.half())
  • When the grad key gets to conv2d, autograd saves the fp16 version of x and w.
  • convolution_backward will get run in fp16 instead of fp32!

This is one of those "the ordering of the dispatch keys matters" things that functorch tries to solve with the stack-based dispatching: in case 2, we want functorch's grad transform to run before autocast

zou3519 avatar Sep 23 '22 20:09 zou3519

But... isn't that the point?

import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F

def f(x, w):
  return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4, device='cuda', requires_grad=True)
w = torch.randn(3, 3, 2, 2, device='cuda', requires_grad=True)

with torch.autocast('cuda'):
  gm = make_fx(lambda x, w: torch.autograd.grad(f(x, w), (x, w)))(x, w)

print(gm)

gives

def forward(self, x_1, w_1):
    _to_copy = torch.ops.aten._to_copy.default(w_1, dtype = torch.float16);  w_1 = None
    _to_copy_3 = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16);  x_1 = None
    convolution = torch.ops.aten.convolution.default(_to_copy_3, _to_copy, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 1)
    sum_5 = torch.ops.aten.sum.default(convolution, dtype = torch.float32);  convolution = None
    ones_like = torch.ops.aten.ones_like.default(sum_5, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0), pin_memory = False, memory_format = torch.preserve_format);  sum_5 = None
    expand = torch.ops.aten.expand.default(ones_like, [2, 3, 3, 3]);  ones_like = None
    _to_copy_8 = torch.ops.aten._to_copy.default(expand, dtype = torch.float16);  expand = None
    convolution_backward = torch.ops.aten.convolution_backward.default(_to_copy_8, _to_copy_3, _to_copy, [0], [1, 1], [0, 0], [1, 1], False, [0, 0], 1, [True, True, False]);  _to_copy_8 = _to_copy_3 = _to_copy = None
    getitem = convolution_backward[0]
    getitem_11 = convolution_backward[1]
    getitem_12 = convolution_backward[2];  convolution_backward = None
    _to_copy_13 = torch.ops.aten._to_copy.default(getitem, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  getitem = None
    _to_copy_14 = torch.ops.aten._to_copy.default(getitem_11, dtype = torch.float32, layout = torch.strided, device = device(type='cuda', index=0));  getitem_11 = None
    return (_to_copy_13, _to_copy_14)

ezyang avatar Sep 23 '22 21:09 ezyang

Er, to clarify, my claim is that (A) should behave the same as (B) which is supposed to be different from (C):

Snippet A:

import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F
from functorch import grad

def f(x, w):
  return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()

# A
with torch.autocast('cuda'):
  gm = make_fx(grad(f))(x, w)

print(gm.code)

Snippet B:

import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F
from functorch import grad

def f(x, w):
  return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4).cuda()
w = torch.randn(3, 3, 2, 2).cuda()

# B
explicit_forward_backward = make_fx(grad(f))(x, w)
with torch.autocast('cuda'):
  gm = make_fx(explicit_forward_backward)(x, w)

print(gm.code)

Snippet C:

import torch
from torch.fx.experimental.proxy_tensor import make_fx
import torch.nn.functional as F

def f(x, w):
  return F.conv2d(x, w).sum()

x = torch.randn(2, 3, 4, 4, device='cuda', requires_grad=True)
w = torch.randn(3, 3, 2, 2, device='cuda', requires_grad=True)

# C
with torch.autocast('cuda'):
  gm = make_fx(lambda x, w: torch.autograd.grad(f(x, w), x))(x, w)

print(gm)

Are you saying the expected behavior should be that A behaves the same way as C?

If we put the autocast key before the functorch keys, then A will behave the same way as C.

NB: looks like on PyTorch master, A behaves the same way as C, so under my claim above this is a bug

zou3519 avatar Sep 23 '22 21:09 zou3519

Yes, I think I am fitting my intuition to what PyTorch master does today. What is muddying the waters is the autocast context manager, as compared to a hypothetical autocast functorch transform.

The claim I'll make is that in conventional PyTorch, when you use autograd and autocast together, autocast ALWAYS gets applied first, no matter where the autocast context manager shows up in the code. This is of course what the point of dispatch key ordering is: it fixes an order transforms apply--go use functorch if you want to reorder. So of course an autocast transform would work differently, since you'd expect the order transforms are applied to be respected.

Does that help?

ezyang avatar Sep 23 '22 23:09 ezyang