lightning-thunder icon indicating copy to clipboard operation
lightning-thunder copied to clipboard

HF LLaVa support

Open riccardofelluga opened this issue 1 year ago • 6 comments

🚀 Model / language coverage

The idea is to support LLaVa model from HF. This issue is mainly for tracking the status.

Blocking issues:

  • [ ] #735
  • [ ] #124

Minimal Repro

First of all get the transformers library with pip install transformers then run this script:

import torch
import thunder
from transformers import LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
)
model.to("cuda")

input_ids = torch.randint(1, 32000, (1, 22), device="cuda")
attention_mask = torch.ones((1, 22), dtype=torch.int64, device="cuda")
pixel_values = torch.randn((1, 3, 336, 336), device="cuda")
labels = torch.randint(-100, 32000, (1, 22), device="cuda")

# Setup fake image id
input_ids[0, 0] = 1
input_ids[0, 5] = 32000

model = thunder.jit(model, executors=thunder.get_default_executors())

out = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, labels=labels)

riccardofelluga avatar Sep 19 '24 17:09 riccardofelluga

left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))

Note that this looks pretty bad from a "data dependent control flow perspective" and has, indeed, been changed in transformers four months ago.

t-vi avatar Sep 19 '24 18:09 t-vi

@t-vi

Note that this looks pretty bad from a "data dependent control flow perspective" and has, indeed, been changed in transformers four months ago.

Indeed it does look kinda bad :( What do you mean by it has been changed? the line seems to still be there in the file:

https://github.com/huggingface/transformers/blob/4d8908df272c0a9db2e5fbcc8aaed73cdf75442a/src/transformers/models/llava/modeling_llava.py#L284

riccardofelluga avatar Sep 19 '24 18:09 riccardofelluga

Right, I'm stupid. They changed it for modelling_llava_next.py not modelling_llava.py. :(

t-vi avatar Sep 19 '24 18:09 t-vi

Updated the description with the relevant blocking issues.

riccardofelluga avatar Sep 25 '24 08:09 riccardofelluga

@kshitij12345 does the splitter correctly route these ops to the inductor path?

csarofeen avatar Sep 30 '24 13:09 csarofeen

thunderFX side-steps the data-dependent ops and works on the above snippet.

import torch
import thunder
from transformers import LlavaForConditionalGeneration

model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
torch_dtype=torch.bfloat16
)
model.to("cuda")

input_ids = torch.randint(1, 100, (1, 22), device="cuda")
attention_mask = torch.rand((1, 22), device="cuda") > 0.5
pixel_values = torch.randn((1, 3, 336, 336), device="cuda", dtype=torch.bfloat16, requires_grad=True)
labels = torch.randint(0, 100, (1, 22), device="cuda")

# Setup fake image id
input_ids[0, 0] = 1
input_ids[0, 5] = 32000

# # model = thunder.jit(model, executors=thunder.get_default_executors())
# model = torch.compile(model)

import thunder.dynamo
backend = thunder.dynamo.ThunderCompiler(executors=thunder.get_default_executors())
model = torch.compile(model, backend=backend)

out = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, labels=labels)
print(out.loss)  # Loss is detached from the graph.

However, I see that out.loss is detached from the computation graph and we can't call backward on it. This is because of a bug in splitter as it doesn't correctly deal with regions under torch.no_grad. Will file a separate issue for the same and look into fixing it. (EDIT - Issue filed at https://github.com/Lightning-AI/lightning-thunder/issues/1219)

kshitij12345 avatar Sep 30 '24 17:09 kshitij12345