torchrec
torchrec copied to clipboard
Update cost model to include prefetch_compute
Summary: Updated version of D50162035 that uses the new CacheStatistics interface, and updated for more sharding types.
Jobs with large prefetch streams can be slowed down by large prefetch_compute blocking backward: https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-leiyuz-test-ip40-b6589377e7?job_attempt=0&version=0&tab=execution_details&env=PRODUCTION
This is because the prefetch work is not balanced as the sharder is unaware of it.
This diff adds a new dimension to the Perf cost model to hold prefetch_compute. This is only non-zero when using embedding offloading via the prefetch pipeline.
We use the cache statistics to estimate the expected number of cache fetches and then convert this to a prefetch duration estimate based on memory bandwidth.
For per shard costing, we use the sum of fwd & prefetch, even though these will attempt to run in parallel. See the comment below for reasoning. In the future, once / if we can correctly model dense fwd, we can update NoopPerfModel to correctly account for the parallelism between forward & prefetch (see https://www.internalfb.com/diff/D51461833?dst_version_fbid=3192784921018181&transaction_fbid=709488211080104 for why we can't do this today).
More details in: https://docs.google.com/document/d/1rcYh6yi23N0USp_T4nySxVnTxllBwiQlRam6EWGIy-g/edit
Reviewed By: henrylhtsang
Differential Revision: D51461817
This pull request was exported from Phabricator. Differential Revision: D51461817