torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Update cost model to include prefetch_compute

Open damianr99 opened this issue 1 year ago • 1 comments

Summary: Updated version of D50162035 that uses the new CacheStatistics interface, and updated for more sharding types.

Jobs with large prefetch streams can be slowed down by large prefetch_compute blocking backward: https://www.internalfb.com/mlhub/pipelines/runs/mast/aps-leiyuz-test-ip40-b6589377e7?job_attempt=0&version=0&tab=execution_details&env=PRODUCTION

This is because the prefetch work is not balanced as the sharder is unaware of it.

This diff adds a new dimension to the Perf cost model to hold prefetch_compute. This is only non-zero when using embedding offloading via the prefetch pipeline.

We use the cache statistics to estimate the expected number of cache fetches and then convert this to a prefetch duration estimate based on memory bandwidth.

For per shard costing, we use the sum of fwd & prefetch, even though these will attempt to run in parallel. See the comment below for reasoning. In the future, once / if we can correctly model dense fwd, we can update NoopPerfModel to correctly account for the parallelism between forward & prefetch (see https://www.internalfb.com/diff/D51461833?dst_version_fbid=3192784921018181&transaction_fbid=709488211080104 for why we can't do this today).

More details in: https://docs.google.com/document/d/1rcYh6yi23N0USp_T4nySxVnTxllBwiQlRam6EWGIy-g/edit

Reviewed By: henrylhtsang

Differential Revision: D51461817

damianr99 avatar Jan 02 '24 19:01 damianr99

This pull request was exported from Phabricator. Differential Revision: D51461817

facebook-github-bot avatar Jan 02 '24 19:01 facebook-github-bot