PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

[SPMD] Remove Gradient tensor clones added during DTensor comm collective insertion

Open lessw2020 opened this issue 3 years ago • 0 comments

After expansion of DTensor communication operations, fx is inserting a clone operation to clone the gradient tensor. This operation will slow down the perf and add memory, but is technically not needed.

I have verified that we can rewrite the graph to remove it by directly updating to use the grad tensor.

Thus we need to either: a - modify the comm subgraph such that fx does not insert the clone operation or b - have an optimization pass that will obviate the clone nodes via relinking the comm primitives.

lessw2020 avatar Nov 18 '22 18:11 lessw2020