mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

[Feature Request] Example of MLLM using MLX

Open fozziethebeat opened this issue 7 months ago • 1 comments

With ml-ferret out, it would be great to include an MLLM example in this repo, namely with ml-ferret or just LlaVA itself. Being LLAMA based, I think this would primarily require just implementing the image encoding step.

fozziethebeat avatar Dec 31 '23 03:12 fozziethebeat

Agreed, that would be super cool. If anyone is interested in contributing it let us know!

awni avatar Dec 31 '23 14:12 awni

Agreed, that would be super cool. If anyone is interested in contributing it let us know!

I would like to contribute to this example soon (perhaps this weekend). Some time ago, I proposed implementing ViT (CLIP) https://github.com/ml-explore/mlx-examples/issues/143. Thus, I think it is cool to have an implementation of ViT (CLIP) that perfectly matches the one used in ferret. This enables us to quickly implement ferret. According to ferret implementation, the image encoder just wraps CLIP.

However, it is not the case that implementing ferret is just implementing image encoding. This is because ferret also uses bilinear interpolation using torch.functional.grid_sample. We do not have this in MLX currently. However, I recently implemented bilinear interpolation for Upsample2d(https://github.com/ml-explore/mlx/pull/414).

gboduljak avatar Jan 11 '24 19:01 gboduljak

LLava might be easier to start with as v1.5 is literally just a fine tuned LLama model + CLIP + a small projector layer.

fozziethebeat avatar Jan 12 '24 04:01 fozziethebeat

LLava might be easier to start with as v1.5 is literally just a fine tuned LLama model + CLIP + a small projector layer.

I think this is a great idea. I will probably start with ViT-CLIP and then implement LLaVA.

gboduljak avatar Jan 12 '24 11:01 gboduljak

That would be awesome! At a bare minimum an image CLIP encoder will be really helpful. I saw that the Stable Diffusion example has a text CLIP encoder so the image half should just be a few small changes.

I was semi looking at this just to see how hard it is. For reference implementation, the original LLaVA implementation is likely to be easier to work with. The transformers v4.35 rewrite made quite a few changes.

fozziethebeat avatar Jan 13 '24 00:01 fozziethebeat

Hey, posting in here since I'm interested in multi-modal. I'm currently trying to convert bakLlava. The model doesn't matter too much to me, but this one worked out of the box for me so went with it.

I'm trying to understand what's the best strategy for converting a model you want to MLX? Is there a guide somewhere?

Right now I'm currently looking at the model's config from HF https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/llava/modeling_llava.py#L237, and trying to build versions of these modules in MLX. Just started, but the idea is to make modules that look like the following, right?


class VisionTower(nn.Module):
    # TODO


class MultiModalProjector(nn.Module):
    # TODO


class LanguageModel(nn.Module):
    # TODO


class Model(nn.Module):
    def __init__(self, args: ModelArgs):

        self.vision_tower = VisionTower(args)
        self.multi_modal_projector = MultiModalProjector(args)

        self.language_model = LanguageModel(args)

        self.pad_token_id = args.pad_token_id

        self.vocab_size = args.vocab_size
        
    def __call__():
        # TODO 


(This is all I've done so far)

Curious if you all had a different strategy in mind? Seems like starting with CLIP is the best approach, so I'm also going to keep an eye on how that's going / let me know if you need any help.

nkasmanoff avatar Jan 14 '24 14:01 nkasmanoff

Hey, posting in here since I'm interested in multi-modal. I'm currently trying to convert bakLlava. The model doesn't matter too much to me, but this one worked out of the box for me so went with it.

I'm trying to understand what's the best strategy for converting a model you want to MLX? Is there a guide somewhere?

Right now I'm currently looking at the model's config from HF https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/llava/modeling_llava.py#L237, and trying to build versions of these modules in MLX. Just started, but the idea is to make modules that look like the following, right?


class VisionTower(nn.Module):
    # TODO


class MultiModalProjector(nn.Module):
    # TODO


class LanguageModel(nn.Module):
    # TODO


class Model(nn.Module):
    def __init__(self, args: ModelArgs):

        self.vision_tower = VisionTower(args)
        self.multi_modal_projector = MultiModalProjector(args)

        self.language_model = LanguageModel(args)

        self.pad_token_id = args.pad_token_id

        self.vocab_size = args.vocab_size
        
    def __call__():
        # TODO 

(This is all I've done so far)

Curious if you all had a different strategy in mind? Seems like starting with CLIP is the best approach, so I'm also going to keep an eye on how that's going / let me know if you need any help.

Hi @nkasmanoff :) I am also very interested in porting multimodal models. My end goal is to implement ferret in MLX and I will probably need some help. Here is my (draft) pull request for CLIP (https://github.com/ml-explore/mlx-examples/pull/315). I think it is going well and I am confident that it will be done soon. After we have CLIP, it is relatively simple to port LLaVA to MLX.

In terms of the general approach, I think that a good starting point is to design model structure (e.g. your Python class hierarchy) such that it closely follows transformers implementations and HuggingFace weights. The next step is to write a script that converts HuggingFace weights to your weights. This resource is useful https://github.com/ml-explore/mlx-examples/issues/155. Here is how I did it in https://github.com/ml-explore/mlx-examples/pull/315. I based my work on the StableDiffusion example (https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).

gboduljak avatar Jan 14 '24 14:01 gboduljak

@gboduljak Thank you this is a lot of great info! Will try to catch myself up and help :-)

nkasmanoff avatar Jan 14 '24 20:01 nkasmanoff

@gboduljak Thank you this is a lot of great info! Will try to catch myself up and help :-)

Thank you :)

