Add Parameter Group Support to MLX Framework
Issue Description:
Feature Request
Summary: I propose adding support for parameter groups in MLX to enhance the flexibility and customization of model optimization.
Details: The addition of parameter groups would enable users to group and apply different optimization configurations to specific subsets of model parameters. This is a common feature in many deep learning frameworks and can significantly improve the efficiency of training and fine-tuning models.
Expected Behavior:
- Introduce a mechanism to create parameter groups.
- Allow users to specify different optimization configurations for each parameter group.
- Ensure that the optimization process correctly applies the specified configurations to the corresponding parameter groups during training.
Motivation:
- Parameter groups provide a powerful tool for users to fine-tune optimization strategies for specific parts of a model.
- Enhances the adaptability of MLX to a wider range of deep learning tasks and scenarios.
Example Usage:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
model = nn.Sequential([...])
optimizer = optim.SGD(learning_rate=0.01)
parameter_groups = [
{'params': model.layer1.parameters(), 'lr': 0.001},
{'params': model.layer2.parameters(), 'lr': 0.005},
{'params': model.layer3.parameters(), 'lr': 0.01},
]
optimizer.add_param_groups(parameter_groups)
@awni your thoughts please!
Thanks @m0saan. I'm not sure we need parameter groups yet. Let's keep this issue open but I would mark it as low priority until we have reason to observe otherwise.
In MLX it's a lot easier to have multiple optimizers each working on a subset of the model since things are a bit more decoupled than in PyTorch. So this is an instance where the added functionality doesn't make as much sense for us.