FBGEMM
FBGEMM copied to clipboard
Fix int_nbit inference int8 nobag kernel meta function
Summary: TLDR;
Fix int8 nobag in TBE inference meta function such that
- output shape is {total_L, D + kINT8QparamsBytes}
- kINT8QparamsBytes = 4
Detail
For nobag int8, the output shape should be {total_L, D + kINT8QparamsBytes}, since total_L dimension already includes T. T * was unintentionally added in D36018114.
kINT8QparamsBytes is 4 in CPU, since a half is used. However, 8 is used in CUDA.
Our meta implementation follows CUDA implementation which mismatches that of CPU.
This diff removes T* from the output shape and change kINT8QparamsBytes to be 4 for meta implementation to match CPU and production.
There has been no issue because both our meta and CUDA kernels are not currently used in production.
CUDA kernel changes will be in the next diff.
Note that this is currently used meta function is fbgemm_int_nbit_split_embedding_codegen_lookup_function_meta, which has different logic for int8 and nobag cases.
The discrepancy has not been an issue because:
- Nobag
- split_dispatcher: D = average D
- FBGEMM: D = max(max_D of each dtype)
-> The embedding dimensions are the same, so average D = max D.
- Int8 Pooled
- split_dispatcher: [B, total_D] here
- FBGEMM: [B, total_D + T * 8]
-> This is not being used in prod
This will be a problem if embedding dimensions are mixed, or int8 pooled is going to be used.
Differential Revision: D75808485
Deploy Preview for pytorch-fbgemm-docs ready!
| Name | Link |
|---|---|
| Latest commit | 80720f9e7d58fcdcc279b725a58820b1987221f2 |
| Latest deploy log | https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/684a6b70867ca400086b74ec |
| Deploy Preview | https://deploy-preview-4333--pytorch-fbgemm-docs.netlify.app |
| Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify project configuration.
This pull request was exported from Phabricator. Differential Revision: D75808485