mlx-swift
mlx-swift copied to clipboard
[BUG] Swift bindings don't allow you to specify the datatype for many MLXNN modules.
Describe the bug Swift bindings don't allow you to specify the datatype for many MLXNN modules.
For instance RMSNorm sets the weights as a let constant, and doesn't have a constructor which allows you to specify the weights array (which would be awkward anyway). https://github.com/ml-explore/mlx-swift/blob/0.25.6/Source/MLXNN/Normalization.swift#L124
Training on bf16 is significantly faster than training on float32 on my M1, but I can't use bfloat16 and the MLXNN.RMSNorm Swift bindings.
To Reproduce
class Transformer: MLXNN.Module {
let ln_f: MLXNN.RMSNorm
init(n_embd: Int) {
self.ln_f = MLXNN.RMSNorm(dimensions: config.n_embd, dtype: MLX.DType.bfloat16) // Error: Extra argument 'dtype' in call
self.ln_f.weight.dtype = MLX.DType.bfloat16 // Error: Cannot assign to property: 'dtype' is a get-only property
}
}
Expected behavior A clear and concise description of what you expected to happen.
Desktop (please complete the following information):
- OS Version: 15.3 (24D34)
- Device: Apple M1 Max
- Version [e.g. 0.10.0]
One fix would be to add a default arg to most modules with weights and biases which defaults to float32 or some unset value. For example:
open class RMSNorm: Module, UnaryLayer {
public let weight: MLXArray
public let eps: Float
public init(dimensions: Int, eps: Float = 1e-5, dtype: MLX.DType = .float32) {
self.weight = MLXArray.ones([dimensions], dtype: dtype)
We try to match the python API and it doesn't take a dtype. That isn't entirely fair as the python version would let you arbitrarily reassign properties.
I think the way you would do this today is something like this:
model.apply { module, key, item in
key == "weight"
} map: { array in
array.asType(.bfloat16)
}
or by examining the type:
model.apply { array in
if array.dtype == .float32 {
array.asType(.bfloat16)
} else {
array
}
}