gpt-2
gpt-2 copied to clipboard
What is the minimum size of GPU I need to set batch_size more than 1 to train 345M model using train.py?
- I am using
ml.p3.2xlarge
instance on AWS with one 16 GB V100 GPU and tried to train 345 model with batch_size 2 and it gets OOM error. It works for batch_size 1 though. - I am thinking of using batch_size 2 to 8. What size of GPU do I need to make this happen? If anyone has experienced this situation, sharing it would be helpful.
- I am using this command to train it.
python train.py --dataset Dataset/data.npz --sample_every 10 --sample_num 3 --batch_size 1 --learning_rate 0.0001 --model_name 345M