Add structured logging for tensor fakeification
Stack from ghstack (oldest at bottom):
- -> #126879
This adds dumps of MetaTensorDesc and MetaStorageDesc to structured logs when they are triggered from Dynamo. The logs look like this:
V0522 08:13:25.267000 140224882566144 torch/_subclasses/meta_utils.py:195] {"describe_storage": {"id": 0, "describer_id": 0, "size": 32}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
V0522 08:13:25.267000 140224882566144 torch/_subclasses/meta_utils.py:220] {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [8], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "<built-in method _view_func_unsafe of Tensor object at 0x7f882959e840>", "describer_id": 0}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
V0522 08:13:25.268000 140224882566144 torch/_subclasses/meta_utils.py:1594] {"describe_source": {"describer_id": 0, "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0}
The describer_id is used to disambiguate ids. We expect it to be
unique per frame id, but if there is a bug it possibly is not. Note you will get
redundant dumps when evaluation restarts.
tlparse can use this to give a visualization of input tensors to a model, you could also use this to generate example inputs to run graphs on.
Some care is taken to avoid redumping the tensor metadata multiple times, which would happen ordinarily because AOTAutograd refakifies everything after Dynamo, to deal with metadata mutation.
Partially fixes https://github.com/pytorch/pytorch/issues/126644
Signed-off-by: Edward Z. Yang [email protected]
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/126879
- :page_facing_up: Preview Python docs built from this PR
- :page_facing_up: Preview C++ docs built from this PR
- :question: Need help or want to give feedback on the CI? Visit the bot commands wiki or our office hours
Note: Links to docs will display an error until the docs builds have been completed.
:heavy_exclamation_mark: 1 Active SEVs
There are 1 currently active SEVs. If your PR is affected, please view them below:
:x: 2 New Failures, 2 Unrelated Failures
As of commit cfeec3e1a3ff54193056ee1befb65717667d886f with merge base 0910429d7262daf67dc3aa1d4e4aa939752ae675 ():
NEW FAILURES - The following jobs have failed:
- pull / linux-focal-cuda12.1-py3.10-gcc9 / test (default, 4, 5, linux.4xlarge.nvidia.gpu) (gh)
test_ops_fwd_gradients.py::TestFwdGradientsCUDA::test_fn_fwgrad_bwgrad_linalg_lu_factor_cuda_complex128 - trunk / macos-13-py3-arm64 / test (default, 2, 3, macos-m1-stable) (gh)
An error was encountered when uploading logs-runattempt1-test-default-2-3-macos-m1-stable_25569204117.zip. There were 1 items that failed to upload.
UNSTABLE - The following jobs failed but were likely due to flakiness present on trunk and has been marked as unstable:
- inductor / cuda12.1-py3.10-gcc9-sm86 / test (dynamic_inductor_timm, 2, 2, linux.g5.4xlarge.nvidia.gpu) (gh) (#127438)
sebotnet33ts_256 - inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal, unstable) (gh) (#126993)
Process completed with exit code 1.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
cc @bhack
Thanks,
Is this enough to isolate a failing compiled function in a minimal repro format?
E.g. If we consider a recent report stacktrace https://github.com/pytorch/pytorch/issues/126614#issuecomment-2122567229 that failure could be generated by a parent def decorated compilation or not.
In the simplest case is the direct decorated function at https://github.com/pytorch/pytorch/issues/121504#issue-2176370853.
As you see the compiled forward is def forward(self, q, k, v):.
So if we want to really create a minimal repro in python and isolate that function from the rest of the code I suppose that we need to find a way to create fake q, k, v tensors but also to serialize something from the class for self.
Or do you think we could have another solution about fast feeding minimal repro in compilers tickets?
In the case we are really not interested in the python source code at all for compilation errors reporting (but I am really not sure about this point) probably we could just highlight to copy/save the inductor generated specific failure code to the user.
In the mentioned case e.g. the compiled code already know that we have failed at:
File "/tmp/torchinductor_root/ut/cutmbnzthsr64p23ilpnn2ym54twqj4lwpqj5v3shylgqucshcur.py", line 660
So we could just suggest to the user to attach that one to the ticket + this tensor fakeification info.
I fixed the memo problem: I can't actually hold on to MetaTensorDesc as it will keep real tensors live lol
CI is green here
@pytorchbot merge -i
Merge started
Your change will be merged while ignoring the following 4 checks: pull / linux-focal-cuda12.1-py3.10-gcc9 / test (default, 4, 5, linux.4xlarge.nvidia.gpu), inductor / linux-jammy-cpu-py3.8-gcc11-inductor / test (inductor_torchbench_cpu_smoketest_perf, 1, 1, linux.24xl.spr-metal, unstable), inductor / cuda12.1-py3.10-gcc9-sm86 / test (dynamic_inductor_timm, 2, 2, linux.g5.4xlarge.nvidia.gpu), trunk / macos-13-py3-arm64 / test (default, 2, 3, macos-m1-stable)
Learn more about merging in the wiki.
Questions? Feedback? Please reach out to the PyTorch DevX TeamAdvanced Debugging
Check the merge workflow status
here