mlx icon indicating copy to clipboard operation
mlx copied to clipboard

Proposal: add dtype on nn module initialization

Open dastrobu opened this issue 1 year ago • 4 comments
trafficstars

Currently, all modules such as nn.Linear create weights and biases in float32 default dtype. So If I am not mistaken, to use other dtypes one would first init the model and then change the dtype for example with:

l = nn.Linear(1, 2)
l.update(tree_map(lambda x: x.astype(mx.float16), l.parameters()))

I am proposing to save allocating all the memory in float32 first and add a dytpe parameter to all the nn classes, like PyTorch does it.

The only thing to align on is whether to make the default explicit and transparent:

    def __init__(self, input_dims: int, output_dims: int, bias: bool = True, dtype: mx.Dtype = mx.float32):

or let the internal functions decide on the default (like PyTorch).

    def __init__(self, input_dims: int, output_dims: int, bias: bool = True, dtype: Optional[mx.Dtype] = None):

I'd prefer the more explicit form dtype: mx.Dtype = mx.float32 even though this spread the default all over the code base. Having dtype: Optional[mx.Dtype] = None would require some additional changes as, for example, mx.random.uniform does not accept dtype=None at the moment. But of course that can be fixed.

Let me know if you would be interested in a PR on this feature.

dastrobu avatar Dec 25 '23 17:12 dastrobu

I recently discovered that the MLX approach for accomplishing this task is as follows:

model.apply(lambda x: x.astype(mx.float16))

Given that parameters are initialized lazily, it seems unnecessary to pass the data type to __init__. A more sensible approach might be to introduce a method like:

class Module(dict):
    def set_dtype(self, dtype: mx.Dtype):
        self.apply(lambda x: x.astype(dtype))

This is reminiscent of PyTorch's type method. However, I find the naming choice made by PyTorch less appealing.

dastrobu avatar Jan 09 '24 00:01 dastrobu

It is indeed a commonly used method. Just thinking out loud: what do you think about Module.astype?

Also torch has Module.to which is I think is the more commonly used method for casting to a different type. I don't mind that name as much.

awni avatar Jan 09 '24 00:01 awni

It is indeed a commonly used method. Just thinking out loud: what do you think about Module.astype?

Also torch has Module.to which is I think is the more commonly used method for casting to a different type. I don't mind that name as much.

FYI, astype is part of the python array API specification (which allows interoperability with scikit-learn, einops, etc.), so it'd be nice to support that for consistency.

altaic avatar Jan 09 '24 00:01 altaic

@altaic this is at the nn.Module level. We already have astype at the array level :). I find the consistency with that pleasing, but I'm curious on other's thoughts.

awni avatar Jan 09 '24 01:01 awni

I would anticipate that both to and astype functions would return a copy of the module casted to the correct type, mirroring the behavior seen in arrays. While this holds true for arrays using astype, it diverges for torch modules, which are modified in place when to is invoked.

From an implementation perspective, opting for in-place modification proves significantly more straightforward. Here, I outline several options along with accompanying comments for such a method:

  • astype might cause confusion if it performs in-place modification on modules but returns a copy for arrays.
  • The term 'to' is quite generic and, in my perception, implies the creation of a new object, akin to array.tolist or numpy's tostring, tobytes, and tofile.
  • set_dtype aligns closely with what NumPy uses for in-place modification, such as setflags or setfield. While setdtype is an alternative, it appears less readable to me.
  • Using dtype closely resembles PyTorch's type method but offers more expressiveness from my standpoint. The term 'type' reminds me personally of the object's type, as in type(obj), and lacks consistency with the other instances where it is termed 'dtype'.
  • apply_dtype remains a viable option, considering the implementation's reliance on 'apply'.

Consequently, my personal preference leans toward set_dtype. I am prepared to create a pull request for this tiny feature, employing the preferred name that resonates with everyone.

dastrobu avatar Jan 09 '24 20:01 dastrobu

That is a very good point about in-place vs copy.

awni avatar Jan 09 '24 21:01 awni

In my opinion set_dtype > setdtype > apply_dtype (which I don't really like).

awni avatar Jan 09 '24 21:01 awni

Unless anyone feels differently, I would suggests we go with set_dtype. You've convinced me it's the best option. Would be great if you would send a PR.

And thank you for the super thoughtful commentary / discussion!

awni avatar Jan 09 '24 21:01 awni

Regarding the name, I agree that set_dtype > the rest. I am not too big of a fan of adding too many methods on nn.Module but this is common enough that it might warrant a method.

Given that we are adding a method, I was thinking that we may make it a bit more useful by providing a type list to change. Something like the following:

def set_dtype(self, dtype, types_to_change: List[dtype] = [mx.float32, mx.float16, mx.bfloat16]):
    self.apply(lambda x: x.astype(dtype) if x.dtype in types_to_change else x)

WDYT? This allows calling set_dtype on quantized models for instance and it would still work fine. Or storing integer parameters for something like a mask and they wouldn't be affected.

angeloskath avatar Jan 09 '24 21:01 angeloskath

In the NumPy world, I would opt for a more generic approach:

def set_dtype(self, dtype, predicate = lambda x: mx.issubdtype(x, mx.floating)):
    self.apply(lambda x: x.astype(dtype) if predicate(x.dtype) else x)

However, implementing this approach would necessitate mirroring the type hierarchy found in NumPy. I've provided a draft #427 to illustrate this direction. Please let me know if this goes into the right direction. If so, I can proceed to address the open TODOs on tests and docs.

dastrobu avatar Jan 10 '24 23:01 dastrobu