burn
burn copied to clipboard
Autotune: Inputs more representative of the key
For Matmul in WGPU, we use autotune with a key that tells us implicitly the range in which the inputs of matmul are. For instance, if we have [3, 2, 63, 61] x [3, 3, 61, 62] the key will be { round: false, // not all m, k, n are multiples of 64 broadcast: true, //because 2 and 3 are not the same dim anchored_m: 64, // power of 2 above 63 anchored_k: 64 // power of 2 above 61 anchored_n: 64 // power of 2 above 62 anchored_batch: 16 // power of 2 above 3x3 (which is the max of the multiplied batches) }
But! The autotuning will all be done on [3, 2, 63, 61] x [3, 3, 61, 62] sized inputs. Afterwards, any inputs with the same key will use the algorithm chosen with autotune done on this, which as you can see may not be very representative. For instance, [3, 2, 33, 33] x [3, 3, 33, 33] will have the same key although it is far.
We have a limit for the anchored_batch also, it cannot be more than 256, so any matmul with more than 256 batches may fall on the same key. But if the first to be tried has 20,000 batches then this is the one that will be autotuned, which will be uselessly slow.
We would need to create the tensors for autotune (in autotunables()) so that they are more representative of the "center" of the key, like with m,k,n around 48 for instance. For batches it is even more important to run them with around 256 batches if there are more than that.