FBGEMM icon indicating copy to clipboard operation
FBGEMM copied to clipboard

FBGEMM kernel for KeyedTensor (PooledEmbedding) permute mapping

Open TroyGarden opened this issue 1 year ago • 2 comments

Summary: X-link: https://github.com/pytorch/torchrec/pull/2120

context

  • current we have a working function permute_pooled_embs_auto_grad to do a full permute of KTs, including forward and backward
  • it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main use case there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation
  • there is some attempt to support duplicated outputs, but the backward doesn't work
  • this diff is trying to create a new kernel (named multi_permute_pooled_embedding) to support a multiple-KT to multiple-KT mapping operation with backward support

operator example usage

  • used in python
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
    torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
    for lens in lengths
]

# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]

# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
    keys, lengths, groups
)

# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding(
    values, # list of tensors (on device)
    permutes.to(device=torch.device("cuda")), # tensor on device
    out_lengths.tolist(), # List[int] on CPU
    in_lengths.to(device=torch.device("cuda")), # tensor on device
    out_lengths.to(device=torch.device("cuda")), # tensor on device
)
  • values
permutes = tensor(
            [
                [0, 0, 0, 0, 3, 4],  # f1
                [1, 0, 0, 3, 5, 0],  # f3
                [0, 1, 3, 0, 4, 0],  # f2
                [1, 2, 5, 0, 6, 0],  # f4
                [0, 2, 0, 6, 3, -6],  # f1
                [2, 2, 0, 9, 8, 0],  # f6
                [0, 3, 0, 0, 3, -8],  # f1
                [1, 3, 11, 3, 7, 0],  # f5
            ]
)

details

  1. from the above example usage, we can clean see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
  2. the operator returns a list of tensors, which represents the permuted KTs
  3. permute is the most critical argument in this operator: a) 2-D tensor b) each row represents key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors

performance notes

The good:

  1. the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes.
  2. _all_keys_used_once is no longer needed
  3. no longer need a torch.cat before calling the old operator

The same bad:

  1. it requires several HtoD communications (move tensor to device): a) 3 tensors, which are permutes, input_lengths, and output_lengths. Those tensors needs to be on the device so that the cuda kernels has access to it. b) 2 lists of (scalar_t*) pointers, input and output tensor lists. c) Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device.
  2. tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous

Differential Revision: D57055616

TroyGarden avatar Jun 17 '24 05:06 TroyGarden

Deploy Preview for pytorch-fbgemm-docs failed.

Name Link
Latest commit e1f45da7558b52cd507e19714ff1241949feae8f
Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/666fcd900b05d000080b224e

netlify[bot] avatar Jun 17 '24 05:06 netlify[bot]

This pull request was exported from Phabricator. Differential Revision: D57055616

facebook-github-bot avatar Jun 17 '24 05:06 facebook-github-bot

This pull request has been merged in pytorch/FBGEMM@87cfbdff45ac661f9d607cf63a39d3e0e0124f86.

facebook-github-bot avatar Jul 09 '24 18:07 facebook-github-bot