TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

feat: Support weight-stripped engine and REFIT_IDENTICAL flag

Open zewenli98 opened this issue 1 year ago • 16 comments

Description

  1. Supported weight-stripped engine
  2. Added REFIT_IDENTICAL flag

Fixes #3146

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • [ ] My code follows the style guidelines of this project (You can use the linters)
  • [ ] I have performed a self-review of my own code
  • [ ] I have commented my code, particularly in hard-to-understand areas and hacks
  • [ ] I have made corresponding changes to the documentation
  • [ ] I have added tests to verify my fix or my feature
  • [ ] New and existing unit tests pass locally with my changes
  • [ ] I have added the relevant labels to my PR in so that relevant reviewers are notified

zewenli98 avatar Sep 19 '24 15:09 zewenli98

@narendasan Ok, at first the overall design was like:

In TRTInterpreter.run():

if compilation_settings.strip_engine_weights is True:
    if engine_cache not hit:
        1. build a weight-stripped engine
        2. save the weight-stripped engine if engine_cache is set
        3. return the weight-stripped engine (not yet refit)
    else:
        load and return the weight-stripped engine (not yet refit)
else:
    if engine_cache not hit:
        1. build a weight-included engine
        2. save the weight-included engine if engine_cache is set
        3. return the weight-included engine (don't need to refit)
    else:
        load and return the weight-included engine (not yet refit)

Then, in TRTModule, refit if necessary before inference. The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

zewenli98 avatar Sep 20 '24 04:09 zewenli98

@narendasan The design was updated.

From the users' perspective, they are able to set make_refittable and refit_identical_engine_weights. make_refittable for general refitting and refit_identical_engine_weights only for refitting with identical weights

if self.compilation_settings.make_refittable:
    if version.parse(trt.__version__) >= version.parse("10.0"):
        if self.compilation_settings.refit_identical_engine_weights:
            builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL)
        else:
            builder_config.set_flag(trt.BuilderFlag.REFIT)
    else:
        builder_config.set_flag(trt.BuilderFlag.REFIT)

Besides, users can specify strip_engine_weights. If strip_engine_weights is True, TRTInterpreter.run() will return weight-stripped engine. Otherwise, return general engine (with weights).

For the 3 workflows mentioned above,

  1. controlling the args above, users can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine.

  2. For engine caching, the implementation of weight-stripped engine is opaque to users, which means engine caching mechanism will (1) save weight-stripped engine no matter what settings users specify (make_refittable is required to be true) and then (2) load and refit the weight-stripped engine while reusing cached engines. If strip_engine_weights is True, the engine will not be refitted. Instead, just returns weight-stripped engine.

  3. If users specify strip_engine_weights=True, calling torch.compile() or torch_trt.dynamo.compile() will return weight-stripped compiled program. If running the compiled program with inputs, all the results will be zeros. Then, calling refit_module_weights will make weights back, e.g.:

from torch_tensorrt.dynamo._refit import refit_module_weights
refitted_trt_gm = refit_module_weights(trt_gm, exp_program)
refitted_output = refitted_trt_gm(*inputs)

Please see more details in the tests.

zewenli98 avatar Sep 20 '24 17:09 zewenli98

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Do you mean we use make_refittable to control both kREFIT and kREFIT_IDENTICAL?

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

narendasan avatar Sep 20 '24 17:09 narendasan

Some of the open questions are:

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?
  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?
  • If building strip weights refit identical + refit is slower than just building?

narendasan avatar Sep 20 '24 18:09 narendasan

The reason that I didn't put the refitting part into TRTInterpreter.run() is that I want to avoid repeated de/serializations of TRT engines: (1) deserialize in TRTInterpreter.run() for refitting and then serialize (2) deserialize in TRTModule again.

I think that we need to separate the runtime and the compiler so im willing to spend the time serializing and deserializing.

I think we should frame PR this around moving TRTInterpreter to default to building weight stripped engines. There will be 3 kinds of engines now.

  1. weight strip + refittable (strip_weights + kREFIT) - should move towards this being the default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_INDIVIDUAL)
  3. non_refittable

The first 2 need separate cache entries. So we need to be able to hash on the weights in the case that the model is being built with kREFIT_INDIVIDUAL

