mlx-examples icon indicating copy to clipboard operation
mlx-examples copied to clipboard

Adding support for mamba

Open Goekdeniz-Guelmez opened this issue 1 year ago • 3 comments

Goekdeniz-Guelmez avatar Jul 23 '24 10:07 Goekdeniz-Guelmez

Awesome! Does it work yet?

awni avatar Jul 23 '24 21:07 awni

Awesome! Does it work yet?

Almost, I can create and generate with the sead created model but still have some issues loading a pretrained one. The problem arises from the Conv1D layer, it expects shape [a, b, a] but gets [a, b, c]. I'll give you the original error later, when im on my mac.

Goekdeniz-Guelmez avatar Jul 24 '24 09:07 Goekdeniz-Guelmez

Update!!!

The loading problem is fixed but I now get this error in inference:

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/generate.py", line 161, in <module>
    main()
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/generate.py", line 148, in main
    generate(
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/utils.py", line 315, in generate
    for (token, logprobs), n in zip(
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/utils.py", line 220, in generate_step
    y, logprobs = model.generate_step(prompt)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/models/mamba.py", line 371, in generate_step
    next_token_logits, caches = self(input_ids, caches)
                                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/models/mamba.py", line 314, in __call__
    out, cache = self.backbone(inputs, cache)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/models/mamba.py", line 301, in __call__
    h, cache[i] = layer(tokens, cache[i])
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/models/mamba.py", line 286, in __call__
    output, cache = self.mixer(self.norm(inputs), cache)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/gokdenizgulmez/Desktop/mlx-examples-drafts/llms/mlx_lm/models/mamba.py", line 271, in __call__
    x = self.conv1d(mx.concatenate([inputs, x_cache], axis=1))[:, self.conv_kernel_size-1, :]
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: [concatenate] All the input arrays must have the same number of dimensions. However, got arrays with dimensions 3 and 4.

The input shapes are:

Shape of inputs: (1, 3, 1536)
Shape of x_cache: (1, 1, 5, 3072)

Goekdeniz-Guelmez avatar Jul 24 '24 11:07 Goekdeniz-Guelmez