simplify usage of cudnn
- Simplify graph cache and usage of cudnn.
- Fix failures in H100
Hello ty for the PR, I'm not an expert in cudnn use do you have a short explanation for some of these changes? Also I noticed you edited the dev/cuda files but do you potentially have any suggestions for train_gpt2.cu (the prod file)?
Hi,
Thanks for taking a look.
Change 1 (L63 - L83):
cudnn::frontend::graph::execute, has two flavors to take the variant pack (device pointers).
Currently at your top of tree the variant pack is a map of shared_ptr<TensorDescriptor> -> device pointer. This means we have to remember pointer values of the tensor in the map. Instead if we manually assign the uid to the tensors, the execute can take a map of uid -> device pointer, which is easier to read and understand. cudnn-frontend is updating its examples to showcase this behavior.
Change 2 (L146):
Fwd graph does require workspace in H100 (Although very small 32B or less)
Change 3(~L135 and ~L200):
Using the graph.key() is a good starting point to find a key for plan cache. However, this involves some cpu overhead of creating the graph. Making a custom key (a tuple of B, H, Seqlen, Hidden Dimension) should be enough in this case. Similar custom key exists in Tranformer engine. This will also benefit from compile time improvement with future release of cudnn-frontend.
I will take a more closer look at train_gpt2.cu. And see what can be further optimized.
Running the test with this PR
make test_gpt2cu USE_CUDNN=1 && ./test_gpt2cu
actually fails, and specifically the error on qkvw tensor grows from 1.1e-1 to 1.4e-1. So we'd have to dumb the tolerances on this tensor. I'm ~ok bumping the tolerance if this is just noise, mostly. The PR looks fairly harmless and algorithmically everything look ok and untouched?
But maybe it's worth maybe looking one more time to see if anything may have changed to trip the error thresholds we have right now.
This was not flagged by our CI because I think it does not turn on USE_CUDNN=1 in the make command.
Sorry for spam, I noticed that it's not this PR that is "flipping" the test from FAIL to PASS, it's the way we compile, without the use of USE_CUDNN=1. Master is "broken" in the same way. I'll bump the thresholds here and merge.