torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

`EmbeddingBagCollection` output raises error when calls `to_dict()`

Open jiannanWang opened this issue 9 months ago • 0 comments

I tried to print the output from a quantized EBC layer in a model. However, when I call .to_dict() on the layer output I get the error: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]. This bug happens when I set the output_dtype to torch.qint8 or torch.quint8, but not torch.float32.

Below is the reproduction code and the log for the error. The code creates a model (with a dense layer, a sparse layer, a weighted sparse layer, and an over layer), quantizes the model, and runs a forward pass of the model on random inputs. My environment is Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0.

I guess the additional dimension might be some parameters added to support integer quantization. I wonder if this is the case. If so, then I wonder if the to_dict() function can be fixed to handle the additional dimension and produce the correct output dictionary. Thanks!

Reproduction code:

import traceback
from typing import Protocol, cast, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
    EmbeddingShardingPlanner,
    Topology,
)
from torchrec.distributed.test_utils.test_model import (
    ModelInput,
)
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingEnv,
    ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedTensor

from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder
from torchrec.inference.modules import quantize_embeddings
from torchrec.modules.embedding_modules import EmbeddingBagCollection

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()

        table_params = [
            [777, 912],
        ]

        weighted_table_params = [
            [941, 804],
        ]

        self.tables = [
            EmbeddingBagConfig(
                num_embeddings=table_params[i][0],
                embedding_dim=table_params[i][1],
                name="table_" + str(i),
                feature_names=["feature_" + str(i)],
                data_type=torch.int64,
            )
            for i in range(len(table_params))
        ]
        self.weighted_tables = [
            EmbeddingBagConfig(
                num_embeddings=weighted_table_params[i][0],
                embedding_dim=weighted_table_params[i][1],
                name="weighted_table_" + str(i),
                feature_names=["weighted_feature_" + str(i)],
            )
            for i in range(len(weighted_table_params))
        ]

        self.dense = nn.Linear(in_features=914, out_features=930, bias=True)

        self.sparse = EmbeddingBagCollection(
            tables=self.tables,
            is_weighted=False,
        )
        self.sparse_weighted = EmbeddingBagCollection(
            tables=self.weighted_tables, 
            is_weighted=True,
        )

        in_features_concat = (
            self.dense.out_features
            + sum(
                [
                    table.embedding_dim * len(table.feature_names)
                    for table in self.tables
                ]
            )
            + sum(
                [
                    table.embedding_dim * len(table.feature_names)
                    for table in self.weighted_tables
                ]
            )
        )

        self.over = nn.Linear(in_features=in_features_concat, out_features=224, bias=True)

    def forward(
        self,
        input,
    ):
        dense_r = self.dense(input.float_features)
        sparse_r = self.sparse(input.idlist_features)
        sparse_weighted_r = self.sparse_weighted(input.idscore_features)
        result = KeyedTensor(
            keys=sparse_r.keys() + sparse_weighted_r.keys(),
            length_per_key=sparse_r.length_per_key()
            + sparse_weighted_r.length_per_key(),
            values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
        )

        _features = [
            feature for table in self.tables for feature in table.feature_names
        ]
        _weighted_features = [
            feature for table in self.weighted_tables for feature in table.feature_names
        ]

        ret_list = []
        ret_list.append(dense_r)
        for feature_name in _features:
            ret_list.append(result[feature_name])
        for feature_name in _weighted_features:
            ret_list.append(result[feature_name])
        ret_concat = torch.cat(ret_list, dim=1)

        over_r = self.over(ret_concat)
        pred = torch.sigmoid(torch.mean(over_r, dim=1))
        if self.training:
            return (
                torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label),
                pred, (dense_r, sparse_r, sparse_weighted_r, over_r),
            )
        else:
            return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)


