lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

allow tracing models on meta device

Open t-vi opened this issue 1 year ago • 6 comments

We want to facilitate running models that can only fit into memory after transforms. The current main thing in PyTorch to instantiate a model without memory is through meta devices.

So the plan is to enable:

  • Trace through model on meta device with real input (say on GPU).
    • This runs into trouble when checking for "same device". We currently have a few hacks around it here and there, but it would be good to have a more structured way.
    • My idea would be to add __thunder_device to the meta tensors (and/or maybe as a default option for compilation) and then use that device for proxying meta tensors.
  • Have transforms handle things on meta, ideally staying on meta. (I'll work this out for the bnb quant, and then we can see if/how to adapt the distributed transforms.)
  • Have a materialization transform that materializes models, optionally with initialization from the original code (using the transform_state_dict_for_submodule).

cc @apaz-cli

t-vi avatar Jul 25 '24 08:07 t-vi

I see this as the materialization transform:

@requiresCUDA
def test_materialization():
    from thunder.transforms import MaterializationTransform

    config = litgpt_model.Config.from_name("llama2-like")
    with torch.device("cuda"):
        ref_m = litgpt_model.GPT(config).to(torch.bfloat16)
    with torch.device("meta"):
        m = litgpt_model.GPT(config).to(torch.bfloat16)
    for p in m.parameters():
        p.__thunder_device = torch.device("cuda")
    for b in m.buffers():
        p.__thunder_device = torch.device("cuda")

    init_from_sd = MaterializationTransform.from_original_state_dict(ref_m.state_dict())
    jm = thunder.jit(
        m,
        transforms=MaterializationTransform("cuda", init=init_from_sd)
    )
    x = torch.randint(1, 255, (1, 10), device="cuda")
    input_pos = torch.arange(10, device="cuda")

    expected = ref_m(x, input_pos)
    actual = jm(x, input_pos)
    
    assert_close(actual, expected)

wdyt?

t-vi avatar Jul 25 '24 13:07 t-vi

Looks super neat, to simplify we could also do

    jm = thunder.jit(
        m,
        transforms=MaterializationTransform("cuda", init=ref_m.state_dict())
    )

lantiga avatar Jul 25 '24 13:07 lantiga

I thought that would be neat, but this is what kept me from offering: we could have

  • no init (but garbage on first run, so only for people who do evil stuff between creating traces / prologue and running),
  • orig module init (through sd transforms),
  • orig state dict init (through sd transforms),
  • thunder module state dict.

And there we have two state dicts I would not know how to differentiate.

t-vi avatar Jul 25 '24 14:07 t-vi

aah makes sense

lantiga avatar Jul 25 '24 15:07 lantiga

With #867 and #868 we have initial support, but three of the four modes are yet to be fleshed out:

  • [ ] no init (but garbage on first run, so only for people who do evil stuff between creating traces / prologue and running),
  • [ ] orig module init (through sd transforms), - can probably heavily borrow from the distributed materialization,
  • [x] orig state dict init (through sd transforms) included in #868,
  • [ ] thunder module state dict.

t-vi avatar Jul 25 '24 21:07 t-vi

What's the benefit of allowing meta device tensors interact with other devices and not error out? What could be the alternatives? Have you considered using PyTorch's FakeTensor for initialization instead of plain meta tensors?

IvanYashchuk avatar Aug 02 '24 07:08 IvanYashchuk

The advantage is to be able to trace cuda (or CPU) inputs with a network with meta weights and then transform the traced module to have cuda (or CPU) weights. Fake tensors might work, too, but it seems that meta tensors give the best user experience for now.

t-vi avatar Nov 11 '24 08:11 t-vi

This has now been achieved.

t-vi avatar Nov 11 '24 08:11 t-vi