TensorRT
TensorRT copied to clipboard
SDXL Accuracy Investigation
Analysis Findings
- Inference results [FP16] for both Llama and SDXL models in Torch-TensorRT's
torch.compilebackend have accuracy discrepancies relative to the Torch counterpart model- Specifically, the inference results from Llama and SDXL vary from the Torch inference results when provided the same seed. This behavior is not reproduced on SD plain or smaller transformer-based models
- One of the few complex layers which SDXL's UNet and Llama share, is
torch.ops.aten._scaled_dot_product_efficient_attention.default. We begin with this as a potential source of error. - After testing a multitude of different configurations to isolate the issue, we observed the following code block produced a large margin of error when TRT is compared against Torch and Numpy:
Input shapes: [(2, 4096, 640), (2, 10, 77, 64), (2, 10, 77, 64), (2, 4096, 640)]
graph():
%getitem_28 : [num_users=1] = placeholder[target=getitem_28]
%_to_copy_51 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%getitem_28,), kwargs = {dtype: torch.float16})
%reshape_default_17 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%_to_copy_51, [8192, 640]), kwargs = {})
%_frozen_param40 : [num_users=1] = get_attr[target=_frozen_param40]
%mm_12 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%reshape_default_17, %_frozen_param40), kwargs = {})
%reshape_default_18 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_12, [2, 4096, 640]), kwargs = {})
%reshape_default_23 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%reshape_default_18, [2, -1, 10, 64]), kwargs = {})
%permute_20 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%reshape_default_23, [0, 2, 1, 3]), kwargs = {})
%permute_21 : [num_users=1] = placeholder[target=permute_21]
%permute_22 : [num_users=1] = placeholder[target=permute_22]
%scaled_dot_product_attention_1 : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%permute_20, %permute_21, %permute_22), kwargs = {})
%permute_23 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%scaled_dot_product_attention_1, [0, 2, 1, 3]), kwargs = {})
%reshape_default_26 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_23, [2, -1, 640]), kwargs = {})
%reshape_default_27 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%reshape_default_26, [8192, 640]), kwargs = {})
%_frozen_param43 : [num_users=1] = get_attr[target=_frozen_param43]
%mm_15 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%reshape_default_27, %_frozen_param43), kwargs = {})
%mul_34 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm_15, 1), kwargs = {})
%_frozen_param44 : [num_users=1] = get_attr[target=_frozen_param44]
%add_17 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param44, %mul_34), kwargs = {})
%reshape_default_28 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%add_17, [2, 4096, 640]), kwargs = {})
%clone_10 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%reshape_default_28,), kwargs = {})
%div_6 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%clone_10, 1.0), kwargs = {})
%add_16 : [num_users=1] = placeholder[target=add_16]
%add_18 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%div_6, %add_16), kwargs = {})
%_to_copy_54 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add_18,), kwargs = {dtype: torch.float32})
return (_to_copy_54, add_18)
The above was then narrowed to the following simple matrix-multiply, which when run in FP16 with the dimensions (8192 x 640), (640 x 640), as is used in our SDXL configuration, produces a maximum difference of 10 between two elements in the output of TRT vs that of Torch. The mean difference was also high, at around 0.5.
class TestModule(torch.nn.Module):
def forward(self, q, k):
return (q@k)
Additionally, the native_layer_norm operator may be contributing to the error, since its exclusion brings improved accuracy as well. This is also under investigation.
Update
We have further narrowed the matmul cases for easier example-reproducing
Next Steps
- See if issues persist when using FP32 precision, again narrow down the cases to identify layers which could be resulting in the accuracy issues if they continue