BitNet
BitNet copied to clipboard
Expected BitLinear weight to be 1 or -1
Hello, I presume according to BitNet paper the weight should be -1 or 1. But
import torch
from bitnet import BitLinearNew
# Create a random tensor of shape (16, 10)
x = torch.randn(2, 10, 10)
# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(
10,
20,
)
# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)
print(layer.weight.dtype)
print(layer.weight)
Output
torch.float32
Parameter containing:
tensor([[ 0.1634, 0.2419, -0.0605, 0.1592, 0.2348, -0.1431, -0.1634, 0.0171,
-0.1672, -0.1526],
[-0.0848, 0.0079, -0.2014, -0.0492, 0.2833, 0.1290, -0.2156, -0.1515,
-0.0473, -0.0839],
[ 0.2230, 0.1434, -0.1410, -0.0626, 0.1189, -0.1652, -0.2978, -0.0287,
0.1025, 0.2458],
[-0.1670, -0.0222, -0.0272, -0.2312, 0.1880, -0.2040, -0.0305, 0.1009,
-0.2247, 0.0124],
[ 0.1351, -0.2926, 0.1891, -0.1614, 0.2894, -0.2931, 0.0802, 0.2884,
0.0454, -0.1398],
[-0.2954, 0.2651, -0.0062, -0.1592, 0.2138, -0.2038, 0.2965, -0.2545,
0.0505, -0.0811],
[-0.3062, -0.1191, -0.1521, 0.1021, -0.1865, -0.1102, 0.2120, -0.2865,
0.1754, 0.1763],
[ 0.1375, -0.2975, 0.0399, -0.1723, -0.0526, -0.2694, 0.1838, -0.1826,
0.2806, -0.1438],
[-0.3150, 0.2163, 0.1946, -0.0244, 0.0657, -0.1531, -0.0310, 0.0071,
0.2590, 0.0985],
[ 0.0402, 0.0704, -0.1441, -0.1929, -0.2450, 0.2408, -0.0750, 0.0238,
0.3030, 0.0516],
[ 0.1537, -0.2231, -0.0092, -0.1068, 0.3131, 0.0859, -0.1692, -0.2364,
0.2257, 0.2601],
[-0.0478, -0.2978, -0.2025, -0.2411, -0.3061, -0.2566, 0.0564, -0.0906,
0.2113, 0.3118],
[-0.1048, 0.2073, -0.2126, -0.1883, 0.0463, -0.1716, -0.3052, 0.0548,
0.2079, 0.2587],
[-0.1387, 0.1778, -0.1886, 0.1239, 0.0265, -0.0421, -0.1020, 0.2481,
-0.0840, 0.1879],
[ 0.0707, -0.0534, 0.0623, 0.0803, 0.3135, 0.2192, -0.1202, 0.3139,
0.0781, -0.0512],
[ 0.2812, 0.2515, -0.0371, 0.0248, 0.0231, -0.0437, 0.0875, 0.3085,
-0.0482, -0.0092],
[ 0.1735, 0.2584, -0.0900, -0.1616, 0.1253, 0.1352, 0.1841, 0.1416,
-0.0686, -0.0269],
[-0.3121, -0.1050, 0.0265, 0.0242, 0.1973, 0.1816, -0.0084, 0.2866,
0.2559, -0.2523],
[ 0.1272, -0.2361, 0.0847, -0.0724, 0.2531, 0.0948, -0.0765, -0.1252,
-0.0459, -0.0133],
[-0.0660, 0.0650, 0.2529, -0.1763, -0.1248, -0.1073, -0.2926, 0.1837,
0.1265, -0.0090]], requires_grad=True)
Am I missing something?
Upvote & Fund
- We're using Polar.sh so you can upvote and help fund this issue.
- We receive the funding once the issue is completed & confirmed by you.
- Thank you in advance for helping prioritize & fund our backlog.
I'm pretty sure the reason it shows fp32 numbers in the parameters is because when training the network, you need to use the original floating point values for backprop (1.58 bit quant destroys gradient). Then when you do the forward pass, the weights are re-quantized every training step. I believe when you actually deploy the model, you would simply take the quantized weights and use them.
The answer from @rolson24 is correct in the sense that the weights are only converted before the linear operation and then dequantized just like the original paper describes(see image below). However, digging a little deeper in the code, i found that the weight quantization is not actually being done correctly, the sign operation in the weight quantization is being done before the scale multiplication. Here's how to test it:
import torch
# https://github.com/kyegomez/BitNet/blob/f56addac025b8bd58a02f9f00eb4cf530770658a/bitnet/bitlinear.py#L20C1-L24C13
def current_weight_quant(w):
scale = w.abs().mean()
e = w.mean()
u = (w - e).sign() * scale
return u
def correct_weight_quant(w):
scale = w.abs().mean()
e = w.mean()
u = torch.sign((w - e) * scale)
return u
weights = torch.rand(15)
print("Original weights: ", weights)
print("Repo quant fn: ", current_weight_quant(weights))
print("Correct quant fn:", correct_weight_quant(weights))
Output:
Original weights: tensor([0.5857, 0.0053, 0.8400, 0.5586, 0.8302, 0.9758, 0.6332, 0.4917, 0.4092,
0.6722, 0.5738, 0.1896, 0.5210, 0.6124, 0.1334])
Repo quant fn: tensor([ 0.5355, -0.5355, 0.5355, 0.5355, 0.5355, 0.5355, 0.5355, -0.5355,
-0.5355, 0.5355, 0.5355, -0.5355, -0.5355, 0.5355, -0.5355])
Correct quant fn: tensor([ 1., -1., 1., 1., 1., 1., 1., -1., -1., 1., 1., -1., -1., 1.,
-1.])
Just created the PR #59 to get it fixed.
What about zero case @jmbrito01 ?, paper mentions weights can be {-1,0,1}, i only see {-1,1}
Edit: Ok me bad, i have read an update of the article, and i have realized that b1 and b1.58 are different architectures, in this case b1 refers to BitLinear implementation and its values are always in range {-1,1}.
Tip: We should have a unit test section to avoid this kind of issues and validate future PR's
Stale issue message