torchrec
torchrec copied to clipboard
Supported nested preproc modules which are called multiple times with different args
Summary: Ran into 3 issues while enabling pipeline for a model:
- Current pipeline logic for finding and swapping a preproc module only works if the preproc module exists at model level. If the preproc is within a model's child modules, this logic would break down e.g.
model._sparse_arch._preproc_module. Finding a module would not work as this usedgetattron the model and swapping the module would fail as this usedsetattron the model. Solution:- Replaced
getattrandsetattrwith_find_preproc_module_recursiveand_swap_preproc_module_recursiverespectively.
- Replaced
- In this model, the same preproc module was called 2 times with 2 different sets of arguments passed to
forward(). Current logic wouldn't handle this correctly as a) we would only ever created 1 instance of each pipelined preproc with its captured arg info from tracing (even though this should be different for each invocation) and B) we would cache results based on the preproc module's FQN only. Solution:- If we see another instance of PipelinedPreproc call, we still capture its argument's graph and add the
List[ArgInfo]toPipelinedPreprocs arg info list viapreproc_module.register_args(preproc_args). - Each time we call pipelined preproc forward during pipeline execution, we need to fetch the right arg info list. So I added
self._call_idxtoPipelinedPreprocthat gets incremented each time we call fwd, and simply indx into arg info list using this index. - Changed the cache key to
self._fqn + str(self._call_idx). Ideally, we would have a differentcache_keyfor each FQN + arg + kwargs combination, but materializing this into a str / object could be too expensive as these args are large model input KJT / tensors.
- If we see another instance of PipelinedPreproc call, we still capture its argument's graph and add the
- Logic doesn't support if an arg to a preproc module is a constant (e.g.
self.model.constant_value) as we skip args that aren'ttorch.fx.Nodevalues. However, we should be able to pipeline these cases. Solution:- Add a new field to
ArgInfocalledobjectsof typeList[Optional[object]]. After fx tracing, you will have fx immutable collections, such astorch.fx.immutable_dictfor immutableDict. Creating a copy converts it back to mutable original value. So we capture this variable inArgInfo. Potential downside is the extra memory overhead, but for this model in particular, this was just a small string value.
- Add a new field to
Reviewed By: joshuadeng
Differential Revision: D61155773
This pull request was exported from Phabricator. Differential Revision: D61155773