torchrec
torchrec copied to clipboard
Refactor stride_per_key_per_rank to support torch.Tensor
Summary:
stride_per_key_per_rank should be a variable whose value is dynamically after input_dist.
Updating its type to Union[Optional[torch.Tensor], Optional[List[List[int]]]] to be backward compatible.
Differential Revision: D72658640
This pull request was exported from Phabricator. Differential Revision: D72658640