AOTAutograd makes unsafe assumptions on how the backward pass will look like
Context: how AOTAutograd works today
Given a function f:
- AOTAutograd traces out
run_forward_and_backward_f(*args, *grad_outputs)to produceforward_and_backward_trace - AOTAutograd partitions
forward_and_backward_traceinto a forward_trace and a backward_trace - AOTAutograd compiles the forward_trace and backward_trace separately
- The compiled_forward_trace and compiled_backward_trace are stitched into an autograd.Function
The Problem
In order to trace run_forward_and_backward_f(*args, *grad_outputs), AOTAutograd needs to construct a Proxy for the grad_outputs. This ends up assuming properties of the grad_output: for example, AOTAutograd assumes that the grad_outputs are contiguous.
There are some more adversarial examples that we could construct. If the backward formula of at::sin were instead:
def sin_backward(grad_output, input):
if grad_output.is_sparse():
return grad_output * input.sin()
return grad_output * input.cos()
then, depending on the properties of the input, the backward that should get executed is different. If AOTAutograd assumes that the Proxy is dense and contiguous, then the backward pass of the generated autograd.Function would be incorrect.
Potential proposal
Proposal: delay tracing the backward pass until the backward pass is invoked.
So, given a function f:
- AOTAutograd constructs a trace of f (that includes intermediates as outputs),
forward_trace - AOTAutograd constructs an autograd.Function that has
compiled(forward_trace)as the forward pass
The autograd.Function's backward pass, when invoked:
- traces out
run_forward_and_backward_f(*args, *grad_outputs)to produceforward_and_backward_trace - takes the difference of
forward_and_backward_traceandforward_traceto producebackward_trace. - compiles
backward_traceintocompiled_backward_trace - then invokes it.
Things that we haven't mentioned that will need to be thought about:
- how does AOTAutograd's rematerialization come into play here?
Things that we haven't mentioned that should be orthogonal:
- caching.
compiled(forward_trace)needs a cache that uses the inputs as keys (among other things),compiled(backward_trace)needs a cache that takes the (inputs, grad_outputs) as keys. - what if the backward is user-defined (e.g., autograd.Function) and isn't traceable? See https://github.com/pytorch/pytorch/issues/93723 for ideas
Alternatives
Keep the current scheme (AOTAutograd traces out both the forward+backward pass at the time of the forward), but somehow prove to ourselves that the produced trace of the backward pass is always correct.
cc @Chillee @anijain2305 @ezyang @anjali411 @albanD
What is the requirement of the grad_output normally? Isn't it roughly that its the same type of tensor (and memory format?) as the output tensor? Would it be valid to construct a proxy grad_output with the same logical metadata as the output?
Isn't it roughly that its the same type of tensor (and memory format?) as the output tensor?
Same shape / dtype / layout / device (checks are here https://github.com/pytorch/pytorch/blob/5fbec86faef07d66ab696bc4c4edbaf6259a2189/torch/csrc/autograd/engine.cpp#L694). Note that stride / contiguousness / sparsity are not checked there (because we do want gradients that don't match in some cases). Also, even for the ones currently checked, there are request to lift these constraints (for subclass in particular). So these requirements might change in the future.
def sin_backward(grad_output, input):
if input.is_sparse():
return x.sin()
return x.cos()
What is x? What I'm trying to get at is if the current rules allow you to write something like this?
def sin_backward(grad_output, input): if input.is_sparse(): return x.sin() return x.cos()What is
x? What I'm trying to get at is if the current rules allow you to write something like this?
Sorry I have corrected the example. The current rules allow you to write something like this (as Alban mentioned sparsity is not checked). Another thing the current rules allows for is switching on contiguity, which is also not checked:
def sin_backward(grad_output, input):
if grad_output.is_contiguous():
return grad_output * input.sin()
return grad_output * input.cos()