Train working
Currently this trains on MLX on llama-3.2-3B, but I had to pull a different version of it because I guess the MLX ops needed to train on quantized models are not yet implemented
Definitely still a toy, but it does as advertised: This is distributed training across an exo network cluster
Very interesting. @dtnewman what's your take on the training?
Very interesting. @dtnewman what's your take on the training?
My overall take is that @blindcrone is a brilliant engineer, doing amazing work here. I think it will be awesome when we can all train our own models, locally. I can't wait to see this built out further.
Some high level notes though:
-
Running command
exo traindoesn't work. It throwsError: This train ain't leaving the station without a model. I think this would be more helpful if it were `Error: This train ain't leaving the station without a model (e.g. command 'exo train llama-3.2-1b'). -
The readme should probably be updated to explain how this works (although I do understand this is a work in progress, so maybe that's for later)
-
Could use some better error handling...
_process_examplein standard_node.py currently returnsNonewhen an exception is called. So when I run this with 'exo train llama-3.2-1b', it fails, but the error gets somewhat obscured since it returns anything at all. I'm not sure that it's worth return anything in the error case there, rather than just raising an exception. -
@blindcrone I'm assuming you built this against llama-3.2-3b since that's what's modified in the
model_cardsin models.py. I'm assuming that if you flippedmlx-community/Llama-3.2-1B-Instruct-4bittomlx-community/Llama-3.2-1B-Instructit would probably work for the smaller model too? -
I am running it now and it seems to be working so far, although I'm only 12/1000 steps in. I might need to run this overnight instead :)
Awesome feedback @dtnewman, indeed great work from @blindcrone, thank you so much for the effort!
I'm probably around 2/1000 steps in reading this PR, so trying to see what would be the architecture choice if training worked as good as inference. (what kind of bottlenecks to expect).
- Aye, like "exo run" this is mostly meant for testing and demo purposes, you'll need to feed it a specific model as it doesn't yet hook up to the UX
- Sure is. I think this needs to have tinygrad support and some either explanation or UX hooks before it's ready for most people to use, as there are definitely some caveats (see 4)
- A lot of this was intended to get the training working so we have something to build an actual workflow off of, and this was a pretty involved process that included tons of debug output that would be really obnoxious to keep in the main repo. I've cleaned most of the instrumentation I used to develop this, but agree that better error messages are necessary to support this being used more widely, and will come from targeting use cases rather than a minimal example
- Quantized models seemingly cannot be trained in MLX right now (which makes sense from theory, you'd need some bespoke ops to get gradients out of them), thus you need a model with the full original weights to train it. It's also probably why MLX tends to only suggest training LoRAs, whereas this PR is for full fine-tuning on the specified layers. Further work on this will include probably trying to solve this problem via offering both paths to users, as I want to support LoRA training, but also requantization as a preprocessing step for switching models between training and inference modes. I also have some more bizarre experiments I'm targeting for taking advantage of the distributed nature of exo, but that's future work for sure
- Yea training takes a long time, especially without big GPUs.
The bottlenecks are interesting here, because rhe asynchronous nature of the communication and the heavy operation that is backpropagation means that most of the communication overhead is effectively "hidden" by being able to happen while the system overall is waiting for training steps
Since training on exo clusters is a whole new use case, I think the features of it will need to be built out over time, and in conversation with the community as they play with this capability on their clusters. However, some deeper architectural decisions need to be made in order to support this functionality, so I'm requesting this merge specifically because I'm considering the future and want to avoid massive conflicts with other work people are doing as more of it gets built out
Please see unresolved threads @blindcrone
Okay so at this point it seems like removing the abstract base class as part of the requirements for merging this is starting to produce conflicts. I think I've resolved every outstanding issue and will try to integrate the changes from main into the reconciled node class, but maintaining this as a separate branch will continue to be less feasible the more other changes happen elsewhere now that this refactor is part of it
Fixed the conflicts AFAICT, hopefully this can be merged now