We should look to prefer case 1 in the long term as it allows us to reuse the most work, case 2 would be the next preference. Case 2 should produce faster engines than Case 1 so there remains a need to support kREFIT_IDENTICAL

Are you referring to kREFIT_IDENTICAL or kREFIT_INDIVIDUAL? The updated design only considered kREFIT_IDENTICAL. kREFIT_INDIVIDUAL is for fine-grained control which is not yet to be considered.

zewenli98 avatar Sep 20 '24 18:09 zewenli98

  • how we determine if the weights have been refit prior to running the engine. Can TRT tell us without an error?

My current design is: If users specify strip_engine_weights=True in compile, the weights will not be refitted. They will get a weight-stripped engine. However, if they get an engine somewhere, they can call get_missing_weights() to see if there's any weight not gets refitted.

  • How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures." My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

  • If building strip weights refit identical + refit is slower than just building?

will investigate on it.

zewenli98 avatar Sep 20 '24 18:09 zewenli98

The case for type 3 engines now is only valid if building a non refittable engine is faster than building a refit_identical engine then refitting the weights. If it is not by a significant enough margin I propose we remove that workflow and just have refit or refit_identical engines.

@narendasan I tested on building Resnet18 and vgg16 via the two paths: (1) strip weights + refit_identical + refit (2) non-refittable, build time of the two ways are almost same (diff < 1%), and engine sizes are also almost same (diff < 0.1%). I'm not sure if there are other benefits from non-refittable engines even though the build time, engine size, and performance are the same, like in deployment weights are not allowed to be changed in terms of safety?

zewenli98 avatar Sep 23 '24 16:09 zewenli98

@narendasan I just confirmed with TRT team, the conclusion is engine built with STRIP_PLAN + REFIT_IDENTICAL + refit is almost same as non-refittable engine. Do you prefer to remove non-refittable engine path? If yes, the paths would be:

  1. weight strip + refittable (strip_weights + kREFIT) - default
  2. weight strip + refittable with original weights (strip_weights + kREFIT_IDENTICAL)

So assuming that we can remove type 3 engines, make_refittable really means "allows the weights to be changed" (we can change the name if needed here), since now both engines are refittable they just have different weight constraints.

I think we can rename make_refittable to refit_mode: str: Union["general", "identical"] (may be easier to extend in the future?) or refit_identical_weights: bool. Then, we can remove refit_identical_engine_weights arg which has been committed in this PR.

On top of this, STRIP_PLAN will be always on while building engines. we have strip_engine_weights arg to allow users to control if they want to get weight-stripped engines.

In summary, the 3 workflows mentioned above would be:

  1. Users just want a weight stripped engine. They can call convert_exported_program_to_trt_engine specifying strip_engine_weights=True to get weight-stripped engine. It is also supported if the engine is loaded from engine cache.

  2. We want to utilize weight stripping to have a lighter weight engine cache. The implementation of weight-stripped engine is opaque to users. However, if users specify kREFIT or kREFIT_IDENTICAL, they would be considered as different engine and cached twice.

  3. Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

zewenli98 avatar Sep 24 '24 05:09 zewenli98

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

I still dont know what the usecase for this is

narendasan avatar Sep 30 '24 02:09 narendasan

How can we tell if a user is trying to refit an engine with different weights to an engine built with REFIT_IDENTICAL?

I also thought about it earlier. The TRT doc says "if the refit weights are not identical to the build-time weights, behavior is undefined... This enables use of a single set of weights with different inference backends, or with TensorRT plans for multiple GPU architectures." My understanding is that we cannot tell if weights are identical in build time and refitting, from the perspective of engine itself, because weight-stripped engine doesn't compare weights in build time and refitting phase, or give any prompts. So users need to be clear what they are refitting.

We should think about a solution for this since behavior is undefined

narendasan avatar Sep 30 '24 02:09 narendasan

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

I still dont know what the usecase for this is

I think this allows users to have a weight-stripped compiled program first, and then they can refit with different weights. For example:

