FBGEMM
FBGEMM copied to clipboard
FBGEMM kernel for KeyedTensor (PooledEmbedding) permute mapping
Summary: X-link: https://github.com/pytorch/torchrec/pull/2120
context
- current we have a working function
permute_pooled_embs_auto_gradto do a full permute of KTs, including forward and backward - it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main use case there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation
- there is some attempt to support duplicated outputs, but the backward doesn't work
- this diff is trying to create a new kernel (named
multi_permute_pooled_embedding) to support a multiple-KT to multiple-KT mapping operation with backward support
operator example usage
- used in python
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding(
values, # list of tensors (on device)
permutes.to(device=torch.device("cuda")), # tensor on device
out_lengths.tolist(), # List[int] on CPU
in_lengths.to(device=torch.device("cuda")), # tensor on device
out_lengths.to(device=torch.device("cuda")), # tensor on device
)
- values
permutes = tensor(
[
[0, 0, 0, 0, 3, 4], # f1
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1
[1, 3, 11, 3, 7, 0], # f5
]
)
details
- from the above example usage, we can clean see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
- the operator returns a list of tensors, which represents the permuted KTs
-
permuteis the most critical argument in this operator: a) 2-D tensor b) each row represents key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors
performance notes
The good:
- the algorithm is designed in a way that it doesn't need to know in advance whether the 1-to-N mapping exists in the permutes.
-
_all_keys_used_onceis no longer needed - no longer need a torch.cat before calling the old operator
The same bad:
- it requires several HtoD communications (move tensor to device):
a) 3 tensors, which are
permutes,input_lengths, andoutput_lengths. Those tensors needs to be on the device so that the cuda kernels has access to it. b) 2 lists of (scalar_t*) pointers, input and output tensor lists. c) Didn't find a good way to let the kernel knows the address of the lists of input/output tensors, because the lists are also need to be on the device. - tensor.contiguous for the backward function, it looks like the grad from the backward are somehow not contiguous
Differential Revision: D57055616
Deploy Preview for pytorch-fbgemm-docs failed.
| Name | Link |
|---|---|
| Latest commit | e1f45da7558b52cd507e19714ff1241949feae8f |
| Latest deploy log | https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/666fcd900b05d000080b224e |
This pull request was exported from Phabricator. Differential Revision: D57055616
This pull request has been merged in pytorch/FBGEMM@87cfbdff45ac661f9d607cf63a39d3e0e0124f86.