lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

improve shape accuracy in transform output by providing better tooling

Open t-vi opened this issue 1 year ago • 5 comments

Some transforms, notably FSDP and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.). We might consider either making it easy for the transform to update the shapes on the fly or at the end of the transform by providing a mechanism similar / based on to the interpret_trace_to_trace introduced in #1164 (which would then become stricter by default).

cc: @crcrpar

cc @carmocca @borda

t-vi avatar Sep 20 '24 12:09 t-vi

Not updating the shapes is odd; curious why this is an 'enhancement' vs. a real bug. marking triage review to discuss!

tfogal avatar Sep 20 '24 18:09 tfogal

I don't object to having it labeled bug, but I don't think typically users will hit it today: For things needing gradients, the augmented forward pass will fix it. I'm not sure that Tensor Parallel sees heavy use in inference just yet. Among the things I found while doing #1164, it seems one of the milder issues.

t-vi avatar Sep 20 '24 18:09 t-vi

Some transforms, notably and TensorParallel ones, change shapes, but currently do not completely update them (it does for the linear that follows, but not for the activation etc.).

I thought the logic implemented with visitor_transform, in https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L118, more specifically, https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/distributed/tensor_parallel/common.py#L145 updates shapes automatically. but it isn't?

crcrpar avatar Sep 26 '24 03:09 crcrpar

Yeah, @crcrpar, so this is why I mentioned better tooling, maybe using the visitor transform pattern more is the solution. Looking at the code, I would probably have the same expectation as you, but if I take out this

https://github.com/Lightning-AI/lightning-thunder/blob/f70b0fe22d27a2e2a0a6ac31cf2c52c9d4e4e35d/thunder/core/trace_interpreter.py#L120-L127

there seem to be inconsistencies. Maybe it is some other part (and I certainly anything I wrote will have the problem) that does not update shapes fully.

t-vi avatar Sep 26 '24 07:09 t-vi

Looking at what happened: I think while the visitor re-executes all things not replaced, it does so with the old outputs, not the new ones (in contrast to interpret_trace / interpret_trace_to_trace which have the env thing to keep track of what to swap).

To my mind, it will be worthwhile to see whether an extended visitor pattern or perhaps the interpret_trace_to_trace itself will be best to solve this. WDYT?

t-vi avatar Sep 27 '24 08:09 t-vi