try removing `ltorch.copy_`
It seems that I've added ltorch.copy_ in #1063 but I don't remember exactly why I did so.
I even am speculating that was an unwanted change with some cost.
So in this PR, I just want to try reverting the change and see if things can be simplified.
isn't torch.Tensor.copy_ a legit method?
We don't seem to have solid way to support the op with decent complexity in implementation and cost of compile time and performance while it doesn't seem that we have many use at the moment.
How's this branch going?
still we could write an equivalent function even without ltorch.copy_ using prims.copy_ but that function would not be executed in eager mode.
BTW, would it be possible to audit prims.copy_s in a region on nvfuser executor side?
Another alternative could be:
diff --git a/thunder/core/functionalization.py b/thunder/core/functionalization.py
index aac6f150..a91e502a 100644
--- a/thunder/core/functionalization.py
+++ b/thunder/core/functionalization.py
@@ -315,6 +315,40 @@ def canonicalize_bsym_args(
return intermediate_trace, reverse_swap_map
+def audit_raw_copy(computation_trace: Trace, copy_bsyms: list[BoundSymbol]) -> Trace:
+ producer_map, consumer_map = producers(computation_trace), consumers(computation_trace)
+ copy_outs, copy_dsts = [], []
+ bsym_to_id: dict[BoundSymbol, int] = {bsym: i for i, bsym in enumerate(computation_trace.bound_symbols)}
+ copy_bsym_to_filter: list[BoundSymbol] = []
+
+ bsym: BoundSymbol
+ for bsym in copy_bsyms:
+ idx_of_bsym = bsym_to_id[bsym]
+ out = bsym.flat_proxy_outs[0]
+ check(out not in consumer_map, lambda: f"`prims.copy_` output {out=} is used inside a trace")
+ dst = bsym.flat_proxy_args[1]
+ if dst not in consumer_map:
+ continue
+ consumer_of_dst = tuple(filter(lambda bsym: bsym.sym.id != prims.PrimIDs.RETURN and bsym_to_id[bsym] > idx_of_bsym, consumer_map[dst]))
+ if not consumer_of_dst:
+ continue
+ check(
+ all(bsym.sym.id == prims.PrimIDs.COPY_ for bsym in consumer_of_dst),
+ lambda: f"copy destination of {dst} has consumers other than {prims.PrimIDs.RETURN} and {prims.PrimIDs.COPY_}",
+ )
+ copy_bsym_to_filter.append(bsym)
+
+ if not copy_bsym_to_filter:
+ return computation_trace
+
+ trace = from_trace(computation_trace)
+ set_of_redundant_copy_bsym = set(copy_bsym_to_filter)
+ trace.bound_symbols.extend(list(filter(lambda bsym: bsym not in set_of_redundant_copy_bsym, computation_trace.bound_symbols)))
+ trace.set_provenance(TraceProvenance("`prims.copy_` audit"))
+
+ return trace
+
+
def create_functional_bsym_from(inplace_bsym: BoundSymbol) -> BoundSymbol:
from thunder.torch import _inplace_to_out_of_place, setitem_, setitem
@@ -911,6 +945,8 @@ def functionalize_inplace_ops(
"""
if not any(is_functionalizable(bsym) for bsym in computation_trace.bound_symbols):
+ if (copy_bsyms := [bsym for bsym in computation_trace.bound_symbols if bsym.sym.id == prims.PrimIDs.COPY_]):
+ return [audit_raw_copy(computation_trace, copy_bsyms)]
return []
# Step 0:
diff --git a/thunder/tests/test_inplace_copy.py b/thunder/tests/test_inplace_copy.py
index e94a6624..0190e35f 100644
--- a/thunder/tests/test_inplace_copy.py
+++ b/thunder/tests/test_inplace_copy.py
@@ -182,9 +182,5 @@ def test_inplace_copy_dst_copy_returned_issue_1109(executor, device, dtype):
assert_close(t0, expected)
assert_close(actual_t2, expected_t2)
- # FIXME(crcrpar): Since there's no `ltorch.Tensor.copy_`, functions like `func` would not
- # be observed and executed with pytorch eager mode. Though there should be either an audit of
- # `prims.copy_` in a nvfuser region and/or what #1110 did.
- assert actual_t1.data_ptr() == actual_t2.data_ptr()
- with pytest.raises(AssertionError):
- assert_close(actual_t1, expected_t1)
+ assert actual_t1.data_ptr() != actual_t2.data_ptr()
+ assert_close(actual_t1, expected_t1)