def sharding_single_rank_test(
    world_size: int,
    model,
    inputs,
    sharders: List[ModuleSharder[nn.Module]],
    quant_dtype = None,
    quant_output_dtype = None,
) -> None:
    device = torch.device("cuda:0")
    model = model.to(device)
    
    model = quantize_embeddings(model, dtype=quant_dtype, inplace=True, output_dtype=quant_output_dtype)

    global_input_train = inputs[0][0].to(device)

    local_model = model

    planner = EmbeddingShardingPlanner(
        topology=Topology(
            world_size, device.type
        )
    )
    plan: ShardingPlan = planner.plan(local_model, sharders)

    local_model = DistributedModelParallel(
        local_model,
        env=ShardingEnv.from_local(world_size=world_size, rank=0),
        plan=plan,
        sharders=sharders,
        device=device,
        init_data_parallel=False,
    )

    local_pred, (local_dense_r, local_sparse_r, local_sparse_weighted_r, local_over_r) = local_model(global_input_train)

    print("local_sparse_r: ")
    print(local_sparse_r.values().shape)
    print(local_sparse_r.keys())
    print(local_sparse_r.length_per_key())
    try:
        print(local_sparse_r.to_dict())
    except Exception as e:
        print(e)
        traceback.print_exc()
    
    print(local_sparse_weighted_r.values().shape)
    print(local_sparse_weighted_r.keys())
    print(local_sparse_weighted_r.length_per_key())
    try:
        print(local_sparse_weighted_r.to_dict())
    except Exception as e:
        print(e)
        traceback.print_exc()

class ModelInputCallable(Protocol):
    def __call__(
        self,
        batch_size: int,
        world_size: int,
        num_float_features: int,
        tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
        weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
        pooling_avg: int = 10,
        dedup_tables: Optional[
            Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
        ] = None,
        variable_batch_size: bool = False,
        long_indices: bool = True,
    ) -> Tuple["ModelInput", List["ModelInput"]]: ...

def main_test(
    sharders: List[ModuleSharder[nn.Module]],
    world_size: int = 2,
    quant_dtype = torch.qint8,
    quant_output_dtype = torch.float32,
) -> None:
    model = TestModel()
    batch_size=2400
    world_size=1
    num_float_features=model.dense.in_features
    tables = model.tables
    weighted_tables=model.weighted_tables

    inputs = [
        (
            cast(ModelInputCallable, ModelInput.generate)(
                world_size=world_size,
                tables=tables,
                weighted_tables=weighted_tables or [],
                num_float_features=num_float_features,
                batch_size=batch_size,
            )
        )
    ]

    sharding_single_rank_test(
        world_size=world_size,
        model=model,
        inputs=inputs,
        sharders=sharders,
        quant_dtype = quant_dtype,
        quant_output_dtype = quant_output_dtype,
    )


def main():
    backend = "nccl"
    world_size = 2
    
    dtype = torch.qint8
    output_dtype = torch.qint8

    sharding_type = "table_wise"
    kernel_type = "quant"
    sharders = [TestQuantEBCSharder(sharding_type, kernel_type)]

    main_test(
        sharders = sharders,
        world_size = world_size,
        quant_dtype=dtype,
        quant_output_dtype=output_dtype,
    )


if __name__ == "__main__":
    main()

Logs:

local_sparse_r:
torch.Size([2400, 920])
['feature_0']
[912]
split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
Traceback (most recent call last):
  File "/mnt/util/reproduce_quant_unit_32.py", line 168, in sharding_single_rank_test
    print(local_sparse_r.to_dict())
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
    split_values = self._values.split(lengths, dim=self._key_dim)
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
torch.Size([2400, 812])
['weighted_feature_0']
[804]
split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
Traceback (most recent call last):
  File "/mnt/util/reproduce_quant_unit_32.py", line 177, in sharding_single_rank_test
    print(local_sparse_weighted_r.to_dict())
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
    split_values = self._values.split(lengths, dim=self._key_dim)
  File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
    return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]

jiannanWang avatar May 26 '24 23:05 jiannanWang