catwalk
catwalk copied to clipboard
Distributed Data Parallel Training
What is here
Fixes support for distributed training with data parallelism. Previously torch metrics would attempt to synchronize across processes during validation call back and would cause a crash. Also the final model output by the FinetuneStep would be on cpu rather than GPU as is the case for the non-distributed usage; now the returned model is on GPU.
Limitations
Validation is still done in a single process with how data parallelism.
Reproduction
Running with and without multiple devices produces exactly the same validation metrics, though the print out is slightly different due to tasks being copied:
python -m catwalk.train --model rc::gpt2 --task piqa --device_count 1 --batch_size 16
....
Running log-likelihood queries: 100%|##########| 2000/2000 [00:07<00:00, 270.13it/s]
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 4003.12it/s]85.04it/s]
Metrics for piqa: acc: 0.647###### | 804/1000 [00:00<00:00, 4020.71it/s]
Running log-likelihood queries: 100%|##########| 2000/2000 [00:07<00:00, 260.74it/s]_val_loss=3.02, val_loss=3.02]
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 4000.07it/s]76.95it/s]
Metrics for piqa: acc: 0.648#####9 | 799/1000 [00:00<00:00, 3991.17it/s]
...
python -m catwalk.train --model rc::gpt2 --task piqa --device_count 2 --batch_size 16
...
Running log-likelihood queries: 100%|##########| 2000/2000 [00:07<00:00, 271.48it/s]
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 3683.20it/s]08.65it/s]
Metrics for <catwalk.tasks.eleuther.EleutherTask object at 0x7fe861c05490>: acc: 0.647
Running log-likelihood queries: 100%|##########| 2000/2000 [00:06<00:00, 288.04it/s]_val_loss=3.02, val_loss=3.02]
Calculating metrics: 100%|##########| 1000/1000 [00:00<00:00, 3767.45it/s]25.77it/s]
Metrics for <catwalk.tasks.eleuther.EleutherTask object at 0x7fe861c05490>: acc: 0.648
...
I don't like this. In this version of the callback, I didn't have to do this. I think the trick is to make sure that each worker runs through the same data.
But also, consider that in my latest training version (which I have only in a branch as yet), I don't even have the callback anymore. Can we just get rid of that whole problem area by making sure the trainable model computes its metrics during
forward()
?
I agree, getting rid of the validation callback all together would be the best solution. I'm concerned that's a bit beyond the scope of what I can accomplish this week. All this distributed processing stuff has me mostly just feeling around in the dark because of my lack of systems background.
This PR is not a necessary dependency of the IA3 PR #81, so if it's going to be superseded by your rework of the training code in that branch then perhaps we should just skip this PR?
I've reverted theses changes in the IA3 PR #81 as they are not actually necessary for that PR, and I don't want this one to block that.
I will revisit this after https://github.com/allenai/catwalk/pull/84 is merged.