mlx
mlx copied to clipboard
Proposal: add dtype on nn module initialization
Currently, all modules such as nn.Linear create weights and biases in float32 default dtype. So If I am not mistaken, to use other dtypes one would first init the model and then change the dtype for example with:
l = nn.Linear(1, 2)
l.update(tree_map(lambda x: x.astype(mx.float16), l.parameters()))
I am proposing to save allocating all the memory in float32 first and add a dytpe parameter to all the nn classes, like PyTorch does it.
The only thing to align on is whether to make the default explicit and transparent:
def __init__(self, input_dims: int, output_dims: int, bias: bool = True, dtype: mx.Dtype = mx.float32):
or let the internal functions decide on the default (like PyTorch).
def __init__(self, input_dims: int, output_dims: int, bias: bool = True, dtype: Optional[mx.Dtype] = None):
I'd prefer the more explicit form dtype: mx.Dtype = mx.float32 even though this spread the default all over the code base.
Having dtype: Optional[mx.Dtype] = None would require some additional changes as, for example, mx.random.uniform does not accept dtype=None at the moment. But of course that can be fixed.
Let me know if you would be interested in a PR on this feature.
I recently discovered that the MLX approach for accomplishing this task is as follows:
model.apply(lambda x: x.astype(mx.float16))
Given that parameters are initialized lazily, it seems unnecessary to pass the data type to __init__. A more sensible approach might be to introduce a method like:
class Module(dict):
def set_dtype(self, dtype: mx.Dtype):
self.apply(lambda x: x.astype(dtype))
This is reminiscent of PyTorch's type method. However, I find the naming choice made by PyTorch less appealing.
It is indeed a commonly used method. Just thinking out loud: what do you think about Module.astype?
Also torch has Module.to which is I think is the more commonly used method for casting to a different type. I don't mind that name as much.
It is indeed a commonly used method. Just thinking out loud: what do you think about
Module.astype?Also torch has
Module.towhich is I think is the more commonly used method for casting to a different type. I don't mind that name as much.
FYI, astype is part of the python array API specification (which allows interoperability with scikit-learn, einops, etc.), so it'd be nice to support that for consistency.
@altaic this is at the nn.Module level. We already have astype at the array level :). I find the consistency with that pleasing, but I'm curious on other's thoughts.
I would anticipate that both to and astype functions would return a copy of the module casted to the correct type, mirroring the behavior seen in arrays. While this holds true for arrays using astype, it diverges for torch modules, which are modified in place when to is invoked.
From an implementation perspective, opting for in-place modification proves significantly more straightforward. Here, I outline several options along with accompanying comments for such a method:
astypemight cause confusion if it performs in-place modification on modules but returns a copy for arrays.- The term 'to' is quite generic and, in my perception, implies the creation of a new object, akin to
array.tolistor numpy'stostring,tobytes, andtofile. set_dtypealigns closely with what NumPy uses for in-place modification, such assetflagsorsetfield. Whilesetdtypeis an alternative, it appears less readable to me.- Using
dtypeclosely resembles PyTorch'stypemethod but offers more expressiveness from my standpoint. The term 'type' reminds me personally of the object's type, as intype(obj), and lacks consistency with the other instances where it is termed 'dtype'. apply_dtyperemains a viable option, considering the implementation's reliance on 'apply'.
Consequently, my personal preference leans toward set_dtype.
I am prepared to create a pull request for this tiny feature, employing the preferred name that resonates with everyone.
That is a very good point about in-place vs copy.
In my opinion set_dtype > setdtype > apply_dtype (which I don't really like).
Unless anyone feels differently, I would suggests we go with set_dtype. You've convinced me it's the best option. Would be great if you would send a PR.
And thank you for the super thoughtful commentary / discussion!
Regarding the name, I agree that set_dtype > the rest. I am not too big of a fan of adding too many methods on nn.Module but this is common enough that it might warrant a method.
Given that we are adding a method, I was thinking that we may make it a bit more useful by providing a type list to change. Something like the following:
def set_dtype(self, dtype, types_to_change: List[dtype] = [mx.float32, mx.float16, mx.bfloat16]):
self.apply(lambda x: x.astype(dtype) if x.dtype in types_to_change else x)
WDYT? This allows calling set_dtype on quantized models for instance and it would still work fine. Or storing integer parameters for something like a mask and they wouldn't be affected.
In the NumPy world, I would opt for a more generic approach:
def set_dtype(self, dtype, predicate = lambda x: mx.issubdtype(x, mx.floating)):
self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)
However, implementing this approach would necessitate mirroring the type hierarchy found in NumPy. I've provided a draft #427 to illustrate this direction. Please let me know if this goes into the right direction. If so, I can proceed to address the open TODOs on tests and docs.