functorch
functorch copied to clipboard
Error when applying `make_fx()` on a function that calls `optim.step()`
This code snippet used to pass but has recently started throwing an error:
def test_make_fx_model_train_with_optim(self, device):
class Foo(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(5, 5)
def forward(self, x):
return self.linear(x).relu()
model = Foo()
optim = torch.optim.SGD(model.parameters(), lr=1e-4)
def f(args, params, buffers):
if not isinstance(args, Iterable):
args = [args]
params_and_buffers = {**params, **buffers}
out = stateless.functional_call(model, params_and_buffers, args)
out.sum().backward()
optim.step()
# TODO: this causes graph to show an output with many incoming edges. Shall we try `return None` or simply don't return?
return list(params.values())
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
fx_f = make_fx(f)(input, params, buffers)
# TODO: what assert statement should we add here?
assert(fx_f(input, params, buffers) is not None)
This is the error it throws:
File "/Users/distiller/project/test/test_pythonkey.py", line 152, in test_make_fx_model_train_with_optim
fx_f = make_fx(f)(input, params, buffers)
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 407, in wrapped
t = dispatch_trace(wrap_key(f, args), tracer=fx_tracer, concrete_args=tuple(phs))
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 246, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 714, in trace
(self.create_arg(fn(*args)),),
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/_symbolic_trace.py", line 549, in flatten_fn
tree_out = root_fn(*tree_args)
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/fx/experimental/proxy_tensor.py", line 270, in wrapped
out = f(*tree_args)
File "/Users/distiller/project/test/test_pythonkey.py", line 144, in f
optim.step()
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/optim/optimizer.py", line 113, in wrapper
with torch.autograd.profiler.record_function(profile_name):
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/autograd/profiler.py", line 477, in __exit__
torch.ops.profiler._record_function_exit(self.handle)
File "/Users/distiller/project/env/lib/python3.10/site-packages/torch/_ops.py", line 164, in __call__
return self._op(*args, **kwargs or {})
RuntimeError: Expected temporary cpp type wrapper of type at::RecordFunction