mlx-examples
mlx-examples copied to clipboard
[Feature Request] Example of MLLM using MLX
With ml-ferret out, it would be great to include an MLLM example in this repo, namely with ml-ferret or just LlaVA itself. Being LLAMA based, I think this would primarily require just implementing the image encoding step.
Agreed, that would be super cool. If anyone is interested in contributing it let us know!
Agreed, that would be super cool. If anyone is interested in contributing it let us know!
I would like to contribute to this example soon (perhaps this weekend). Some time ago, I proposed implementing ViT (CLIP) https://github.com/ml-explore/mlx-examples/issues/143. Thus, I think it is cool to have an implementation of ViT (CLIP) that perfectly matches the one used in ferret. This enables us to quickly implement ferret. According to ferret implementation, the image encoder just wraps CLIP.
However, it is not the case that implementing ferret is just implementing image encoding. This is because ferret also uses bilinear interpolation using torch.functional.grid_sample
. We do not have this in MLX currently. However, I recently implemented bilinear interpolation for Upsample2d
(https://github.com/ml-explore/mlx/pull/414).
LLava might be easier to start with as v1.5 is literally just a fine tuned LLama model + CLIP + a small projector layer.
LLava might be easier to start with as v1.5 is literally just a fine tuned LLama model + CLIP + a small projector layer.
I think this is a great idea. I will probably start with ViT-CLIP and then implement LLaVA.
That would be awesome! At a bare minimum an image CLIP encoder will be really helpful. I saw that the Stable Diffusion example has a text CLIP encoder so the image half should just be a few small changes.
I was semi looking at this just to see how hard it is. For reference implementation, the original LLaVA implementation is likely to be easier to work with. The transformers v4.35 rewrite made quite a few changes.
Hey, posting in here since I'm interested in multi-modal. I'm currently trying to convert bakLlava. The model doesn't matter too much to me, but this one worked out of the box for me so went with it.
I'm trying to understand what's the best strategy for converting a model you want to MLX? Is there a guide somewhere?
Right now I'm currently looking at the model's config from HF https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/llava/modeling_llava.py#L237, and trying to build versions of these modules in MLX. Just started, but the idea is to make modules that look like the following, right?
class VisionTower(nn.Module):
# TODO
class MultiModalProjector(nn.Module):
# TODO
class LanguageModel(nn.Module):
# TODO
class Model(nn.Module):
def __init__(self, args: ModelArgs):
self.vision_tower = VisionTower(args)
self.multi_modal_projector = MultiModalProjector(args)
self.language_model = LanguageModel(args)
self.pad_token_id = args.pad_token_id
self.vocab_size = args.vocab_size
def __call__():
# TODO
(This is all I've done so far)
Curious if you all had a different strategy in mind? Seems like starting with CLIP is the best approach, so I'm also going to keep an eye on how that's going / let me know if you need any help.
Hey, posting in here since I'm interested in multi-modal. I'm currently trying to convert bakLlava. The model doesn't matter too much to me, but this one worked out of the box for me so went with it.
I'm trying to understand what's the best strategy for converting a model you want to MLX? Is there a guide somewhere?
Right now I'm currently looking at the model's config from HF https://github.com/huggingface/transformers/blob/bc72b4e2cdcbc80d5f56731f35dbc9c18b4c8de6/src/transformers/models/llava/modeling_llava.py#L237, and trying to build versions of these modules in MLX. Just started, but the idea is to make modules that look like the following, right?
class VisionTower(nn.Module): # TODO class MultiModalProjector(nn.Module): # TODO class LanguageModel(nn.Module): # TODO class Model(nn.Module): def __init__(self, args: ModelArgs): self.vision_tower = VisionTower(args) self.multi_modal_projector = MultiModalProjector(args) self.language_model = LanguageModel(args) self.pad_token_id = args.pad_token_id self.vocab_size = args.vocab_size def __call__(): # TODO
(This is all I've done so far)
Curious if you all had a different strategy in mind? Seems like starting with CLIP is the best approach, so I'm also going to keep an eye on how that's going / let me know if you need any help.
Hi @nkasmanoff :) I am also very interested in porting multimodal models. My end goal is to implement ferret in MLX and I will probably need some help. Here is my (draft) pull request for CLIP (https://github.com/ml-explore/mlx-examples/pull/315). I think it is going well and I am confident that it will be done soon. After we have CLIP, it is relatively simple to port LLaVA to MLX.
In terms of the general approach, I think that a good starting point is to design model structure (e.g. your Python class hierarchy) such that it closely follows transformers
implementations and HuggingFace weights. The next step is to write a script that converts HuggingFace weights to your weights. This resource is useful https://github.com/ml-explore/mlx-examples/issues/155. Here is how I did it in https://github.com/ml-explore/mlx-examples/pull/315. I based my work on the StableDiffusion example (https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
@gboduljak Thank you this is a lot of great info! Will try to catch myself up and help :-)
@gboduljak Thank you this is a lot of great info! Will try to catch myself up and help :-)
Thank you :)
Here are some concrete things were I need some help:
- Verifying whether the initialization of CLIP (https://github.com/ml-explore/mlx-examples/pull/315) is correct (i.e. matches the method implemented in https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/models/clip/)
- Complete implementation of
CLIPTokenizer
(the implementation in PR should work for most of the cases, but it is not complete) - An implementation of CLIP image input pre-processing (we are now using
transformers
CLIPProcessor
) - Removing the dependency on
transformers
- Proofreading
README.md
If you have time, I would appreciate your general review of https://github.com/ml-explore/mlx-examples/pull/315. Sooner we have the satisfactory CLIP implementation, the sooner we can move to LLaVa and/or Ferret.
@gboduljak No problem, but after looking it over, not sure I can be extraordinarily helpful beyond some simpler tasks. This is a lot lower level coding than I'm used to :-). With that being said I can help with 3, 5, and a general review. Don't the tests you have already verify that 1. is complete?
I'm also considering making a demo train.py
file, have you considered something like this? If I come up with something that looks like it works using the code you have here, I'll submit a PR to the branch you already made.
@gboduljak No problem, but after looking it over, not sure I can be extraordinarily helpful beyond some simpler tasks. This is a lot lower level coding than I'm used to :-). With that being said I can help with 3, 5, and a general review. Don't the tests you have already verify that 1. is complete?
I'm also considering making a demo
train.py
file, have you considered something like this? If I come up with something that looks like it works using the code you have here, I'll submit a PR to the branch you already made.
Thank you for taking a look at my code. The tests unfortunately do not verify that our model initialization is correct. They verify whether our model inference is correct. In other words, we test whether our implementation with pretrained weights produces the same output as the transformers
implementation initialized with the same pretrained weights.
Porting CLIPProcessor
would be very useful, because then we can eliminate the dependency on transformers
.
train.py
would also be very interesting. However, we likely cannot run the full training :).
In addition, please ask me to explain the code. Perhaps we should move the further discussion of CLIP to its issue https://github.com/ml-explore/mlx-examples/issues/143.
Closing as we added #461.