functorch icon indicating copy to clipboard operation
functorch copied to clipboard

AOT Autograd - LSTM - grads not generated (model tts_angular)

Open anijain2305 opened this issue 3 years ago • 1 comments

This is a subgraph from tts_angular model

The generated backward pass has many None outputs, suggesting that that requires_grad is somehow not passed correctly when LSTM cell is used.

import functorch
import torch
from torch.nn import *

from functorch.compile import memory_efficient_fusion, print_compile, aot_module, decomposition_table
import importlib
import torchdynamo
import copy
import itertools
from torchdynamo.optimizations import backends

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.self_lstm = LSTM(40, 768, batch_first=True)
        self.weight = Parameter(torch.randn(torch.Size([256, 768], requires_grad=True)))



    def forward(self, x):
        self_lstm = self.self_lstm(x);  x = None
        getitem = self_lstm[0];  self_lstm = None
        linear = torch.nn.functional.linear(getitem, self.weight, bias = None);  getitem = self_linear_weight = None
        return (linear,)

def reduce_out(out):
    if isinstance(out, torch.Tensor):
        return torch.sigmoid(out).sum()
    elif isinstance(out, (tuple, list)):
        return sum([reduce_out(x) for x in out])
    raise NotImplementedError("Don't know how to reduce", type(out))


def checkpoint_params(gm):
    rng_state = torch.clone(torch.random.get_rng_state())
    saved_state = []
    for param in itertools.chain(gm.parameters(), gm.buffers()):
        saved_state.append((param, param._version, torch.clone(param)))

    def restore():
        with torch.no_grad():
            torch.random.set_rng_state(rng_state)
            for param, version, original_value in saved_state:
                if param._version != version:
                    param.copy_(original_value)

    return restore


def clone_me(x):
    if x is None:
        return None
    return x.detach().clone().requires_grad_(x.requires_grad)

def collect_results(model, prediction, loss, example_inputs):
    results = []
    results.append(prediction)
    results.append(loss)
    for param in model.parameters():
        results.append(clone_me(param.grad))
    for example in example_inputs:
        if isinstance(example, list):
            for inp in example:
                results.append(clone_me(inp.grad))
        else:
            results.append(clone_me(example.grad))
    return results

def same(a, b):
    """Check correctness to see if a and b match"""
    if isinstance(a, (list, tuple, torch.nn.ParameterList)):
        if not isinstance(b, (list, tuple)):
            return False
        return all(same(ai, bi) for ai, bi in zip(a, b))
    elif isinstance(a, torch.Tensor):
        assert isinstance(b, torch.Tensor)
        if not  torch.allclose(a, b, atol=1e-5, rtol=1e-5):
            print(a.flatten()[1], b.flatten()[1])
            print(a.size())
        return torch.allclose(a, b, atol=1e-5, rtol=1e-5)
    elif isinstance(a, (int, float, type(None), bool, torch.device)):
        return a == b
    else:
        raise RuntimeError(f"unsupported type: {type(a).__name__}")


def clone_inputs(inputs):
    clones = [clone_me(x) for x in inputs]
    for c in clones:
        c.grad = None
    return clones


def get_results(mod, inputs):
    cloned_inputs = clone_inputs(inputs)
    mod.zero_grad(True)
    ref = mod(*cloned_inputs)
    l = reduce_out(ref)
    l.backward()
    ref_results = collect_results(mod, ref, l, cloned_inputs)
    return ref_results

def test_module():
    inp0 = torch.randn(64, 50, 40, device="cuda", requires_grad=True)
    inputs = [inp0, ]

    mod = Bar().to(device="cuda")
    restore = checkpoint_params(mod)
    orig_mod_results = get_results(mod, inputs)

    restore()
    new_mod = copy.deepcopy(mod)
    copy_mod_results = get_results(new_mod, inputs)
    print("Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
    # assert same(orig_mod_results, copy_mod_results), "Deepcopy of a mod fails, what the hell"

    restore()
    aot_mod = aot_module(mod, fw_compiler=print_compile)
    aot_mod_results = get_results(aot_mod, inputs)

    print("Recheck Are Orig_mod and Copy_mod same:", same(orig_mod_results, copy_mod_results))
    print("Are Orig_mod and AOT_mod same:", same(orig_mod_results, aot_mod_results))
    print("Are Copy_mod and AOT_mod same:", same(copy_mod_results, aot_mod_results))

test_module()

anijain2305 avatar Mar 10 '22 01:03 anijain2305

This now is erroring

   File "/raid/ezyang/pytorch-scratch2/torch/nn/modules/rnn.py", line 774, in forward                                                       
    result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/utils/_python_dispatch.py", line 74, in wrapped                                                
    return f(self, *args, **kwargs)                                                                                                        
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 408, in __torch_dispatch__                              
    return proxy_call(self, func_overload, args, kwargs)                                                                                   
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 168, in proxy_call                                      
    proxy_res = func_overload(*proxy_args, **proxy_kwargs)                                                                                 
  File "/raid/ezyang/pytorch-scratch2/torch/_ops.py", line 60, in __call__                                                                 
    return self._op(*args, **kwargs or {})                                                                                                 
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 321, in __torch_function__                                                  
    return tracer.create_proxy('call_function', orig_method, args, kwargs,                                         
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 65, in create_proxy                                                         
    args_ = self.create_arg(args)                                                                                                          
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 347, in create_arg                                      
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/_symbolic_trace.py", line 343, in create_arg                                                
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 127, in create_arg                                                          
    return type(a)(self.create_arg(elem) for elem in a)                                                                                    
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 127, in <genexpr>                                                           
    return type(a)(self.create_arg(elem) for elem in a)                                                                                    
  File "/raid/ezyang/pytorch-scratch2/torch/fx/experimental/proxy_tensor.py", line 347, in create_arg                                      
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/_symbolic_trace.py", line 343, in create_arg                                                
    return super().create_arg(a)                                                                                                           
  File "/raid/ezyang/pytorch-scratch2/torch/fx/proxy.py", line 153, in create_arg                                                          
    raise NotImplementedError(f"argument of type: {type(a)}")        
NotImplementedError: argument of type: <class 'torch.storage.UntypedStorage'>   

Last time I saw this it was because we tried to copy a Proxy

ezyang avatar Aug 12 '22 15:08 ezyang

I think the minimal repro may no longer be valid. When I run the original tts_angular it passes

$ python benchmarks/torchbench.py --training --devices=cuda --accuracy-aot-nop --use-eval-mode -k tts_angular  

ezyang avatar Aug 23 '22 19:08 ezyang

OK it turns out inductor still triggers this

./benchmarks/torchbench.py --inductor  -dcuda --no-skip -k tts_angular

ezyang avatar Aug 23 '22 20:08 ezyang