mlx-swift icon indicating copy to clipboard operation
mlx-swift copied to clipboard

[BUG] Swift bindings don't allow you to specify the datatype for many MLXNN modules.

Open iankronquist opened this issue 2 weeks ago • 2 comments

Describe the bug Swift bindings don't allow you to specify the datatype for many MLXNN modules.

For instance RMSNorm sets the weights as a let constant, and doesn't have a constructor which allows you to specify the weights array (which would be awkward anyway). https://github.com/ml-explore/mlx-swift/blob/0.25.6/Source/MLXNN/Normalization.swift#L124

Training on bf16 is significantly faster than training on float32 on my M1, but I can't use bfloat16 and the MLXNN.RMSNorm Swift bindings.

To Reproduce

class Transformer: MLXNN.Module {
    let ln_f: MLXNN.RMSNorm
    init(n_embd: Int) {
        self.ln_f = MLXNN.RMSNorm(dimensions: config.n_embd, dtype: MLX.DType.bfloat16) // Error: Extra argument 'dtype' in call
        self.ln_f.weight.dtype = MLX.DType.bfloat16 // Error: Cannot assign to property: 'dtype' is a get-only property
    } 
}

Expected behavior A clear and concise description of what you expected to happen.

Desktop (please complete the following information):

  • OS Version: 15.3 (24D34)
  • Device: Apple M1 Max
  • Version [e.g. 0.10.0]

iankronquist avatar Nov 20 '25 00:11 iankronquist

One fix would be to add a default arg to most modules with weights and biases which defaults to float32 or some unset value. For example:

open class RMSNorm: Module, UnaryLayer {

    public let weight: MLXArray
    public let eps: Float

    public init(dimensions: Int, eps: Float = 1e-5, dtype: MLX.DType = .float32) {
        self.weight = MLXArray.ones([dimensions], dtype: dtype)

iankronquist avatar Nov 20 '25 00:11 iankronquist

We try to match the python API and it doesn't take a dtype. That isn't entirely fair as the python version would let you arbitrarily reassign properties.

I think the way you would do this today is something like this:

model.apply { module, key, item in
    key == "weight"
} map: { array in
    array.asType(.bfloat16)
}

or by examining the type:

model.apply { array in
    if array.dtype == .float32 {
        array.asType(.bfloat16)
    } else {
        array
    }
}

davidkoski avatar Nov 21 '25 19:11 davidkoski