Here are some concrete things were I need some help:

  1. Verifying whether the initialization of CLIP (https://github.com/ml-explore/mlx-examples/pull/315) is correct (i.e. matches the method implemented in https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/clip/)
  2. Complete implementation of CLIPTokenizer (the implementation in PR should work for most of the cases, but it is not complete)
  3. An implementation of CLIP image input pre-processing (we are now using transformers CLIPProcessor)
  4. Removing the dependency on transformers
  5. Proofreading README.md

If you have time, I would appreciate your general review of https://github.com/ml-explore/mlx-examples/pull/315. Sooner we have the satisfactory CLIP implementation, the sooner we can move to LLaVa and/or Ferret.

gboduljak avatar Jan 14 '24 22:01 gboduljak

@gboduljak No problem, but after looking it over, not sure I can be extraordinarily helpful beyond some simpler tasks. This is a lot lower level coding than I'm used to :-). With that being said I can help with 3, 5, and a general review. Don't the tests you have already verify that 1. is complete?

I'm also considering making a demo train.py file, have you considered something like this? If I come up with something that looks like it works using the code you have here, I'll submit a PR to the branch you already made.

nkasmanoff avatar Jan 15 '24 17:01 nkasmanoff

@gboduljak No problem, but after looking it over, not sure I can be extraordinarily helpful beyond some simpler tasks. This is a lot lower level coding than I'm used to :-). With that being said I can help with 3, 5, and a general review. Don't the tests you have already verify that 1. is complete?

I'm also considering making a demo train.py file, have you considered something like this? If I come up with something that looks like it works using the code you have here, I'll submit a PR to the branch you already made.

Thank you for taking a look at my code. The tests unfortunately do not verify that our model initialization is correct. They verify whether our model inference is correct. In other words, we test whether our implementation with pretrained weights produces the same output as the transformers implementation initialized with the same pretrained weights.

Porting CLIPProcessor would be very useful, because then we can eliminate the dependency on transformers. train.py would also be very interesting. However, we likely cannot run the full training :).

In addition, please ask me to explain the code. Perhaps we should move the further discussion of CLIP to its issue https://github.com/ml-explore/mlx-examples/issues/143.

gboduljak avatar Jan 15 '24 18:01 gboduljak

Closing as we added #461.

awni avatar Mar 02 '24 14:03 awni