functorch icon indicating copy to clipboard operation
functorch copied to clipboard

Make fx tracing use input tensor copies to prevent side effects

Open byronyi opened this issue 2 years ago • 2 comments

Some side effects could possibly occur during flat_fn and make_fx, making unsound changes to the arguments (e.g. module buffers).

Here I try to clone the arguments for compilation, so we always operate on user arguments with compiled function.

byronyi avatar May 13 '22 14:05 byronyi

Hi, thanks for the PR :)

Most of the work on this tracer is now going on here: https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/proxy_tensor.py

However, about the specific proposal in this PR.... I don't know how I feel about it. I suppose in principle it makes sense, but if you're relying on inputs to your function to not be mutated in the first place then I feel like you might be using make_fx wrong.

Chillee avatar May 20 '22 23:05 Chillee

A little bit more context: this issue hits me when I try to trace through batch norm in training mode, as the number of tracked batches is incremented twice during tracing. My compiled function is able to capture the in-place increments though (by copying back SSA program return value to argument, like the one for LTC).

I agree with you: this seems more like a band-aid fix before the functionalization pass landed in core.

byronyi avatar May 21 '22 02:05 byronyi