torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Refactor stride_per_key_per_rank to support torch.Tensor

Open jd7-tr opened this issue 7 months ago • 1 comments

Summary: stride_per_key_per_rank should be a variable whose value is dynamically after input_dist.

Updating its type to Union[Optional[torch.Tensor], Optional[List[List[int]]]] to be backward compatible.

Differential Revision: D72658640

jd7-tr avatar Apr 08 '25 18:04 jd7-tr

This pull request was exported from Phabricator. Differential Revision: D72658640

facebook-github-bot avatar Apr 08 '25 18:04 facebook-github-bot