torchrec
torchrec copied to clipboard
Floating point exception (core dumped) when use quantize embeddings with float32 dtype
Using quantized embeddings with the float32 data type may lead to Floating point exception (core dumped),We can reproduce this using the following command: python test_quant.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124
test_quant.py
import torch
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.inference.modules import quantize_embeddings
large_table_cnt = 2
small_table_cnt = 2
large_tables = [
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(large_table_cnt)
]
small_tables = [
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(small_table_cnt)
]
class DebugModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device=device)
self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)
def forward(self, kjt: KeyedJaggedTensor):
emb = self.ebc(kjt)
return torch.mean(self.linear(emb.values()))
model = DebugModel(device=torch.device("cuda:0"))
# dtype == qint8 is ok
quantize_embeddings(model, dtype=torch.float, inplace=True)