lightning-thunder
lightning-thunder copied to clipboard
allow tracing models on meta device
We want to facilitate running models that can only fit into memory after transforms. The current main thing in PyTorch to instantiate a model without memory is through meta devices.
So the plan is to enable:
- Trace through model on meta device with real input (say on GPU).
- This runs into trouble when checking for "same device". We currently have a few hacks around it here and there, but it would be good to have a more structured way.
- My idea would be to add
__thunder_deviceto the meta tensors (and/or maybe as a default option for compilation) and then use that device for proxying meta tensors.
- Have transforms handle things on meta, ideally staying on meta. (I'll work this out for the bnb quant, and then we can see if/how to adapt the distributed transforms.)
- Have a materialization transform that materializes models, optionally with initialization from the original code (using the transform_state_dict_for_submodule).
cc @apaz-cli
I see this as the materialization transform:
@requiresCUDA
def test_materialization():
from thunder.transforms import MaterializationTransform
config = litgpt_model.Config.from_name("llama2-like")
with torch.device("cuda"):
ref_m = litgpt_model.GPT(config).to(torch.bfloat16)
with torch.device("meta"):
m = litgpt_model.GPT(config).to(torch.bfloat16)
for p in m.parameters():
p.__thunder_device = torch.device("cuda")
for b in m.buffers():
p.__thunder_device = torch.device("cuda")
init_from_sd = MaterializationTransform.from_original_state_dict(ref_m.state_dict())
jm = thunder.jit(
m,
transforms=MaterializationTransform("cuda", init=init_from_sd)
)
x = torch.randint(1, 255, (1, 10), device="cuda")
input_pos = torch.arange(10, device="cuda")
expected = ref_m(x, input_pos)
actual = jm(x, input_pos)
assert_close(actual, expected)
wdyt?
Looks super neat, to simplify we could also do
jm = thunder.jit(
m,
transforms=MaterializationTransform("cuda", init=ref_m.state_dict())
)
I thought that would be neat, but this is what kept me from offering: we could have
- no init (but garbage on first run, so only for people who do evil stuff between creating traces / prologue and running),
- orig module init (through sd transforms),
- orig state dict init (through sd transforms),
- thunder module state dict.
And there we have two state dicts I would not know how to differentiate.
aah makes sense
With #867 and #868 we have initial support, but three of the four modes are yet to be fleshed out:
- [ ] no init (but garbage on first run, so only for people who do evil stuff between creating traces / prologue and running),
- [ ] orig module init (through sd transforms), - can probably heavily borrow from the distributed materialization,
- [x] orig state dict init (through sd transforms) included in #868,
- [ ] thunder module state dict.
What's the benefit of allowing meta device tensors interact with other devices and not error out? What could be the alternatives? Have you considered using PyTorch's FakeTensor for initialization instead of plain meta tensors?
The advantage is to be able to trace cuda (or CPU) inputs with a network with meta weights and then transform the traced module to have cuda (or CPU) weights. Fake tensors might work, too, but it seems that meta tensors give the best user experience for now.
This has now been achieved.