torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Floating point exception (core dumped) when use quantize embeddings with float32 dtype

Open tiankongdeguiji opened this issue 8 months ago • 3 comments

Using quantized embeddings with the float32 data type may lead to Floating point exception (core dumped),We can reproduce this using the following command: python test_quant.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124

test_quant.py

import torch
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.inference.modules import quantize_embeddings

large_table_cnt = 2
small_table_cnt = 2
large_tables = [
    torchrec.EmbeddingBagConfig(
        name="large_table_" + str(i),
        embedding_dim=64,
        num_embeddings=4096,
        feature_names=["large_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(large_table_cnt)
]
small_tables = [
    torchrec.EmbeddingBagConfig(
        name="small_table_" + str(i),
        embedding_dim=64,
        num_embeddings=1024,
        feature_names=["small_table_feature_" + str(i)],
        pooling=torchrec.PoolingType.SUM,
    )
    for i in range(small_table_cnt)
]

class DebugModel(nn.Module):
    def __init__(self, device: torch.device):
        super().__init__()
        self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device=device)
        self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)

    def forward(self, kjt: KeyedJaggedTensor):
        emb = self.ebc(kjt)
        return torch.mean(self.linear(emb.values()))
    
model = DebugModel(device=torch.device("cuda:0"))
# dtype == qint8 is ok  
quantize_embeddings(model, dtype=torch.float, inplace=True)

tiankongdeguiji avatar Mar 12 '25 06:03 tiankongdeguiji