Add Gemma 3n
Edit: I can quantize the model using mlx-vlm, so I will be able to test this after we fix the sanitization.
I think the configs are now being created correctly, but I'm running into a lot of problems related to loading the model.
I'm going to leave it here for now and wait for experts to add their input. I can't do this alone.
I need help understanding what's going wrong during the model loading. You can test it by running llm-tool in Xcode. This is the current debug output:
Loading mlx-community/gemma-3n-E2B-it-bf16...
🔍 Gemma3n.sanitize: Starting with 1556 weights
🔍 Gemma3n.sanitize: After prefix removal, have 1556 weights
Error: Key originalInvFreq not found in Gemma3nRotaryEmbedding
Program ended with exit code: 1
Resolving this error will reveal many more like it, so there must be a more systemic fix.
Error: Key originalInvFreq not found in Gemma3nRotaryEmbedding
This is the loading code confusing model weights with computed weights. The pattern (in mlx python as well) is to name the computed parameters with a leading underscore:
let _invFreq: MLXArray
let _originalInvFreq: MLXArray
this matches the python code:
self._inv_freq = mx.array(inv_freq, dtype=mx.float32)
self._original_inv_freq = mx.array(inv_freq, dtype=mx.float32)
as it is this expects that those "weights" be present on load.
Working on an update, stand by
See also https://github.com/ml-explore/mlx-swift/issues/251 -- this might make doing some of this easier.
OK, I think I fixed these key issues:
- computed values need to be named with a leading underscore
- some of the keys on the modules were missing
I found this in the python code:
self.query = NamedSequential()
self.query.add_module(
"proj",
create_conv2d(
dim,
self.num_heads * self.key_dim,
kernel_size=1,
),
)
corresponding to:
private class MultiQueryAttention2d: Module {
@ModuleInfo var queryProj: Conv2d
but I am not sure what that does (yet). The weights didn't have these values (it must be optional).
The current failure is:
Error: Mismatched parameter weight shape. Actual [1536, 1024], expected [1536, 640]
FWIW, here is how I debugged these. First I set breakpoints on the throws of the errors seen here:
open func update(parameters: ModuleParameters, verify: VerifyUpdate) throws -> Self {
func apply(key: String, _ item: ModuleItem, _ value: NestedItem<String, MLXArray>) throws {
if case .none = value, !verify.contains(.allModelKeysSet) {
return
}
// item: single item from `items()`
// value: single item with matching structure from `parameters()`
//
// match them up and apply the MLXArrays from value -> item
switch (item, value) {
case (.value(.parameters(let p)), .value(let newArray)):
if verify.contains(.all), p.shape != newArray.shape {
throw UpdateError.mismatchedSize(
key: key, expectedShape: p.shape, actualShape: newArray.shape)
}
p._updateInternal(newArray)
case (.value(.parameters(let p)), .none):
if Self.parameterIsValid(key) {
throw UpdateError.keyNotFound(base: describeType(self), key: key)
} else {
// ignore it -- this isn't a parameter that requires update
}
then I can look at the call stack:
notice the "key" value. You can also print the module value to see which module it is. E.g. in frame 3:
(lldb) po self
ConvNormAct(outChannels=768) {
bn: RMSNormAct2d(applyAct=true, eps=1e-05) {
act: GELU(approximation=none),
drop: Identity,
},
conv: Conv2d(bias=nil, groups=768),
}
Gives this error (there are a couple of these and the keys are in a dictionary so you hit them in random order):
Error: Mismatched parameter weight shape. Actual [768, 5, 5, 1], expected [768, 5, 5, 768]
Looking into these mismatched sized, I am not sure how this is supposed to work -- I wonder if the python side doesn't verify sizes?
For example:
self.dw_start = ConvNormAct(
in_chs,
in_chs,
The input/output channels will be the same so the conv block under that should be shape [X, .., .., X] but the safetensors has:
If I disable the validation of sizes in the call to update:
try model.update(parameters: parameters, verify: [])
it makes it past that but fails here:
Error: keyNotFound(CodingKeys(stringValue: "vision_soft_tokens_per_image", intValue: nil), Swift.DecodingError.Context(codingPath: [], debugDescription: "No value associated with key CodingKeys(stringValue: \"vision_soft_tokens_per_image\", intValue: nil) (\"vision_soft_tokens_per_image\").", underlyingError: nil))
which is curious because 1) that key is present (it could be looking in a different file) and 2) the value is optional, so why the complaint?
// MLX Swift currently doesn't have custom Metal kernel creation capabilities like Python's
// mx.fast.metal_kernel(). Consider optimizing with vectorized MLX operations or requesting
// custom kernel support from the MLX Swift team for better performance.
Wait, isn't it supported via MLXFast? https://github.com/ml-explore/mlx-swift/blob/b79c74ce773440b86a81ef925ea78dd5023a16c0/Source/MLXFast/MLXFastKernel.swift#L29
Example https://github.com/ml-explore/mlx-swift/blob/b79c74ce773440b86a81ef925ea78dd5023a16c0/Tests/MLXTests/MLXFastKernelTests.swift#L38-L51
Thank you! As you can see, that's something I meant to follow up on once the model is working. Claude 4 Sonnet was convinced that it's not possible to write custom Metal kernels in MLX Swift, which is why you see the incorrect comment there.
LLMs are getting better at writing MLX code in Swift, but clearly they still have gaps in their knowledge. I'm hoping that by adding more Swift ports, we can fix that for future versions.
I added some complicated fixes for sanitization, which have resolved some but not all of the problems. These are probably not the right solution, since the sanitization in Python is much simpler. But the original logic that closely followed the Python implementation wasn't working.
Currently it's failing when it tries to assign to blocks in VisionTower.
@Blaizzy, can you help me understand what's going wrong with the weights sanitization here? Is this something that should already be handled during the conversion of the model to MLX format? I used mlx-vlm to quantize the model to 4 bits.
When I try to run mlx-community/gemma-3n-E2B-it-4bit in mlx-vlm, I get this error: ValueError: [conv] Invalid input array with type uint32. Convolution currently only supports floating point types
I'm testing the same model in this PR, since I can't run the bf16 model on my MacBook Pro with 16 GB of RAM.
@DePasqualeOrg see https://github.com/Blaizzy/mlx-vlm/issues/400 and https://github.com/Blaizzy/mlx-vlm/pull/398
I wanted to see how the blocks.blocks were loaded so I tried this too but I can't load the model:
python -m mlx_vlm.generate --model mlx-community/gemma-3n-E4B-bf16 --prompt "describe these images in english" --image /Users/dkoski/Desktop/IMG_0691.jpeg
...
File "/Users/dkoski/miniconda3/envs/mlx/lib/python3.11/site-packages/mlx/nn/layers/base.py", line 178, in load_weights
raise ValueError(f"Received parameters not in model: {extras}.")
ValueError: Received parameters not in model: language_model.lm_head.weight.
Sure enough, there is no lm_head property in the model (I am building from source, ebafa5a789ed1a8e050b8366ae4e845dbe640b90)
Are you able to successfully load a model? If so, which one?
I already updated this for https://github.com/Blaizzy/mlx-vlm/pull/398, but I see that today there were some significant changes in mlx-vlm. I'll wait for the Python implementation to stabilize before proceeding here.
OK, I went back to 7a36f2eda1a304e4ef89fef874971c94352ab5d4 (398) and used mlx-community/gemma-3n-E2B-it-bf16 and that loads, so I can look at how that works.
In mlx-community/gemma-3n-E2B-it-bf16 the blocks look like this:
"model.vision_tower.timm_model.blocks.0.0.bn1.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.0.bn2.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.0.conv_exp.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.0.conv_pwl.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.1.bn1.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.1.bn2.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.1.conv_exp.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.1.conv_pwl.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.2.bn1.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.2.bn2.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.2.conv_exp.weight": "model-00001-of-00003.safetensors",
"model.vision_tower.timm_model.blocks.0.2.conv_pwl.weight": "model-00001-of-00003.safetensors",
which looks like an array of arrays or an array of tuples (the latter is pretty common in vision models).
That corresponds to:
blocks = []
in_chs = self.conv_stem.out_chs
for stage, block_config in enumerate(gemma3n_mobilenet_def()):
block_group = []
for config in block_config:
...
blocks is an array of block_group.
However the sanitize() is converting that from blocks.0.0 to blocks.blocks.0, which I guess is ultimately VisionTower.sanitize(weights:)
OK, next issue is:
@ModuleInfo var weight: MLXArray?
should be:
@ParameterInfo var weight: MLXArray?
or simply omit the ParameterInfo since we don't need to override the key.
After that the NamedSequential in MultiQueryAttention2d handles output.proj.weight where it would normally be output_proj.weight -- this needs to be added (I had called this out earlier as a bit of a mystery as I couldn't yet see how it worked).
Anyway, this is a little bit tricky to implement as it does this in python:
def add_module(self, name, module):
setattr(self, name, module)
so it dynamically adds properties that will later be picked up by Module. We can do something like that by overriding items():
private class NamedSequential: Module, UnaryLayer {
var _items = ModuleItems()
var _names = [String]()
override init() {
}
init(_ name: String, _ module: UnaryLayer) {
super.init()
self.add(name, module)
}
func add(_ name: String, _ module: UnaryLayer) {
_items[name] = .value(.module(module))
_names.append(name)
}
override func items() -> ModuleItems {
_items
}
func callAsFunction(_ x: MLXArray) -> MLXArray {
var x = x
for name in names {
guard let module = items[name]?.module else {
fatalError("Cannot find \(name) in items")
}
x = module(x)
}
return x
}
}
items() is the cache of the introspected values -- we are going to supply our own custom build items.
But, seeing how it is actually used, I think this is simpler and more clear:
private class ProjectionBlock: Module, UnaryLayer {
@ModuleInfo(key: "down_conv") var down: Conv2d?
@ModuleInfo var norm: RMSNormAct2d?
@ModuleInfo var proj: Conv2d
}
Working through some more issues:
- mlx-swift, conv2d and groups was missing a divide by the groups
- this produces a slice
let imgShape = img.shape.suffix(2)-- the array indices are not0, 1 - replaced the interpolation with Upsample
- some config confusion -- vocabSize from the right config
I commented out (with a comment to delete when ready) some of the code in the vision tower sanitize -- the code now matches the python code, I think.
It loads the model and starts to run but fails with:
Shapes (1,682,16,48) and (1,2048,1,1) cannot be broadcast
in rmsNorm2d().
Getting close, I think. This is with mlx-community/gemma-3n-E2B-it-bf16.
BTW, you may need to update your mlx-swift dependency -- once we are ready here I will tag that, but for now there were a couple changes there.
Thank you! It looks like the 4-bit quantized model is now loading, and I get this error: Shapes (1,8) and (1,32) cannot be broadcast.
I will wait until @Blaizzy confirms that the Python implementation is stable before proceeding here.
Heads up: mlx-vlm finally fixes audio and visual issues (according to my checks) and released https://github.com/Blaizzy/mlx-vlm/releases/tag/v0.3.0
It's stable @DePasqualeOrg 👌🏽
v0.3.0 fixes all Gemma3n bugs
@xlab yap, audio was working fine the only issues where:
- The multimodal merging didn't allow for audio + vision features because we were returning each separately.
- A few parts of the vision module convolution needed padding.
- The Jax conv weights were transposed whilst converting to Torch so OCR is broken unless you transpose the image's HW before processing. Note: The deepmind team is aware and they will fix the weights soon.
I'll wait for the Gemma team to fix the remaining issues, but also there are at least two other efforts to port this model that I'm aware of. I think we should pick one to focus on, to avoid duplicating labor. I'm happy to let others take this on if they're interested in completing the task.
I tried to verify the functionality of the model on iOS and unfortunately ended up with some errors.
Head dimension param default value should be IMHO 2048
public var headDim: Int {
_headDim ?? 2048
}
And the fatat error: (don`t know how to resolve)
MLXNN/Module.swift:519: Fatal error: Unable to set vision_tower.timm_model.conv_stem.conv.bias on Gemma3n.Gemma3nVisionModel.VisionTower.ConvNormAct.Conv2d: none not compatible with [64]