functorch icon indicating copy to clipboard operation
functorch copied to clipboard

AOTAutograd makes unsafe assumptions on how the backward pass will look like

Open zou3519 opened this issue 3 years ago • 4 comments

Context: how AOTAutograd works today

Given a function f:

  • AOTAutograd traces out run_forward_and_backward_f(*args, *grad_outputs) to produce forward_and_backward_trace
  • AOTAutograd partitions forward_and_backward_trace into a forward_trace and a backward_trace
  • AOTAutograd compiles the forward_trace and backward_trace separately
  • The compiled_forward_trace and compiled_backward_trace are stitched into an autograd.Function

The Problem

In order to trace run_forward_and_backward_f(*args, *grad_outputs), AOTAutograd needs to construct a Proxy for the grad_outputs. This ends up assuming properties of the grad_output: for example, AOTAutograd assumes that the grad_outputs are contiguous.

There are some more adversarial examples that we could construct. If the backward formula of at::sin were instead:

def sin_backward(grad_output, input):
  if grad_output.is_sparse():
    return grad_output * input.sin()
  return grad_output * input.cos()

then, depending on the properties of the input, the backward that should get executed is different. If AOTAutograd assumes that the Proxy is dense and contiguous, then the backward pass of the generated autograd.Function would be incorrect.

Potential proposal

Proposal: delay tracing the backward pass until the backward pass is invoked.

So, given a function f:

  • AOTAutograd constructs a trace of f (that includes intermediates as outputs), forward_trace
  • AOTAutograd constructs an autograd.Function that has compiled(forward_trace) as the forward pass

The autograd.Function's backward pass, when invoked:

  • traces out run_forward_and_backward_f(*args, *grad_outputs) to produce forward_and_backward_trace
  • takes the difference of forward_and_backward_trace and forward_trace to produce backward_trace.
  • compiles backward_trace into compiled_backward_trace
  • then invokes it.

Things that we haven't mentioned that will need to be thought about:

  • how does AOTAutograd's rematerialization come into play here?

Things that we haven't mentioned that should be orthogonal:

  • caching. compiled(forward_trace) needs a cache that uses the inputs as keys (among other things), compiled(backward_trace) needs a cache that takes the (inputs, grad_outputs) as keys.
  • what if the backward is user-defined (e.g., autograd.Function) and isn't traceable? See https://github.com/pytorch/pytorch/issues/93723 for ideas

Alternatives

Keep the current scheme (AOTAutograd traces out both the forward+backward pass at the time of the forward), but somehow prove to ourselves that the produced trace of the backward pass is always correct.

cc @Chillee @anijain2305 @ezyang @anjali411 @albanD

zou3519 avatar Jun 01 '22 18:06 zou3519

What is the requirement of the grad_output normally? Isn't it roughly that its the same type of tensor (and memory format?) as the output tensor? Would it be valid to construct a proxy grad_output with the same logical metadata as the output?

gchanan avatar Jun 01 '22 18:06 gchanan

Isn't it roughly that its the same type of tensor (and memory format?) as the output tensor?

Same shape / dtype / layout / device (checks are here https://github.com/pytorch/pytorch/blob/5fbec86faef07d66ab696bc4c4edbaf6259a2189/torch/csrc/autograd/engine.cpp#L694). Note that stride / contiguousness / sparsity are not checked there (because we do want gradients that don't match in some cases). Also, even for the ones currently checked, there are request to lift these constraints (for subclass in particular). So these requirements might change in the future.

albanD avatar Jun 01 '22 19:06 albanD

def sin_backward(grad_output, input):
   if input.is_sparse():
     return x.sin()
   return x.cos()

What is x? What I'm trying to get at is if the current rules allow you to write something like this?

gchanan avatar Jun 01 '22 20:06 gchanan

def sin_backward(grad_output, input):
   if input.is_sparse():
     return x.sin()
   return x.cos()

What is x? What I'm trying to get at is if the current rules allow you to write something like this?

Sorry I have corrected the example. The current rules allow you to write something like this (as Alban mentioned sparsity is not checked). Another thing the current rules allows for is switching on contiguity, which is also not checked:

def sin_backward(grad_output, input):
  if grad_output.is_contiguous():
    return grad_output * input.sin()
  return grad_output * input.cos()

zou3519 avatar Jun 01 '22 20:06 zou3519