Obtaining bit representation of byte
I would like to obtain a bit representation of a tensor. Using PyTorch, I can do it like this:
import torch
n_bits = 7
inp = torch.tensor([45])
mask = 2 ** torch.arange(n_bits).to(torch.uint8).unsqueeze(-1)
bitsliced_input = (
((2**n_bits - 1) * inp)
.abs()
.unsqueeze(-2)
.round()
.byte()
.bitwise_and(mask)
.ne(0)
.byte()
)
print(bitsliced_input)
Is there a way to perform bitwise operations in triton? I think then I could get closer to the solution.
Or is it better to create that bit-matrix externally? I was trying to avoid this since it would incur a bunch of extra memory consumption (by a factor of n_bits).
Thanks for your interest in triton, but this is a bug tracker not a place to ask for programming advice. Consider directing your questions to the triton channel on the GPU Mode discord server.
FWIW though, triton has arange, broadcasting and bitwise arithmetic operators so I don't see any issue with writing this in triton.
Ah thank you, I was searching for something like this.