[Feature] Quantized Convolution 2D - Possible PR?
Hi @awni , I've been working on a quantizable Conv2D layer that is a dropin replacement for Conv2D (for a large conv unet ~5GB with self attention, cross attention and a bunch of other stuff). I've already done bf16, quantized linear layers, optimized till blue in the face etc but still achieved only 2.5 GB weights with 6 GB of runtime memory required. I've implemented a naive Quantized Convolution using pure MLX ops as below:
def quantize_weights(self):
if self.quantized:
return
if self.groups == 1:
w_2d = self.weight.reshape(self.weight.shape[0], -1)
original_size = w_2d.shape[-1]
if original_size % self.group_size != 0:
padded_size = ((original_size + self.group_size - 1) // self.group_size) * self.group_size
w_2d = mx.pad(w_2d, [(0, 0), (0, padded_size - original_size)])
self.padded_input_size = w_2d.shape[-1]
q_w, scales, q_b = mx.quantize(w_2d, group_size=self.group_size, bits=self.bits)
self.q_weights = [q_w]
self.scales = [scales]
self.q_biases = [q_b]
else:
self.padded_input_size = []
out_channels_per_group = self.weight.shape[0] // self.groups
for g in range(self.groups):
start_out = g * out_channels_per_group
end_out = start_out + out_channels_per_group
w_group = self.weight[start_out:end_out, :, :, :]
w_2d = w_group.reshape(out_channels_per_group, -1)
original_size = w_2d.shape[-1]
if original_size % self.group_size != 0:
padded_size = ((original_size + self.group_size - 1) // self.group_size) * self.group_size
w_2d = mx.pad(w_2d, [(0, 0), (0, padded_size - original_size)])
self.padded_input_size.append(w_2d.shape[-1])
q_w, scales, q_b = mx.quantize(w_2d, group_size=self.group_size, bits=self.bits)
self.q_weights.append(q_w)
self.scales.append(scales)
self.q_biases.append(q_b)
self.quantized = True
This uses another naive im2col as below:
def img2col(self, x, kh, kw, stride, padding):
batch, height, width, channels = x.shape
sh, sw = stride
ph, pw = padding
x_padded = mx.pad(x, [(0, 0), (ph, ph), (pw, pw), (0, 0)])
out_h,out_w = ((height + 2 * ph - kh) // sh + 1,(width + 2 * pw - kw) // sw + 1) #debugging woot
patches = []
for b in range(batch):
for y in range(0, height + 2 * ph - kh + 1, sh):
for x_pos in range(0, width + 2 * pw - kw + 1, sw):
patch = x_padded[b, y:y+kh, x_pos:x_pos+kw, :].flatten()
patches.append(patch)
return mx.stack(patches)
This basically then plugs into mx.quantized_matmul to either do groupwise matmuls or single matmuls in the forward pass. It is assumed that backwards passes cannot be performed on a quantized matrix. Now, as expected, this is slow af. But it works and matches up exactly with outputs from nn.Conv2D ->
Benchmarking convolution layers: MLX Conv2d: 0.481 ms/forward Implemented Conv2d: 39.193 ms/forward Quantized Conv2d: 37.448 ms/forward
Speedup vs MLX Conv2d: Implemented: 0.01x Quantized: 0.01x
This is on MBP M3 Pro. Runs were computed as 5s warmup with 100 iterations post warmup. Weights in storage now take approximately 33% of space versus unquantized.
Questions:
- How to speedup? Is a custom layer in Metal acceptable or is there some C++ backing this somewhere (I am not good at C++ I am just asking this to answer the below question)
- Are you looking for PRs on this topic (are there any guidelines on submitting PRs for a new layer)?
- Is there already a better implementation I missed thus making my work redundant and me looking dumb?
Thanks for your attention!