maxtext
maxtext copied to clipboard
The default setting of `param_scan_axis=1` hurts performance and memory consumption on GPUs
The default setting of param_scan_axis=1 causes Flax to transpose the model parameters for the scan in addition to keeping the untransposed version for optimizer state update.
Compared to param_scan_axis=0 on Llama2-7b, the extra memory consumption is ~13GB (out of ~81GB vs ~68GB) and the performance hit is ~3%.
Here is a log from a run with the default param_scan_axis=1:
completed step: 1, seconds: 0.637, TFLOP/s/device: 148.271, Tokens/s/device: 6434.697, total_weights: 8192, loss: 10.863
completed step: 2, seconds: 0.450, TFLOP/s/device: 209.671, Tokens/s/device: 9099.331, total_weights: 8192, loss: 9.510
completed step: 3, seconds: 0.449, TFLOP/s/device: 210.105, Tokens/s/device: 9118.189, total_weights: 8192, loss: 7.975
completed step: 4, seconds: 0.450, TFLOP/s/device: 209.736, Tokens/s/device: 9102.141, total_weights: 8192, loss: 6.078
completed step: 5, seconds: 0.450, TFLOP/s/device: 209.950, Tokens/s/device: 9111.435, total_weights: 8192, loss: 4.459
completed step: 6, seconds: 0.450, TFLOP/s/device: 209.673, Tokens/s/device: 9099.412, total_weights: 8192, loss: 3.240
completed step: 7, seconds: 0.450, TFLOP/s/device: 209.882, Tokens/s/device: 9108.497, total_weights: 8192, loss: 2.379
completed step: 8, seconds: 0.450, TFLOP/s/device: 209.835, Tokens/s/device: 9106.431, total_weights: 8192, loss: 1.819
completed step: 9, seconds: 0.450, TFLOP/s/device: 209.915, Tokens/s/device: 9109.915, total_weights: 8192, loss: 1.484
Output size: 40430494092, temp size: 40316314088, argument size: 40430575628, host temp size: 0, in bytes.
Here is a log with param_scan_axis=0:
completed step: 1, seconds: 0.620, TFLOP/s/device: 152.335, Tokens/s/device: 6611.047, total_weights: 8192, loss: 10.863
completed step: 2, seconds: 0.435, TFLOP/s/device: 217.068, Tokens/s/device: 9420.337, total_weights: 8192, loss: 9.510
completed step: 3, seconds: 0.434, TFLOP/s/device: 217.292, Tokens/s/device: 9430.074, total_weights: 8192, loss: 7.975
completed step: 4, seconds: 0.434, TFLOP/s/device: 217.400, Tokens/s/device: 9434.745, total_weights: 8192, loss: 6.079
completed step: 5, seconds: 0.434, TFLOP/s/device: 217.451, Tokens/s/device: 9436.983, total_weights: 8192, loss: 4.461
completed step: 6, seconds: 0.435, TFLOP/s/device: 217.040, Tokens/s/device: 9419.145, total_weights: 8192, loss: 3.241
completed step: 7, seconds: 0.434, TFLOP/s/device: 217.333, Tokens/s/device: 9431.833, total_weights: 8192, loss: 2.380
completed step: 8, seconds: 0.434, TFLOP/s/device: 217.454, Tokens/s/device: 9437.114, total_weights: 8192, loss: 1.820
completed step: 9, seconds: 0.434, TFLOP/s/device: 217.345, Tokens/s/device: 9432.376, total_weights: 8192, loss: 1.485
Output size: 40430494092, temp size: 27134419432, argument size: 40430575628, host temp size: 0, in bytes.
@khatwanimohit can we merge your PR and close this ?