torchrec icon indicating copy to clipboard operation
torchrec copied to clipboard

Supported nested preproc modules which are called multiple times with different args

Open sarckk opened this issue 1 year ago • 1 comments

Summary: Ran into 3 issues while enabling pipeline for a model:

  1. Current pipeline logic for finding and swapping a preproc module only works if the preproc module exists at model level. If the preproc is within a model's child modules, this logic would break down e.g. model._sparse_arch._preproc_module. Finding a module would not work as this used getattr on the model and swapping the module would fail as this used setattr on the model. Solution:
    • Replaced getattr and setattr with _find_preproc_module_recursive and _swap_preproc_module_recursive respectively.
  2. In this model, the same preproc module was called 2 times with 2 different sets of arguments passed to forward(). Current logic wouldn't handle this correctly as a) we would only ever created 1 instance of each pipelined preproc with its captured arg info from tracing (even though this should be different for each invocation) and B) we would cache results based on the preproc module's FQN only. Solution:
    • If we see another instance of PipelinedPreproc call, we still capture its argument's graph and add the List[ArgInfo] to PipelinedPreprocs arg info list via preproc_module.register_args(preproc_args).
    • Each time we call pipelined preproc forward during pipeline execution, we need to fetch the right arg info list. So I added self._call_idx to PipelinedPreproc that gets incremented each time we call fwd, and simply indx into arg info list using this index.
    • Changed the cache key to self._fqn + str(self._call_idx). Ideally, we would have a different cache_key for each FQN + arg + kwargs combination, but materializing this into a str / object could be too expensive as these args are large model input KJT / tensors.
  3. Logic doesn't support if an arg to a preproc module is a constant (e.g. self.model.constant_value) as we skip args that aren't torch.fx.Node values. However, we should be able to pipeline these cases. Solution:
    • Add a new field to ArgInfo called objects of type List[Optional[object]]. After fx tracing, you will have fx immutable collections, such as torch.fx.immutable_dict for immutable Dict. Creating a copy converts it back to mutable original value. So we capture this variable in ArgInfo. Potential downside is the extra memory overhead, but for this model in particular, this was just a small string value.

Reviewed By: joshuadeng

Differential Revision: D61155773

sarckk avatar Aug 22 '24 21:08 sarckk

This pull request was exported from Phabricator. Differential Revision: D61155773

facebook-github-bot avatar Aug 22 '24 21:08 facebook-github-bot