TensorRT icon indicating copy to clipboard operation
TensorRT copied to clipboard

SDXL Accuracy Investigation

Open gs-olive opened this issue 2 years ago • 2 comments

gs-olive avatar Dec 12 '23 20:12 gs-olive

Analysis Findings

  • Inference results [FP16] for both Llama and SDXL models in Torch-TensorRT's torch.compile backend have accuracy discrepancies relative to the Torch counterpart model
    • Specifically, the inference results from Llama and SDXL vary from the Torch inference results when provided the same seed. This behavior is not reproduced on SD plain or smaller transformer-based models
  • One of the few complex layers which SDXL's UNet and Llama share, is torch.ops.aten._scaled_dot_product_efficient_attention.default. We begin with this as a potential source of error.
  • After testing a multitude of different configurations to isolate the issue, we observed the following code block produced a large margin of error when TRT is compared against Torch and Numpy:
 Input shapes: [(2, 4096, 640), (2, 10, 77, 64), (2, 10, 77, 64), (2, 4096, 640)]
 graph():
    %getitem_28 : [num_users=1] = placeholder[target=getitem_28]
    %_to_copy_51 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%getitem_28,), kwargs = {dtype: torch.float16})
    %reshape_default_17 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%_to_copy_51, [8192, 640]), kwargs = {})
    %_frozen_param40 : [num_users=1] = get_attr[target=_frozen_param40]
    %mm_12 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%reshape_default_17, %_frozen_param40), kwargs = {})
    %reshape_default_18 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%mm_12, [2, 4096, 640]), kwargs = {})
    %reshape_default_23 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%reshape_default_18, [2, -1, 10, 64]), kwargs = {})
    %permute_20 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%reshape_default_23, [0, 2, 1, 3]), kwargs = {})
    %permute_21 : [num_users=1] = placeholder[target=permute_21]
    %permute_22 : [num_users=1] = placeholder[target=permute_22]
    %scaled_dot_product_attention_1 : [num_users=1] = call_function[target=torch._C._nn.scaled_dot_product_attention](args = (%permute_20, %permute_21, %permute_22), kwargs = {})
    %permute_23 : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%scaled_dot_product_attention_1, [0, 2, 1, 3]), kwargs = {})
    %reshape_default_26 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%permute_23, [2, -1, 640]), kwargs = {})
    %reshape_default_27 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%reshape_default_26, [8192, 640]), kwargs = {})
    %_frozen_param43 : [num_users=1] = get_attr[target=_frozen_param43]
    %mm_15 : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%reshape_default_27, %_frozen_param43), kwargs = {})
    %mul_34 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%mm_15, 1), kwargs = {})
    %_frozen_param44 : [num_users=1] = get_attr[target=_frozen_param44]
    %add_17 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param44, %mul_34), kwargs = {})
    %reshape_default_28 : [num_users=1] = call_function[target=torch.ops.aten.reshape.default](args = (%add_17, [2, 4096, 640]), kwargs = {})
    %clone_10 : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%reshape_default_28,), kwargs = {})
    %div_6 : [num_users=1] = call_function[target=torch.ops.aten.div.Tensor](args = (%clone_10, 1.0), kwargs = {})
    %add_16 : [num_users=1] = placeholder[target=add_16]
    %add_18 : [num_users=2] = call_function[target=torch.ops.aten.add.Tensor](args = (%div_6, %add_16), kwargs = {})
    %_to_copy_54 : [num_users=1] = call_function[target=torch.ops.aten._to_copy.default](args = (%add_18,), kwargs = {dtype: torch.float32})
    return (_to_copy_54, add_18)

The above was then narrowed to the following simple matrix-multiply, which when run in FP16 with the dimensions (8192 x 640), (640 x 640), as is used in our SDXL configuration, produces a maximum difference of 10 between two elements in the output of TRT vs that of Torch. The mean difference was also high, at around 0.5.

    class TestModule(torch.nn.Module):
        def forward(self, q, k):
            return (q@k)

Additionally, the native_layer_norm operator may be contributing to the error, since its exclusion brings improved accuracy as well. This is also under investigation.

gs-olive avatar Dec 29 '23 00:12 gs-olive

Update

We have further narrowed the matmul cases for easier example-reproducing

Next Steps

  • See if issues persist when using FP32 precision, again narrow down the cases to identify layers which could be resulting in the accuracy issues if they continue

gs-olive avatar Jan 08 '24 18:01 gs-olive