MLX default weight init seems incorrect
Seems the default mlx weight initialization is incorrect. Specifically github mlx/python/mlx/nn/layers/linear.py and other similar references.
Generally the starting point for weight init is a normal distribution with mean = zero and variance = 1 / input_dim or 2 / input_dim (with relu). The issue is discussed in papers like Xavier Glorot "Understanding the difficulty of training deep feedforward neural networks" or Kaiming He "Delving Deep into Rectifiers.."
Like many, mlx uses a uniform distribution to approximate a normal distribution. In this case, mlx uses a uniform distribution with bounds sqrt(1 / input_dim). Mathematically the variance of a uniform distribution with bounds a b is (b - a)^2 / 12. To maintain a zero mean, we set a = -b. To achieve variance = 1 / input_dim = b^2 / 3. Requires bound b = sqrt(3 / input_dim).
As a comparison, look at github pytorch/torch/nn/modules/linear.py the func reset_parameters() uses init.kaiming_uniform_(). Then within pytorch/torch/nn/init.py kaiming_uniform_() multiplies the bound by sqrt(3).
MLX should make a similar adjustment. General training accuracy should improve as a result.