tiny-cuda-nn
tiny-cuda-nn copied to clipboard
How big can you make a fully fused MLP while retaining performance benefits?
The speedup in this repo relies on getting the memory traffic close to the chip - in caches/registers etc. This is going to stop working if an MLP is sufficiently large, but I'm unclear where the boundary is.
Does anyone know the answers to these questions:
- How big can you make an MLP while retaining the performance benefits? (Has anyone tested this?)
- Can you "trade off" a smaller batch size for a larger MP and still keep the benefits?
- Would using more powerful hardware (e.g. an A100 which has 40MB L2 cache over an RTX3090 which 6MB L2 cache) expand this performance window?
I could potentially help out with testing (3)
Hi there!
- It's a continuum, which I would classify as follows:
a. Let's use CUTLASS's matrix multiplication routines (implemented in
CutlassMLP) as a baseline, since these avoid unrelated overheads of Python frameworks. b. Compared to that baseline, I've observed: significant speedups for 64-wide and smaller MLPs, moderate speed-ups for 128-wide MLPs, and no speedup for 256-wide MLPs. RTX 3090. I've hand-tuned the low-level kernel configurations for each of them, so am reasonably confident in this. - Unfortunately not with the structure of computation that the current implementation exploits.
- Again, unfortunately no. The L2 cache is shared across multiple SMs and thus equally benefits traditional (non-fused) matmuls as well as the fully fused approach. What would help is an increased register file, L1 cache, and shared memory. To fully exploid these, a few of the low-level kernel parameters in
fully_fused_mlp.cuneed to be tuned to whichever sizes are available.
For reference, here are the specs of an A100 vs a 3090:
- A100: 256 KB register per SM, 164 KB shared memory per SM, 108 SMs (source)
- 3090: 256 KB register per SM, 128 KB shared memory per SM, 82 SM (source)
These numbers are quite similar on a per-SM basis, although the A100 has significantly more SMs. Do you think this would make much difference? (Provided kernel parameters are tuned appropriately)