torchrec
torchrec copied to clipboard
Inconsistency results of quantized sparse layer and weighted sparse layer.
There are large inconsistency results when running a single forward pass of a torchrec model (a dense layer, a sparse layer, a weighted sparse layer, and an over layer) under distributed and non-distributed settings.
Below is the code to reproduce the inconsistency. In the code, I created a model and inputs and quantized the model with dtype = torch.qint8
and output_dtype = torch.qint8
. I then run a forward pass with the distributed model and the non-distributed model. Since the model's weights are copied, I expect their results to be the same. However, there are large inconsistencies in the results. The inconsistencies are shown in the log. The environment is Python 3.10.14, torch 2.3.0+cu121, torchrec 0.7.0
Note that this code is updated from torchrec 0.2.0
. When running the below code in 0.2.0, the sparse layer prints NaN output.
The inconsistencies should be bugs because the distributed model and the non-distributed model have the same parameters and inputs. When running a single forward pass, they should return the same results.
Reproduction code
import copy
import traceback
from typing import Any, Type, Dict, List, Optional, Union, Tuple
import torch
import torch.nn as nn
from torchrec.distributed.model_parallel import DistributedModelParallel
from torchrec.distributed.planner import (
EmbeddingShardingPlanner,
Topology,
)
from torchrec.distributed.quant_embeddingbag import QuantEmbeddingBagCollectionSharder
from torchrec.distributed.test_utils.test_model import (
ModelInput,
)
from torchrec.distributed.types import (
ModuleSharder,
ShardingEnv,
ShardingPlan,
)
from torchrec.sparse.jagged_tensor import KeyedTensor
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.distributed.test_utils.test_sharding import copy_state_dict
from torchrec.inference.modules import quantize_embeddings
class ReproduceModel(nn.Module):
def __init__(self):
super().__init__()
table_params = [
[187, 128],
[844, 288],
[310, 444],
[870, 20],
[704, 512],
]
weighted_table_params = [
[975, 316],
[439, 612],
[855, 284],
]
self.tables = [
EmbeddingBagConfig(
num_embeddings=table_params[i][0],
embedding_dim=table_params[i][1],
name="table_" + str(i),
feature_names=["feature_" + str(i)],
data_type=torch.int64,
)
for i in range(len(table_params))
]
self.weighted_tables = [
EmbeddingBagConfig(
num_embeddings=weighted_table_params[i][0],
embedding_dim=weighted_table_params[i][1],
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(len(weighted_table_params))
]
self.dense = nn.Linear(in_features=984, out_features=551, bias=True)
self.sparse = EmbeddingBagCollection(
tables=self.tables,
is_weighted=False,
)
self.sparse_weighted = EmbeddingBagCollection(
tables=self.weighted_tables,
is_weighted=True,
)
in_features_concat = (
self.dense.out_features
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in self.tables
]
)
+ sum(
[
table.embedding_dim * len(table.feature_names)
for table in self.weighted_tables
]
)
)
self.over = nn.Linear(in_features=in_features_concat, out_features=21, bias=True)
def forward(
self,
input,
):
dense_r = self.dense(input.float_features)
sparse_r = self.sparse(input.idlist_features)
sparse_weighted_r = self.sparse_weighted(input.idscore_features)
result = KeyedTensor(
keys=sparse_r.keys() + sparse_weighted_r.keys(),
length_per_key=sparse_r.length_per_key()
+ sparse_weighted_r.length_per_key(),
values=torch.cat([sparse_r.values(), sparse_weighted_r.values()], dim=1),
)
_features = [
feature for table in self.tables for feature in table.feature_names
]
_weighted_features = [
feature for table in self.weighted_tables for feature in table.feature_names
]
ret_list = []
ret_list.append(dense_r)
for feature_name in _features:
ret_list.append(result[feature_name])
for feature_name in _weighted_features:
ret_list.append(result[feature_name])
ret_concat = torch.cat(ret_list, dim=1)
over_r = self.over(ret_concat)
pred = torch.sigmoid(torch.mean(over_r, dim=1))
if self.training:
return (
torch.nn.functional.binary_cross_entropy_with_logits(pred, input.label),
pred, (dense_r, sparse_r, sparse_weighted_r, over_r),
)
else:
return pred, (dense_r, sparse_r, sparse_weighted_r, over_r)
def quantize_embeddings(
module: nn.Module,
dtype: torch.dtype,
inplace: bool,
additional_qconfig_spec_keys: Optional[List[Type[nn.Module]]] = None,
additional_mapping: Optional[Dict[Type[nn.Module], Type[nn.Module]]] = None,
output_dtype: torch.dtype = torch.float32,
) -> nn.Module:
import torch.quantization as quant
import torchrec as trec
import torchrec.quant as trec_quant
qconfig = quant.QConfig(
activation=quant.PlaceholderObserver.with_args(dtype=output_dtype),
weight=quant.PlaceholderObserver.with_args(dtype=dtype),
)
qconfig_spec: Dict[Type[nn.Module], quant.QConfig] = {
trec.EmbeddingBagCollection: qconfig,
}
mapping: Dict[Type[nn.Module], Type[nn.Module]] = {
trec.EmbeddingBagCollection: trec_quant.EmbeddingBagCollection,
}
if additional_qconfig_spec_keys is not None:
for t in additional_qconfig_spec_keys:
qconfig_spec[t] = qconfig
if additional_mapping is not None:
mapping.update(additional_mapping)
return quant.quantize_dynamic(
module,
qconfig_spec=qconfig_spec,
mapping=mapping,
inplace=inplace,
)
def sharding_single_rank_test(
world_size: int,
model,
inputs,
sharders: List[ModuleSharder[nn.Module]],
quant_dtype = None,
quant_output_dtype = None,
) -> None:
device = torch.device("cuda:0")
model = model.to(device)
model = quantize_embeddings(model, dtype=quant_dtype, inplace=True, output_dtype=quant_output_dtype)
global_model = copy.deepcopy(model)
global_model = global_model.to(device)
global_input = inputs[0][0].to(device)
local_model = copy.deepcopy(model)
planner = EmbeddingShardingPlanner(
topology=Topology(
world_size, device.type
)
)
plan: ShardingPlan = planner.plan(local_model, sharders)
local_model = DistributedModelParallel(
local_model,
env=ShardingEnv.from_local(world_size=world_size, rank=0),
plan=plan,
sharders=sharders,
device=device,
init_data_parallel=False,
)
copy_state_dict(local_model.state_dict(), global_model.state_dict())
local_pred, (local_dense_r, local_sparse_r, local_sparse_weighted_r, local_over_r) = gen_full_pred_after_one_step(local_model, global_input)
global_pred, (global_dense_r, global_sparse_r, global_sparse_weighted_r, global_over_r) = gen_full_pred_after_one_step(global_model, global_input)
print("Linf: ", torch.max(torch.abs(global_pred - local_pred)))
print("Linf dense: ", torch.max(torch.abs(global_dense_r - local_dense_r)))
print("Linf sparse: ", torch.max(torch.abs(local_sparse_r.values() - global_sparse_r.values())))
print("Linf sparse weighted: ", torch.max(torch.abs(local_sparse_weighted_r.values() - global_sparse_weighted_r.values())))
print("Linf over: ", torch.max(torch.abs(global_over_r - local_over_r)))
def gen_full_pred_after_one_step(
model: nn.Module,
input: ModelInput,
) -> torch.Tensor:
# Run a forward pass of the global model.
with torch.no_grad():
model.train(False)
full_pred, intermediate_list = model(input)
return full_pred, intermediate_list
from torchrec.distributed.embedding_types import EmbeddingTableConfig
from typing import Protocol, cast
class ModelInputCallable(Protocol):
def __call__(
self,
batch_size: int,
world_size: int,
num_float_features: int,
tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
weighted_tables: Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]],
pooling_avg: int = 10,
dedup_tables: Optional[
Union[List[EmbeddingTableConfig], List[EmbeddingBagConfig]]
] = None,
variable_batch_size: bool = False,
long_indices: bool = True,
) -> Tuple["ModelInput", List["ModelInput"]]: ...
def main_test_quant(
sharders: List[ModuleSharder[nn.Module]],
world_size: int = 2,
quant_dtype = None,
quant_output_dtype = None,
) -> None:
model = ReproduceModel()
inputs = [
(
cast(ModelInputCallable, ModelInput.generate)(
world_size=world_size,
tables=model.tables,
weighted_tables=model.weighted_tables,
num_float_features=model.dense.in_features,
batch_size=256,
)
)
]
inputs[0][0].idlist_features._values = inputs[0][0].idlist_features._values.to(dtype=torch.int32)
inputs[0][0].idscore_features._values = inputs[0][0].idscore_features._values.to(dtype=torch.int32)
sharding_single_rank_test(
world_size=world_size,
model=model,
inputs=inputs,
sharders=sharders,
quant_dtype = quant_dtype,
quant_output_dtype = quant_output_dtype,
)
class TestQuantEBCSharder(QuantEmbeddingBagCollectionSharder):
def __init__(self, sharding_type: str, kernel_type: str) -> None:
super().__init__()
self._sharding_type = sharding_type
self._kernel_type = kernel_type
def sharding_types(self, compute_device_type: str) -> List[str]:
return [self._sharding_type]
def compute_kernels(
self, sharding_type: str, compute_device_type: str
) -> List[str]:
return [self._kernel_type]
@property
def fused_params(self) -> Optional[Dict[str, Any]]:
return None
def main():
# backend = "nccl"
world_size = 3
dtype = torch.qint8
output_dtype = torch.qint8
sharding_type = "table_wise"
kernel_type = "quant"
sharders = [TestQuantEBCSharder(sharding_type, kernel_type)]
main_test_quant(
sharders = sharders,
world_size = world_size,
quant_dtype=dtype,
quant_output_dtype=output_dtype,
)
if __name__ == "__main__":
main()
Logs
Linf: tensor(0.9984, device='cuda:0')
Linf dense: tensor(0., device='cuda:0')
Linf sparse: tensor(255, device='cuda:0', dtype=torch.uint8)
Linf sparse weighted: tensor(255, device='cuda:0', dtype=torch.uint8)
Linf over: tensor(130.5735, device='cuda:0')
Logs in torchrec 0.2.0
Linf: tensor(nan, device='cuda:0')
Linf dense: tensor(0., device='cuda:0')
Linf sparse: tensor(nan, device='cuda:0')
Traceback (most recent call last):
File "/root/miniconda3/envs/pt112tr02/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/root/miniconda3/envs/pt112tr02/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/mnt/util/reproduce_quant_nan_2.py", line 312, in <module>
main()
File "/mnt/util/reproduce_quant_nan_2.py", line 303, in main
main_test_quant(
File "/mnt/util/reproduce_quant_nan_2.py", line 265, in main_test_quant
sharding_single_rank_test(
File "/mnt/util/reproduce_quant_nan_2.py", line 213, in sharding_single_rank_test
print("Linf sparse weighted: ", torch.max(torch.abs(local_sparse_weighted_r.values() - global_sparse_weighted_r.values())))
RuntimeError: The size of tensor a (1212) must match the size of tensor b (1236) at non-singleton dimension 1