triton icon indicating copy to clipboard operation
triton copied to clipboard

repeat_interleave or alternative needed to unpack quantized weights

Open fpgaminer opened this issue 1 year ago • 3 comments

I'm working on a Triton kernel to compute matmuls on quantized linear layers. In particular where there are more than one parameters packed into a single value of an int32 Tensor.

The issue is that I could not find a way to "unpack" such Tensors in Triton. For example, imagine I have an int32 Tensor of size [1, N//8], where each int32 represents eight 4-bit parameters. Inside a Triton kernel how do I expand this into a [1, N] Tensor?

Something like PyTorch's repeat_interleave would work, as it would allow one to unroll the packed tensor. From there one can apply shifting and masking to get the correct values unpacked at each index.

My current hack is the following:

b_ptrs = b_ptr + ((offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn)   # (BLOCK_SIZE_K, BLOCK_SIZE_N)
shifter = (offs_k % 8) * 4

...
# Inside the inner loop:
	b = tl.load(b_ptrs)   # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated

	# Now we need to unpack b (which is 4-bit values) into 32-bit values
	b = (b >> shifter[:, None]) & 0xF  # Extract the 4-bit values
	b = b * scales[None, :] - zeros[None, :]  # Scale and shift

This is based on the matmul tutorial code. The major difference is that I divide the b_ptrs indexes by // 8. This causes them to repeat along the K axis. So I'm basically making tl.load act like repeat_interleave for me. Then I can finish unpacking the values like normal.

The downside is that, as far as I'm aware, this results in 8x as many loads as compared to fetching the packed Tensor directly which is 8x smaller.

Having a built-in similar to repeat_interleave would allow me to unpack those values in SRAM and save the bandwidth. Or maybe a way to index a Tensor? Then I could build an interleaved index and do b[indexes]. But I didn't see any examples of indexing Tensors like that, so I assumed it wasn't possible in the language.

Does this functionality already exist? Is there a better implementation? Or should this be a feature request?

Thank you!

fpgaminer avatar Mar 27 '23 20:03 fpgaminer

#974:

Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.

Looks like we might see indexing support in the future

julian-q avatar Sep 17 '23 06:09 julian-q

#974:

Yeah, on-chip indexing through shared memory isn't supported yet. It's on the roadmap though, but it's a pretty advanced feature so we haven't come up with a specific timeline yet.

Looks like we might see indexing support in the future

Hello, I would like to confirm when this feature is expected to be supported?

vivienfanghuagood avatar Mar 19 '24 06:03 vivienfanghuagood

bumping this

jselvam11 avatar May 01 '24 05:05 jselvam11

Advanced tensor indexing feature wanted!

zzb66666666x avatar Aug 22 '24 21:08 zzb66666666x