lightning-thunder
lightning-thunder copied to clipboard
`state_dict` transform for `fsdp(jit(model))`
What does this PR do?
Transform for fsdp(jit(model)).state_dict() by all-gathering and unpadding params