nanoGPT icon indicating copy to clipboard operation
nanoGPT copied to clipboard

Running train.py on 2060 GPU

Open lzeladam opened this issue 2 years ago • 6 comments

"Hello! I've been trying to run the train.py on a 2060 GPU, but this device does not support dtype=torch.bfloat16. What changes would I have to make to achieve my goal? Or can I only train on an Ampere architecture GPU for now? Thank you very much for sharing this project!"

lzeladam avatar Jan 02 '23 15:01 lzeladam

Two options:

  • use dtype=torch.float32 to disable mixed precision training. Will work on anything, but slow.
  • used dtype=torch.float16 to use fp16 instead of bf16. Because the range of fp16 is small this requires addition of gradient scaler. It's only a few lines of codes. I'm not sure if I should add support for it in stock train.py. Potentially the answer is yes. I didn't do so so far because I didn't want to bloat the training file with more options, but this might be common enough that it is worth it. Thinking it through...

karpathy avatar Jan 02 '23 17:01 karpathy

H @karpathy,

Thank you for your help, I made the change and now I have some problems detecting the CUDA in my WSL environment:

debug_wrapper raised RuntimeError: CUDA: Error- no device

I don't know why because the GPU is detected with nvidia-smi command:

image

so, I will try to solve it

lzeladam avatar Jan 03 '23 00:01 lzeladam

What are the min requirements to run nanoGPT?

jcherrera avatar Jan 03 '23 08:01 jcherrera

@jcherrera try to change this parameters batch_size = 12 by 16 Block_size = 1024 by 512

Note: This project doesn't work in windows because pythorch 2.0 by now only support Linux. Another alternative is pay a A10 or A100 instance in Lambdlabs.com ...maybe I'll could do a post 🤔

lzeladam avatar Jan 07 '23 00:01 lzeladam

@jcherrera set

compile = False # use PyTorch 2.0 to compile the model to be faster

in train.py

adammarples avatar Jan 11 '23 15:01 adammarples

To add one data point: I'm running unmodified python train.py with --batch_size=8 on ~22gb vram.

jorahn avatar Jan 19 '23 13:01 jorahn