functorch
functorch copied to clipboard
Make fx tracing use input tensor copies to prevent side effects
Some side effects could possibly occur during flat_fn
and make_fx
, making unsound changes to the arguments (e.g. module buffers).
Here I try to clone the arguments for compilation, so we always operate on user arguments with compiled function.
Hi, thanks for the PR :)
Most of the work on this tracer is now going on here: https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/proxy_tensor.py
However, about the specific proposal in this PR.... I don't know how I feel about it. I suppose in principle it makes sense, but if you're relying on inputs to your function to not be mutated in the first place then I feel like you might be using make_fx
wrong.
A little bit more context: this issue hits me when I try to trace through batch norm in training mode, as the number of tracked batches is incremented twice during tracing. My compiled function is able to capture the in-place increments though (by copying back SSA program return value to argument, like the one for LTC).
I agree with you: this seems more like a band-aid fix before the functionalization pass landed in core.