transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Chameleon: add model

Open zucchini-nlp opened this issue 1 year ago • 27 comments

What does this PR do?

Fixes #31505.

Adds Chameleon, a vision language model from Meta AI.

from transformers import ChameleonForCausalLM, ChameleonProcessor
from PIL import Image
import requests
import torch

model_path = "MODEL_PATH"
model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
processor = ChameleonProcessor.from_pretrained(model_path)

prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)

inputs = processor(prompt, images=[image], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
out = model.generate(**inputs, max_new_tokens=40, do_sample=False)
generated_text = processor.batch_decode(out, skip_special_tokens=False)[0]
print(f"Generated text: {generated_text}")

# Multi-image example
prompt = "I used to know a lot about constellations when I was younger, but as I grew older, I forgot most of what I knew. These are the only two constellations that I really remember now.<image><image>I would like for you to tell me about 3 more constellations and give me a little bit of history about the constellation."
image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)

inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, dtype=torch.bfloat16)
out = model.generate(**inputs, max_new_tokens=200, do_sample=False)
generated_text = processor.batch_decode(out, skip_special_tokens=True)[0]
print(f"Generated text: {generated_text}")
>>> 

Project repo: https://github.com/facebookresearch/chameleon Paper: https://arxiv.org/abs/2405.09818v1

zucchini-nlp avatar Jun 21 '24 09:06 zucchini-nlp

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@amyeroberts @ArthurZucker

jsm69 avatar Jun 22 '24 09:06 jsm69

The modeling /processingcode is done and passes all the tests with dummy weights. I looked in transformers for similar models to replace VQEncoder with a simple call to vision backbone, but there is nothing like that yet. The VQModel seems to be similar to the one we have in diffusers, and I believe we can add it as a model of its own in the future PRs if we get more models like Chameleon.

The issue currently is the generation quality with the actual weights. For text-only generation I found the bug and now it's generating well. For image-text cases the model generates garbage and I also cannot get sensible output from the original repo code. The original repo always generates a refusal to answer my question... @jacobkahn, could you take a look pls?

zucchini-nlp avatar Jun 24 '24 11:06 zucchini-nlp

@zucchini-nlp — great to hear things are working.

In terms of generating output from images + text reliably, one strategy that works is to narrow the scope of questions asked about images; for example, an text prompt asking Is there a cat in this image? + an image of a cat. This works nicely for me with the miniviewer:

image

I'll run some local tests on my end with these prompts and compare to our reference impl

jacobkahn avatar Jun 24 '24 17:06 jacobkahn

Ready for review!

The model conversion is fixed, thanks to Arthur for spotting the bug. Now we have to convert and upload the weights to Meta org on hub, so that I can change all model checkpoints in the docs/tests. cc @jacobkahn

zucchini-nlp avatar Jun 26 '24 10:06 zucchini-nlp

Two quick notes:

  1. The assertion in the conversion script for the 7B model for the test generation fails when run on CPU (still generates reasonable output)
AssertionError: Generations don't match: The image you provided is a drawing by the surrealist artist Salvador Dalí, titled "The Persistence of Memory" (1931). It is a classic example of Dal != The image you provided is a drawing by the artist, M.C. Escher. Born in the Netherlands in 1898, Escher was a master of optical illusions and impossible drawings

should we remove it? 2. I'm seeing a shape mismatch when I convert the 30B model; looks like it's related to the weird RoPE embedding layouts I'd asked about a while back... https://gist.github.com/jacobkahn/05d35aba27f1dee66c714f706054c107 Happy to dive in deeper here again if needed

jacobkahn avatar Jun 26 '24 16:06 jacobkahn

  1. Yes, maybe we don't need assertion then. A bit weird that outputs are completely different though, I will check out and change it.
  2. Hmm, that's weird, I will see what's happening tomorrow

zucchini-nlp avatar Jun 26 '24 17:06 zucchini-nlp

The PR is ready. The only moment that needs to be done is uploading weights to the hub (after we find what's the issue with 30b model's image module with @jacobkahn ) and replacing everything in modeling/tests/docs with the actual model ids.

zucchini-nlp avatar Jun 27 '24 12:06 zucchini-nlp

@zucchini-nlp yay!! Almost there. A few things:

  • 7B model looks great overall, works across a bunch of setups too
  • I can't seem to run the 30B model: I get this when running on a DGX-A100; I'm not familiar with how HF sharding/device mgmt is implemented though

jacobkahn avatar Jun 28 '24 15:06 jacobkahn

