ao icon indicating copy to clipboard operation
ao copied to clipboard

Add blocksparse_int_addmm. Eliminate unnecessary contiguous calls which leads to performance increase.

Open pearu opened this issue 1 year ago • 1 comments

As in the title.

This PR is created on top of https://github.com/pytorch/ao/pull/821 and requires https://github.com/pytorch/pytorch/pull/136104 .

The diff of jcaip/int8-bsr and pearu/int8-bsr branches is

diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py
index 402948ad..d9e5e56f 100644
--- a/torchao/dtypes/affine_quantized_tensor.py
+++ b/torchao/dtypes/affine_quantized_tensor.py
@@ -1179,7 +1179,7 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor,
     w_vals = weight_tensor.layout_tensor
     w_scales = weight_tensor.layout_tensor.scale
     tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
-    tmp_t = tmp.t().contiguous()
+    tmp_t = tmp.t()
 
     # # Need to put this into custom op
     # weight_bsr = torch.sparse_bsr_tensor(w_vals.crow_indices(),
@@ -1197,21 +1197,15 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor,
     # input = torch.zeros(M, N, dtype=torch.int32, device=dense.device)
     # y = _int_bsr_dense_addmm(input, weight_bsr, tmp_t).t().contiguous()
 
-    y = torch.ops.blocksparse.int_mm(w_vals.crow_indices(),
-                                          w_vals.col_indices(),
-                                          w_vals.values(),
-                                          w_vals.shape[0],
-                                          w_vals.shape[1],
-                                          tmp_t)
 
-    # breakpoint()
-
-
-    y = x_scales.reshape(-1, 1) * y
-
-    y = (y * w_scales).reshape(
-        *x_vals_int8.shape[:-1], y.shape[-1]
-    )
+    y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(),
+                                        w_vals.col_indices(),
+                                        w_vals.values(),
+                                        tmp_t,
+                                        w_scales,
+                                        x_scales.reshape(-1))
+    y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
+    y = y.reshape(*y_shape)
 
     # can downcast only at the very end
     output_dtype = input_tensor.dtype
diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py
index c960eb7b..20df85f7 100644
--- a/torchao/sparsity/prototype/superblock/benchmark.py
+++ b/torchao/sparsity/prototype/superblock/benchmark.py
@@ -13,6 +13,7 @@ import torch.utils.data
 import utils
 from torch import nn
 from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
+from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params
 from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
 from torchao.utils import benchmark_model, profiler_runner
 
@@ -34,15 +35,30 @@ def main(args):
     # BSR kernel tuning
     if args.bsr and args.tune_kernel_params:
         print("Tuning kernel params")
+        kwargs = dict(
+            dtype=torch.int8 if args.quantization else dtype,
+            sparsity=args.sparsity_linear, verbose=True,
+            # per blocksparse_int_addmm:
+            alpha=1, beta=0, use_left_alpha=True, use_right_alpha=True,
+            # force tuning because existing tuning parameters are
+            # computed for use_left/right_alpha=False, however, it
+            # turns out that re-tuning for use_left/right_alpha=False
+            # leads to the same set of tuning parametes:
+            # force=True
+        )
         if args.model == "vit_b_16":
-            optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
-            optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
+            optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs)
+            optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs)
         elif args.model == "vit_h_14":
-            optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
-            optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
+            optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs)
+            optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs)
         else:
             raise NotImplementedError("Tuning kernel params for this model is not supported yet.")
-
+        # Warning: the following call will overwrite the source code
+        # of torch.sparse._triton_ops_meta (hence it is commented out
+        # by default) but when used, it'll enables reusing the tuned
+        # parameters in subsequent runs of this script:
+        # store_tuned_kernel_params()
     print("Creating model")
     model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)

diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py
index 5bf458ea..8e520ac9 100644
--- a/torchao/sparsity/prototype/superblock/blocksparse.py
+++ b/torchao/sparsity/prototype/superblock/blocksparse.py
@@ -5,7 +5,7 @@ from typing import Optional, Tuple, List, Dict, Any, Callable
 from torch.utils._python_dispatch import return_and_correct_aliasing
 from torchao.utils import TorchAOBaseTensor
 from torchao.quantization.quant_api import _get_linear_subclass_inserter
-from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims
+from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims, bsr_dense_addmm
 
 aten = torch.ops.aten
 
@@ -41,6 +41,31 @@ def blocksparse_int_mm_abstract(crow_indices: torch.Tensor, col_indices: torch.T
     new_shape = (A.shape[-1], M)
     return torch.empty(new_shape, dtype=torch.int8, device=A.device)
 
+
[email protected]_op("blocksparse::int_addmm", mutates_args=())
+def blocksparse_int_addmm(crow_indices: torch.Tensor,
+                          col_indices: torch.Tensor,
+                          values: torch.Tensor,
+                          A: torch.Tensor,
+                          left_alpha: torch.Tensor,
+                          right_alpha: torch.Tensor) -> torch.Tensor:
+    assert values.dtype == torch.int8
+    M = left_alpha.shape[-1]
+    K = A.shape[-2]
+    N = A.shape[-1]
+    weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
+    original_batch_dims_broadcasted = broadcast_batch_dims(blocksparse_int_addmm, weight_bsr, A)
+    out = A.new_empty(original_batch_dims_broadcasted + (M, N))
+    return bsr_dense_addmm(out, weight_bsr, A, alpha=1, beta=0, out=out, left_alpha=left_alpha, right_alpha=right_alpha).t()
+
+
[email protected]_fake("blocksparse::int_addmm")
+def blocksparse_int_addmm_abstract(crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, A: torch.Tensor, left_alpha: torch.Tensor, right_alpha: torch.Tensor) -> torch.Tensor:
+    N = A.shape[-1]
+    M = left_alpha.shape[-1]
+    return torch.empty((N, M), dtype=torch.int8, device=A.device)
+
+
 # Subclass definition
 class BlockSparseTensor(TorchAOBaseTensor):
     bsr_crow_indices: Optional[torch.Tensor]

As a result, the following performance test

python torchao/sparsity/prototype/superblock/benchmark.py --model vit_h_14   --batch-size 256   --sparsity-linear 0.8   --sp-linear-tile-size 64 --bsr 64 --sparsity bsr --quantization  --tune-kernel-params

leads to

340.710 ms
2.935 img/s
Memory: 2592

which should be compared to the previous state:

581.503 ms
1.720 img/s
Memory: 5004

pearu avatar Sep 16 '24 19:09 pearu

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/891

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Sep 16 '24 19:09 pytorch-bot[bot]