feat: improve engine caching and fix bugs
Description
As I requested, TensorRT 10.14 added an argument trt.SerializationFlag.INCLUDE_REFIT to allow refitted engines to keep refittable. That means engines can be refitted multiple times. Based on the capability, this PR enhances the existing engine caching and refitting features as follows:
- To save hard disk space, engine caching will only save weight-stripped engines on disk regardless of
compilation_settings.strip_engine_weights. Then, when users pull out the cached engine, it will be automatically refitted and kept refittable. - Compiled TRT modules can be refitted multiple times with
refit_module_weights(). e.g.:
for _ in range(3):
trt_gm = refit_module_weights(trt_gm, exp_program)
- Due to some changes, the insertion and pulling of cached engines are located in different places, which causes #3909. This PR unified the insertion and pulling in
_conversion.py.
Type of change
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
Checklist:
- [x] My code follows the style guidelines of this project (You can use the linters)
- [x] I have performed a self-review of my own code
- [x] I have commented my code, particularly in hard-to-understand areas and hacks
- [ ] I have made corresponding changes to the documentation
- [x] I have added tests to verify my fix or my feature
- [x] New and existing unit tests pass locally with my changes
- [x] I have added the relevant labels to my PR in so that relevant reviewers are notified
@cehongwang please take a pass so we have multiple eyes on this PR
The reason why JIT's output is not all zeros when strip_engine_weights=True is that AOT and JIT generate different GM before converting to TRT engine. JIT graphs are always weightless because weights are passed by input.
AOT's weights are stored in the model:
graph():
%conv1_weight : [num_users=1] = get_attr[target=conv1.weight]
%bn1_weight : [num_users=1] = get_attr[target=bn1.weight]
%bn1_bias : [num_users=1] = get_attr[target=bn1.bias]
%layer1_0_conv1_weight : [num_users=1] = get_attr[target=layer1.0.conv1.weight]
%layer1_0_bn1_weight : [num_users=1] = get_attr[target=layer1.0.bn1.weight]
%layer1_0_bn1_bias : [num_users=1] = get_attr[target=layer1.0.bn1.bias]
%layer1_0_conv2_weight : [num_users=1] = get_attr[target=layer1.0.conv2.weight]
%layer1_0_bn2_weight : [num_users=1] = get_attr[target=layer1.0.bn2.weight]
%layer1_0_bn2_bias : [num_users=1] = get_attr[target=layer1.0.bn2.bias]
...
%layer4_1_bn2_running_mean : [num_users=1] = get_attr[target=layer4.1.bn2.running_mean]
%layer4_1_bn2_running_var : [num_users=1] = get_attr[target=layer4.1.bn2.running_var]
%x : [num_users=1] = placeholder[target=x]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %conv1_weight, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), kwargs = {})
%_native_batch_norm_legit_no_training : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%convolution, %bn1_weight, %bn1_bias, %bn1_running_mean, %bn1_running_var, 0.1, 1e-05), kwargs = {})
but JIT uses placeholder to get the weights on the fly, so there's actually no weights to be stripped.
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%arg1_1 : [num_users=1] = placeholder[target=arg1_1]
%arg2_1 : [num_users=1] = placeholder[target=arg2_1]
%arg3_1 : [num_users=1] = placeholder[target=arg3_1]
%arg4_1 : [num_users=1] = placeholder[target=arg4_1]
%arg5_1 : [num_users=1] = placeholder[target=arg5_1]
%arg6_1 : [num_users=1] = placeholder[target=arg6_1]
%arg7_1 : [num_users=1] = placeholder[target=arg7_1]
%arg8_1 : [num_users=1] = placeholder[target=arg8_1]
%arg9_1 : [num_users=1] = placeholder[target=arg9_1]
...
%arg101_1 : [num_users=1] = placeholder[target=arg101_1]
%arg102_1 : [num_users=1] = placeholder[target=arg102_1]
%convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%arg1_1, %arg0_1, None, [2, 2], [3, 3], [1, 1], False, [0, 0], 1), kwargs = {})
%_native_batch_norm_legit_no_training : [num_users=1] = call_function[target=torch.ops.aten._native_batch_norm_legit_no_training.default](args = (%convolution, %arg4_1, %arg5_1, %arg2_1, %arg3_1, 0.1, 1e-05), kwargs = {})