Hmm, probably we need to manually move tye residual to the same device as hidden states after attn module. Btw, I was running on one A100 gpu, it fits perfectly with bf16.

I will not be available for a week, so ping @gante (slack is faster to get a reply) in case you need further help ;)

As highlighted, 30b with images is the only thing that's left to fix

zucchini-nlp avatar Jun 28 '24 16:06 zucchini-nlp

Hi all,

Out of curiosity, is this model code similar with the actual Chameleon code? Or someone has access to the actual model architecture code? Because Chameleon Team officially did not release the training code, including model architecture. Did I miss something? If this is the actual code, then I will be happy to finetuning the model for our dataset and do further research :)

Joeycho avatar Jul 03 '24 18:07 Joeycho

@Joeycho this is inference code modeled roughly off of https://github.com/facebookresearch/chameleon and the existing Llama recipe in Transformers.

jacobkahn avatar Jul 03 '24 18:07 jacobkahn

Hey folks, I got the 30B and 7B model working with the following diff:

diff --git a/src/transformers/models/chameleon/configuration_chameleon.py b/src/transformers/models/chameleon/configuration_chameleon.py
index faf56e56b..464e9c9b9 100644
--- a/src/transformers/models/chameleon/configuration_chameleon.py
+++ b/src/transformers/models/chameleon/configuration_chameleon.py
@@ -206,6 +206,7 @@ class ChameleonConfig(PretrainedConfig):
         attention_bias=False,
         attention_dropout=0.0,
         qk_layernorm=True,
+        model_parallel_size=1,
         swin_norm=False,
         vq_config=None,
         vocabulary_map=None,
@@ -235,6 +236,7 @@ class ChameleonConfig(PretrainedConfig):
         self.attention_bias = attention_bias
         self.attention_dropout = attention_dropout
         self.qk_layernorm = qk_layernorm
+        self.model_parallel_size = model_parallel_size
         self.swin_norm = swin_norm
 
         if vq_config is None:
diff --git a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py
index 9687de670..3b43c2937 100644
--- a/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py
+++ b/src/transformers/models/chameleon/convert_chameleon_weights_to_hf.py
@@ -91,6 +91,7 @@ def write_model(model_path, input_base_path, model_size, chameleon_version=1):
     if os.path.isfile(consolidate_params_path):
         params = {**params, **read_json(consolidate_params_path)}
     num_shards = NUM_SHARDS[model_size]
+    model_parallel_size = params["model_parallel_size"]
     params = params.get("model", params)
     n_layers = params["n_layers"]
     n_heads = params["n_heads"]
@@ -165,16 +166,16 @@ def write_model(model_path, input_base_path, model_size, chameleon_version=1):
             if qk_layernorm:
                 state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = loaded[
                     f"layers.{layer_i}.attention.q_normalization.weight"
-                ]
+                ].unsqueeze(0).expand(n_heads, -1).contiguous()
                 state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = loaded[
                     f"layers.{layer_i}.attention.q_normalization.bias"
-                ]
+                ].unsqueeze(0).expand(n_heads, -1).contiguous()
                 state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = loaded[
                     f"layers.{layer_i}.attention.k_normalization.weight"
-                ]
+                ].unsqueeze(0).expand(num_key_value_heads, -1).contiguous()
                 state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = loaded[
                     f"layers.{layer_i}.attention.k_normalization.bias"
-                ]
+                ].unsqueeze(0).expand(num_key_value_heads, -1).contiguous()
 
         else:
             # Sharded
@@ -207,18 +208,19 @@ def write_model(model_path, input_base_path, model_size, chameleon_version=1):
             ).reshape(key_value_dim, dim)
 
             if qk_layernorm:
-                state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = torch.stack(
-                    [l[f"layers.{layer_i}.attention.q_normalization.weight"] for l in loaded]
-                ).mean(dim=0)
-                state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = torch.stack(
-                    [l[f"layers.{layer_i}.attention.q_normalization.bias"] for l in loaded]
-                ).mean(dim=0)
-                state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = torch.stack(
-                    [l[f"layers.{layer_i}.attention.k_normalization.weight"] for l in loaded]
-                ).mean(dim=0)
-                state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = torch.stack(
-                    [l[f"layers.{layer_i}.attention.k_normalization.bias"] for l in loaded]
-                ).mean(dim=0)
+                state_dict[f"model.layers.{layer_i}.self_attn.q_norm.weight"] = torch.cat(
+                    [l[f"layers.{layer_i}.attention.q_normalization.weight"].unsqueeze(0).expand(n_heads, -1) for l in loaded]
+                )
+                state_dict[f"model.layers.{layer_i}.self_attn.q_norm.bias"] = torch.cat(
+                    [l[f"layers.{layer_i}.attention.q_normalization.bias"].unsqueeze(0).expand(n_heads, -1) for l in loaded]
+                )
+                state_dict[f"model.layers.{layer_i}.self_attn.k_norm.weight"] = torch.cat(
+                    [l[f"layers.{layer_i}.attention.k_normalization.weight"].unsqueeze(0).expand(num_key_value_heads, -1) for l in loaded]
+                )
+                state_dict[f"model.layers.{layer_i}.self_attn.k_norm.bias"] = torch.cat(
+                    [l[f"layers.{layer_i}.attention.k_normalization.bias"].unsqueeze(0).expand(num_key_value_heads, -1) for l in loaded]
+                )
+
             state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
                 [
                     loaded[i][f"layers.{layer_i}.attention.wv.weight"].view(
@@ -315,6 +317,7 @@ def write_model(model_path, input_base_path, model_size, chameleon_version=1):
         rope_theta=base,
         max_position_embeddings=max_position_embeddings,
         qk_layernorm=qk_layernorm,
+        model_parallel_size=model_parallel_size,
         swin_norm=swin_norm,
         vq_config=vq_config,
         vocabulary_map=vocabulary_map,
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index 3cbad5482..d55da3cbe 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -206,8 +206,10 @@ class ChameleonMLP(nn.Module):
 
     # Ignore copy
     def forward(self, x):
-        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
-        return down_proj
+        device = self.down_proj.weight.device
+        agx = self.act_fn(self.gate_proj(x)).to(device)
+        ux = self.up_proj(x).to(device)
+        return self.down_proj(agx * ux)
 
 
 # Copied from transformers.models.llama.modeling_llama.repeat_kv
@@ -247,6 +249,7 @@ class ChameleonAttention(nn.Module):
         self.rope_theta = config.rope_theta
         self.is_causal = True
         self.qk_layernorm = config.qk_layernorm
+        self.model_parallel_size = config.model_parallel_size
 
         if (self.head_dim * self.num_heads) != self.hidden_size:
             raise ValueError(
@@ -259,8 +262,8 @@ class ChameleonAttention(nn.Module):
         self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
         self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias)
         if self.qk_layernorm:
-            self.q_norm = nn.LayerNorm(self.head_dim)
-            self.k_norm = nn.LayerNorm(self.head_dim)
+            self.q_norm = nn.LayerNorm((self.num_heads * self.model_parallel_size, self.head_dim))
+            self.k_norm = nn.LayerNorm((self.num_key_value_heads * self.model_parallel_size, self.head_dim))
         self._init_rope()
 
     # Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->Chameleon
@@ -309,12 +312,33 @@ class ChameleonAttention(nn.Module):
         value_states = self.v_proj(hidden_states)
 
         if self.qk_layernorm:
-            # reshape for layernorm
-            query_states = query_states.view(-1, self.num_heads, self.head_dim)
-            key_states = key_states.view(-1, self.num_key_value_heads, self.head_dim)
-
-            query_states = self.q_norm(query_states)
-            key_states = self.k_norm(key_states)
+            if self.model_parallel_size == 1:
+                query_states = query_states.reshape(-1, 1, self.head_dim)
+                query_states = query_states.tile((1, self.num_heads, 1))
+                query_states = self.q_norm(query_states)
+                query_states = query_states[:, 0, :]
+
+                key_states = key_states.reshape(-1, 1, self.head_dim)
+                key_states = key_states.tile((1, self.num_key_value_heads, 1))
+                key_states = self.k_norm(key_states)
+                key_states = key_states[:, 0, :]
+
+            elif self.model_parallel_size > 1:
+                query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+                query_states = query_states.tile((1, self.model_parallel_size, 1))
+                query_states = query_states.reshape(-1, self.model_parallel_size, self.num_heads, self.head_dim).transpose(1, 2)
+                query_states = query_states.reshape(-1, self.model_parallel_size * self.num_heads, self.head_dim)
+                query_states = self.q_norm(query_states)
+                query_states = query_states.reshape(-1, self.model_parallel_size, self.head_dim)
+                query_states = query_states[:, 0, :]
+
+                key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+                key_states = key_states.tile((1, self.model_parallel_size, 1))
+                key_states = key_states.reshape(-1, self.model_parallel_size, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+                key_states = key_states.reshape(-1, self.model_parallel_size * self.num_key_value_heads, self.head_dim)
+                key_states = self.k_norm(key_states)
+                key_states = key_states.reshape(-1, self.model_parallel_size, self.head_dim)
+                key_states = key_states[:, 0, :]
 
         # permute key/value to use transformers RoPE implementation (see for more: https://github.com/huggingface/transformers/issues/25199)
         # NOTE: permutation is done same way as in llama conversion script
@@ -400,12 +424,33 @@ class ChameleonFlashAttention2(ChameleonAttention):
         value_states = self.v_proj(hidden_states)
 
         if self.qk_layernorm:
-            # reshape for layernorm
-            query_states = query_states.view(-1, self.num_heads, self.head_dim)
-            key_states = key_states.view(-1, self.num_key_value_heads, self.head_dim)
-
-            query_states = self.q_norm(query_states)
-            key_states = self.k_norm(key_states)
+            if self.model_parallel_size == 1:
+                query_states = query_states.reshape(-1, 1, self.head_dim)
+                query_states = query_states.tile((1, self.num_heads, 1))
+                query_states = self.q_norm(query_states)
+                query_states = query_states[:, 0, :]
+
+                key_states = key_states.reshape(-1, 1, self.head_dim)
+                key_states = key_states.tile((1, self.num_key_value_heads, 1))
+                key_states = self.k_norm(key_states)
+                key_states = key_states[:, 0, :]
+
+            elif self.model_parallel_size > 1:
+                query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+                query_states = query_states.tile((1, self.model_parallel_size, 1))
+                query_states = query_states.reshape(-1, self.model_parallel_size, self.num_heads, self.head_dim).transpose(1, 2)
+                query_states = query_states.reshape(-1, self.model_parallel_size * self.num_heads, self.head_dim)
+                query_states = self.q_norm(query_states)
+                query_states = query_states.reshape(-1, self.model_parallel_size, self.head_dim)
+                query_states = query_states[:, 0, :]
+
+                key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+                key_states = key_states.tile((1, self.model_parallel_size, 1))
+                key_states = key_states.reshape(-1, self.model_parallel_size, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+                key_states = key_states.reshape(-1, self.model_parallel_size * self.num_key_value_heads, self.head_dim)
+                key_states = self.k_norm(key_states)
+                key_states = key_states.reshape(-1, self.model_parallel_size, self.head_dim)
+                key_states = key_states[:, 0, :]
 
         # permute key/value to use transformers RoPE implementation (see for more: https://github.com/huggingface/transformers/issues/25199)
         # NOTE: permutation is done same way as in llama conversion script
@@ -612,12 +657,33 @@ class ChameleonSdpaAttention(ChameleonAttention):
         value_states = self.v_proj(hidden_states)
 
         if self.qk_layernorm:
-            # reshape for layernorm
-            query_states = query_states.view(-1, self.num_heads, self.head_dim)
-            key_states = key_states.view(-1, self.num_key_value_heads, self.head_dim)
-
-            query_states = self.q_norm(query_states)
-            key_states = self.k_norm(key_states)
+            if self.model_parallel_size == 1:
+                query_states = query_states.reshape(-1, 1, self.head_dim)
+                query_states = query_states.tile((1, self.num_heads, 1))
+                query_states = self.q_norm(query_states)
+                query_states = query_states[:, 0, :]
+
+                key_states = key_states.reshape(-1, 1, self.head_dim)
+                key_states = key_states.tile((1, self.num_key_value_heads, 1))
+                key_states = self.k_norm(key_states)
+                key_states = key_states[:, 0, :]
+
+            elif self.model_parallel_size > 1:
+                query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+                query_states = query_states.tile((1, self.model_parallel_size, 1))
+                query_states = query_states.reshape(-1, self.model_parallel_size, self.num_heads, self.head_dim).transpose(1, 2)
+                query_states = query_states.reshape(-1, self.model_parallel_size * self.num_heads, self.head_dim)
+                query_states = self.q_norm(query_states)
+                query_states = query_states.reshape(-1, self.model_parallel_size, self.head_dim)
+                query_states = query_states[:, 0, :]
+
+                key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+                key_states = key_states.tile((1, self.model_parallel_size, 1))
+                key_states = key_states.reshape(-1, self.model_parallel_size, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+                key_states = key_states.reshape(-1, self.model_parallel_size * self.num_key_value_heads, self.head_dim)
+                key_states = self.k_norm(key_states)
+                key_states = key_states.reshape(-1, self.model_parallel_size, self.head_dim)
+                key_states = key_states[:, 0, :]
 
         # permute key/value to use transformers RoPE implementation (see for more: https://github.com/huggingface/transformers/issues/25199)
         # NOTE: permutation is done same way as in llama conversion script
@@ -803,7 +869,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
             **kwargs,
         )
         hidden_states = self.input_layernorm(hidden_states)
-        hidden_states = residual + hidden_states
+        hidden_states = residual.to(hidden_states.device) + hidden_states
         # Fully Connected
         residual = hidden_states
         hidden_states = self.mlp(hidden_states)

lshamis avatar Jul 03 '24 21:07 lshamis

@jacobkahn Thank you for the prompt answer. But this means via 'the existing Llama recipe' is is possible fine-tuning like-chameleon model? Or does this PR deal with 2 separate things (1. inference code of Chameleon, 2. editing LLaMA)?

I'm curious whether I can fine-tune the Chameleon model. Ofc, I have to come up with how to fine-tune (which parts to be frozen and which part to train).

Joeycho avatar Jul 04 '24 07:07 Joeycho

What’s needed to get this PR merged?

EwoutH avatar Jul 08 '24 07:07 EwoutH

@EwoutH almost there, just need to apply changes for sharded inference in 30b model. I was off for a week and will work on it tomorrow.

zucchini-nlp avatar Jul 08 '24 08:07 zucchini-nlp

Should the code in this pull request be able to generate images, provided that the model is further fine-tuned not to avoid generating images or a new model is trained with the same architecture with image generation capabilities?

ethanc8 avatar Jul 08 '24 16:07 ethanc8

Should the code in this pull request be able to generate images, provided that the model is further fine-tuned not to avoid generating images or a new model is trained with the same architecture with image generation capabilities?

@ethanc8 — the model checkpoints we're releasing don't have image generation capabilities, so the HF implementation won't support it. As always, you can fork and modify as you see fit.

jacobkahn avatar Jul 08 '24 20:07 jacobkahn

I'm curious whether I can fine-tune the Chameleon model. Ofc, I have to come up with how to fine-tune (which parts to be frozen and which part to train).

@Joeycho — we're not releasing finetuning code; this PR will only have inference code. If you want to finetune the model, the paper can give you a place to start.

jacobkahn avatar Jul 08 '24 20:07 jacobkahn

the model checkpoints we're releasing don't have image generation capabilities

@jacobkahn In what manner does the model not have image generation capabilities? It doesn't look like any components are actually missing (the VQGAN was released), so I assumed that the model was finetuned to refuse to generate images.

ethanc8 avatar Jul 08 '24 22:07 ethanc8

Pushed changes for qk layernorm and tested that it works for both checkpoints. Locally tests are all passing, except for slow ones. So the last step is now to run CI slow tests for this PR and that requires a model to be available to download from the hub.

@jacobkahn when are we planning to make the hf hub repo public?

zucchini-nlp avatar Jul 09 '24 08:07 zucchini-nlp

@zucchini-nlp — just made the 7B and 30B repos public; will populate them once this PR is merged. @ArthurZucker any other blockers to merge?

jacobkahn avatar Jul 09 '24 14:07 jacobkahn

I just came back from holidays, reviewing in abit !

ArthurZucker avatar Jul 09 '24 16:07 ArthurZucker

Added a label and commit to trigger slow tests but I am not sure if they will run, because it requires an access token.

EDIT: Ah, I see there're no weights yet so no point in running slow CI

zucchini-nlp avatar Jul 10 '24 09:07 zucchini-nlp

GAIR has finetuned Chameleon-7B to support image generation. (weights on HuggingFace, in PyTorch format, inference and finetuning code). Will this PR's code be able to support Chameleon models that have been finetuned to generate images such as GAIR's model, or will that need to be addressed in a separate PR?

ethanc8 avatar Jul 10 '24 14:07 ethanc8

@ethanc8 this PR doesn't support image generation and forces image tokens to have ~0 probability of getting chosen as the next token

zucchini-nlp avatar Jul 10 '24 14:07 zucchini-nlp

@zucchini-nlp — weights are up in https://huggingface.co/facebook/chameleon-7b and https://huggingface.co/facebook/chameleon-30b, we can run slow CI now with those repo slugs

jacobkahn avatar Jul 10 '24 14:07 jacobkahn