torchrec
torchrec copied to clipboard
`EmbeddingBagCollection` output raises error when calls `to_dict()`
I tried to print the output from a quantized EBC layer in a model. However, when I call .to_dict()
on the layer output I get the error: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
. This bug happens when I set the output_dtype
to torch.qint8
or torch.quint8
, but not torch.float32
.
Below is the reproduction code and the log for the error. The code creates a model (with a dense layer, a sparse layer, a weighted sparse layer, and an over layer), quantizes the model, and runs a forward pass of the model on random inputs. My environment is Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0
.
I guess the additional dimension might be some parameters added to support integer quantization. I wonder if this is the case. If so, then I wonder if the to_dict()
function can be fixed to handle the additional dimension and produce the correct output dictionary. Thanks!
Reproduction code:
import traceback
from typing import Protocol, cast, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.test_utils.test_model import (
ModelInput,
)
from torchrec.distributed.types import (
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.distributed.test_utils.infer_utils import TestQuantEBCSharder
from torchrec.inference.modules import quantize_embeddings
from torchrec.modules.embedding_modules import EmbeddingBagCollection
class TestModel(nn.Module):
def __init__(self):
super().__init__()
table_params = [
[777, 912],
]
weighted_table_params = [
[941, 804],
]
self.tables = [
EmbeddingBagConfig(
num_embeddings=table_params[i][0],
embedding_dim=table_params[i][1],
name="table_" + str(i),
feature_names=["feature_" + str(i)],
data_type=torch.int64,
)
for i in range(len(table_params))
]
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=weighted_table_params[i][0],
embedding_dim=weighted_table_params[i][1],
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(len(weighted_table_params))
]
self.dense = nn.Linear(in_features=914, out_features=930, bias=True)
self.sparse = EmbeddingBagCollection(
tables=self.tables,
is_weighted=False,
)
self.sparse_weighted = EmbeddingBagCollection(
tables=self.weighted_tables,
is_weighted=True,
)
in_features_concat = (
self.dense.out_features
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in self.tables
]
)
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in self.weighted_tables
]
)
)
self.over = nn.Linear(in_features=in_features_concat, out_features=224, bias=True)
def forward(
self,
input,
):
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(input.idlist_features)
sparse_weighted_r = self.sparse_weighted(input.idscore_features)
result = KeyedTensor(
keys=sparse_r.keys() + sparse_weighted_r.keys(),
length_per_key=sparse_r.length_per_key()
+ sparse_weighted_r.length_per_key(),
values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
)
_features = [
feature for table in self.tables for feature in table.feature_names
]
_weighted_features = [
feature for table in self.weighted_tables for feature in table.feature_names
]
ret_list = []
ret_list.append(dense_r)
for feature_name in _features:
ret_list.append(result[feature_name])
for feature_name in _weighted_features:
ret_list.append(result[feature_name])
ret_concat = torch.cat(ret_list, dim=1)
over_r = self.over(ret_concat)
pred = torch.sigmoid(torch.mean(over_r, dim=1))
if self.training:
return (
torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label),
pred, (dense_r, sparse_r, sparse_weighted_r, over_r),
)
else:
return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)
def sharding_single_rank_test(
world_size: int,
model,
inputs,
sharders: List[ModuleSharder[nn.Module]],
quant_dtype = None,
quant_output_dtype = None,
) -> None:
device = torch.device("cuda:0")
model = model.to(device)
model = quantize_embeddings(model, dtype=quant_dtype, inplace=True, output_dtype=quant_output_dtype)
global_input_train = inputs[0][0].to(device)
local_model = model
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size, device.type
)
)
plan: ShardingPlan = planner.plan(local_model, sharders)
local_model = DistributedModelParallel(
local_model,
env=ShardingEnv.from_local(world_size=world_size, rank=0),
plan=plan,
sharders=sharders,
device=device,
init_data_parallel=False,
)
local_pred, (local_dense_r, local_sparse_r, local_sparse_weighted_r, local_over_r) = local_model(global_input_train)
print("local_sparse_r: ")
print(local_sparse_r.values().shape)
print(local_sparse_r.keys())
print(local_sparse_r.length_per_key())
try:
print(local_sparse_r.to_dict())
except Exception as e:
print(e)
traceback.print_exc()
print(local_sparse_weighted_r.values().shape)
print(local_sparse_weighted_r.keys())
print(local_sparse_weighted_r.length_per_key())
try:
print(local_sparse_weighted_r.to_dict())
except Exception as e:
print(e)
traceback.print_exc()
class ModelInputCallable(Protocol):
def __call__(
self,
batch_size: int,
world_size: int,
num_float_features: int,
tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
pooling_avg: int = 10,
dedup_tables: Optional[
Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
] = None,
variable_batch_size: bool = False,
long_indices: bool = True,
) -> Tuple["ModelInput", List["ModelInput"]]: ...
def main_test(
sharders: List[ModuleSharder[nn.Module]],
world_size: int = 2,
quant_dtype = torch.qint8,
quant_output_dtype = torch.float32,
) -> None:
model = TestModel()
batch_size=2400
world_size=1
num_float_features=model.dense.in_features
tables = model.tables
weighted_tables=model.weighted_tables
inputs = [
(
cast(ModelInputCallable, ModelInput.generate)(
world_size=world_size,
tables=tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
batch_size=batch_size,
)
)
]
sharding_single_rank_test(
world_size=world_size,
model=model,
inputs=inputs,
sharders=sharders,
quant_dtype = quant_dtype,
quant_output_dtype = quant_output_dtype,
)
def main():
backend = "nccl"
world_size = 2
dtype = torch.qint8
output_dtype = torch.qint8
sharding_type = "table_wise"
kernel_type = "quant"
sharders = [TestQuantEBCSharder(sharding_type, kernel_type)]
main_test(
sharders = sharders,
world_size = world_size,
quant_dtype=dtype,
quant_output_dtype=output_dtype,
)
if __name__ == "__main__":
main()
Logs:
local_sparse_r:
torch.Size([2400, 920])
['feature_0']
[912]
split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
Traceback (most recent call last):
File "/mnt/util/reproduce_quant_unit_32.py", line 168, in sharding_single_rank_test
print(local_sparse_r.to_dict())
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
split_values = self._values.split(lengths, dim=self._key_dim)
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 920 (input tensor's size at dimension 1), but got split_sizes=[912]
torch.Size([2400, 812])
['weighted_feature_0']
[804]
split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]
Traceback (most recent call last):
File "/mnt/util/reproduce_quant_unit_32.py", line 177, in sharding_single_rank_test
print(local_sparse_weighted_r.to_dict())
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torchrec/sparse/jagged_tensor.py", line 2313, in to_dict
split_values = self._values.split(lengths, dim=self._key_dim)
File "/root/miniconda3/envs/pttrlatest/lib/python3.10/site-packages/torch/_tensor.py", line 921, in split
return torch._VF.split_with_sizes(self, split_size, dim)
RuntimeError: split_with_sizes expects split_sizes to sum exactly to 812 (input tensor's size at dimension 1), but got split_sizes=[804]