PiPPy
PiPPy copied to clipboard
[SPMD] Remove Gradient tensor clones added during DTensor comm collective insertion
After expansion of DTensor communication operations, fx is inserting a clone operation to clone the gradient tensor. This operation will slow down the perf and add memory, but is technically not needed.
I have verified that we can rewrite the graph to remove it by directly updating to use the grad tensor.
Thus we need to either: a - modify the comm subgraph such that fx does not insert the clone operation or b - have an optimization pass that will obviate the clone nodes via relinking the comm primitives.