mesozoic-egg

Results 20 comments of mesozoic-egg

I have come to realize that naively sharding the models, then calling all gather to align the axis during training, is the same as just replicating the model across all...

I think the reshard is working well, `extra/fsdp/gpt2.py` example (model size with adam optimizer: 13.3mB) memory usage went from 28mb (no shard) to 17mb per device (2 shards). Right now...

The trick I found was to set device to CPU during initialization, then upon the start of training shard them to each GPU device. Also found out that sharding strategy...

Training llama 4B (8B model with half the layers)

Was able to do a (partial) training run for 4B llama3 model on five 4090s (llama-3 8B with half the `n_layers`) batch size 16, sequence length 16, epoch 500 model...

I trained for 30min on some rented GPUs and here's the result, the test loss is going down consistently ``` Model 18.17 GB Optimizer: 54.49 GB Model params: 4.54 B...

[llama_fsdp_log_20241102_1.txt](https://github.com/user-attachments/files/17605822/llama_fsdp_log_20241102_1.txt) After 4 hr of training on tiny18: ``` Epoch 100 loss: 6.42 test_loss: 7.54 elapsed time: 2.4 min Epoch 200 loss: 7.06 test_loss: 7.23 elapsed time: 3.6 min ......

Now I actually think the pattern matcher is a better approach to remove unnecessary chunks of transfer, instead of coding up the transfer logic directly. I'll work on that at...

I removed the extra conditional in the DEFINE_GLOBAL case since it's always a ptr. The dtype variable itself is still needed though. Normally the register var has the dtype of...

Thx! It's not immediately clear to me how removal of dtype would be done, but maybe after I make more contributions it would become more obvious. Will open a PR...