improve shape accuracy in transform output by providing better tooling
Some transforms, notably FSDP and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.).
We might consider either making it easy for the transform to update the shapes on the fly or at the end of the transform by providing a mechanism similar / based on to the interpret_trace_to_trace introduced in #1164 (which would then become stricter by default).
cc: @crcrpar
cc @carmocca @borda
Not updating the shapes is odd; curious why this is an 'enhancement' vs. a real bug. marking triage review to discuss!
I don't object to having it labeled bug, but I don't think typically users will hit it today: For things needing gradients, the augmented forward pass will fix it. I'm not sure that Tensor Parallel sees heavy use in inference just yet. Among the things I found while doing #1164, it seems one of the milder issues.
Some transforms, notably and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.).
I thought the logic implemented with visitor_transform, in https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L118, more specifically, https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L145 updates shapes automatically. but it isn't?
Yeah, @crcrpar, so this is why I mentioned better tooling, maybe using the visitor transform pattern more is the solution. Looking at the code, I would probably have the same expectation as you, but if I take out this
https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/core/trace_interpreter.py#L120-L127
there seem to be inconsistencies. Maybe it is some other part (and I certainly anything I wrote will have the problem) that does not update shapes fully.
Looking at what happened: I think while the visitor re-executes all things not replaced, it does so with the old outputs, not the new ones (in contrast to interpret_trace / interpret_trace_to_trace which have the env thing to keep track of what to swap).
To my mind, it will be worthwhile to see whether an extended visitor pattern or perhaps the interpret_trace_to_trace itself will be best to solve this. WDYT?