dspy
dspy copied to clipboard
CUDA memoory error when using BootstrapFinetune
Hello,
I am using the BootstrapFinetune
teleprompter on a task that I have defined with dspy. When I run my code, I always get CUDA out of memory error, and in addition, the teleprompter consumes a lot of cost in my OpenAI API. I want to ask why is this the case, and how I can control the memory parameters using dspy?
Thanks for your help!
Hi! What GPU do you have, what model are you trying to fine tune? How are you initializing the model?
OpenAI shouldn't be using any vram (the cause of an OOM error).
When looking at bootstrapfinetune docs, seems like adjusting bsize or enabling bf16 are probably your easiest bets for avoiding the OOM error, assuming you have enough compute.
Hi @darinkishore
Thanks for the response. The model is google/flan-t5-large
, and I am using an A100.
Hey @kimianoorbakhsh, thanks for opening the issue. (Also thanks a lot @darinkishore for helping!)
It's easier to help if I can see (some of) your code, but I'll use the example from examples/qa/hotpot/multihop_finetune.ipynb
There, the example has this form:
config = dict(target='t5-large', epochs=2, bf16=True, bsize=6, accumsteps=2, lr=5e-5)
tp = BootstrapFinetune(metric=None)
t5_program = tp.compile(PROGRAM, teacher=FEW_SHOT_PROGRAM, trainset=TRAIN_LIST, **config)
First off, is this similar to your structure?
Second, in order to supervised finetuning, BootstrapFinetune
will take the teacher
program and run it on all examples in trainset
. This is probably why you're seeing OpenAI costs. Your teacher
goes to OpenAI.
But if you run it again, it will be fast and free, assuming your cache is still available. (I.e., you're on the same server and user)
After the traces are generated from the teacher
, the compiler BootstrapFinetune
will launch a finetuning job for your local model, in this case google/flan-t5-large
.
-
You can check the json file it saves before compiling to confirm everything looks like what you need. The path is printed for you when you run it. (If you lost that output, it's also easy to find it in a cache folder under compiler/)
-
Often, the inputs are very long, especially if you're doing retrieval or something that has long context. This isn't DSPy-specific. You can avoid OOM errors by (1) using a smaller LM
google/flan-t5-base
, (2) reducingbsize
and optionally increasingaccumsteps
, or (3) turning onbf16=True
like Darin said.
Concretely, check the json file and then, if all looks good, use google/flan-t5-base
with bsize=6, accumsteps=2, bf16=True
in the config. Does this help?
Hi @okhat ,
Thank you very much. This is very helpful! Yes, I was using almost the same setup. I will go ahead and adjust the parameters as you suggested, and try again.
Hey @kimianoorbakhsh , happy to help with this if needed, let us know how it goes!