exp_program = torch.export.export(
    pyt_model, args=inputs, dynamic_shapes={"x": {0: batch}}
)
stripped_gm = torch_trt.dynamo.compile(exp_program, strip_engine_weights=True, ...)
refitted_gm = refit_module_weights(stripped_gm, exp_program)
exp_program2 = ...
refitted_gm2 = refit_module_weights(stripped_gm, exp_program2)

zewenli98 avatar Sep 30 '24 20:09 zewenli98

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

I was aware that some ops/nodes are not refittable, like cumsum or embedding_bag. That was handled in the previous PR: https://github.com/pytorch/TensorRT/pull/3159. However, if make_refittable is removed, we may need another flag to tell TRT to build a non-refittable engine?

zewenli98 avatar Oct 01 '24 07:10 zewenli98

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

Users want a stripped weights compiled program. They just need to call torch.compile() or torch_trt.dynamo.compile() with strip_engine_weights=True. If running the compiled program with inputs immediately, all the results will always be zeros. Calling refit_module_weights() will make weights back

I still dont know what the usecase for this is

I think this allows users to have a weight-stripped compiled program first, and then they can refit with different weights. For example:

Yeah but why would they care that there are no weights inside? The thing I'm worried about is now you get these non live programs returned to users. So either we need to be able to check if the weights have been refitted before running or we shouldn't let exported programs have stripped weights

narendasan avatar Oct 01 '24 14:10 narendasan

I think we should remove non-refittable then and we can add it back as a non default workflow later if theres some reason to.

I was aware that some ops/nodes are not refittable, like cumsum or embedding_bag. That was handled in the previous PR: #3159. However, if make_refittable is removed, we may need another flag to tell TRT to build a non-refittable engine?

The behavior for those cases is to fallback to PyTorch in the case refitting is enabled. But we could have a setting like immutable_weights or something that indicates that the weights cant be changed and we skip refitting.

narendasan avatar Oct 01 '24 15:10 narendasan

Yeah but why would they care that there are no weights inside?

I was thinking if this is a possible scenario that users call compile function to get a weight-stripped compiled program and then distribute the program to different nodes/GPUs, and each node refits with different weights?

The thing I'm worried about is now you get these non live programs returned to users. So either we need to be able to check if the weights have been refitted before running or we shouldn't let exported programs have stripped weights

Natively TRT doesn't restrict the running of weight-stripped engines. It just returns results with all zeros. But yeah if the scenario above is not a possible case, we can just return weight-included exported programs.

zewenli98 avatar Oct 01 '24 16:10 zewenli98

Natively TRT doesn't restrict the running of weight-stripped engines. It just returns results with all zeros. But yeah if the scenario above is not a possible case, we can just return weight-included exported programs.

I guess thats better than nothing, ideally still we have a way to detect if weights have been refitted or not

narendasan avatar Oct 02 '24 03:10 narendasan

@zewenli98 / @peri044 I could reproduce "misaligned address cuda error" with weight streaming test. This problem started when refit build flag is add as default. When I remove trt.BuilderFlag.REFIT, test passed. I will check further.

def _populate_trt_builder_config(
    self,
    strict_type_constraints: bool = False,
    algorithm_selector: Optional[trt.IAlgorithmSelector] = None,
    tactic_sources: Optional[int] = None,
) -> trt.IBuilderConfig:

... if self.compilation_settings.immutable_weights: # non-refittable engine if self.compilation_settings.strip_engine_weights: _LOGGER.warning("strip_engine_weights will be ignored.") else: # refittable engine if self.compilation_settings.refit_identical_engine_weights: builder_config.set_flag(trt.BuilderFlag.REFIT_IDENTICAL) else: builder_config.set_flag(trt.BuilderFlag.REFIT) <----- refit flag is set as default

keehyuna avatar Nov 21 '24 10:11 keehyuna

Thanks @keehyuna! Does this error happen with trt.BuilderFlag.REFIT_IDENTICAL flag?

zewenli98 avatar Nov 21 '24 17:11 zewenli98

Thanks @keehyuna! Does this error happen with trt.BuilderFlag.REFIT_IDENTICAL flag?

This is no problem when trt.BuilderFlag.REFIT_IDENTICAL is used. It was tested on main branch.

keehyuna avatar Nov 22 '24 01:11 keehyuna