ao
ao copied to clipboard
Add blocksparse_int_addmm. Eliminate unnecessary contiguous calls which leads to performance increase.
As in the title.
This PR is created on top of https://github.com/pytorch/ao/pull/821 and requires https://github.com/pytorch/pytorch/pull/136104 .
The diff of jcaip/int8-bsr and pearu/int8-bsr branches is
diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py
index 402948ad..d9e5e56f 100644
--- a/torchao/dtypes/affine_quantized_tensor.py
+++ b/torchao/dtypes/affine_quantized_tensor.py
@@ -1179,7 +1179,7 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor,
w_vals = weight_tensor.layout_tensor
w_scales = weight_tensor.layout_tensor.scale
tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1])
- tmp_t = tmp.t().contiguous()
+ tmp_t = tmp.t()
# # Need to put this into custom op
# weight_bsr = torch.sparse_bsr_tensor(w_vals.crow_indices(),
@@ -1197,21 +1197,15 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor,
# input = torch.zeros(M, N, dtype=torch.int32, device=dense.device)
# y = _int_bsr_dense_addmm(input, weight_bsr, tmp_t).t().contiguous()
- y = torch.ops.blocksparse.int_mm(w_vals.crow_indices(),
- w_vals.col_indices(),
- w_vals.values(),
- w_vals.shape[0],
- w_vals.shape[1],
- tmp_t)
- # breakpoint()
-
-
- y = x_scales.reshape(-1, 1) * y
-
- y = (y * w_scales).reshape(
- *x_vals_int8.shape[:-1], y.shape[-1]
- )
+ y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(),
+ w_vals.col_indices(),
+ w_vals.values(),
+ tmp_t,
+ w_scales,
+ x_scales.reshape(-1))
+ y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1])
+ y = y.reshape(*y_shape)
# can downcast only at the very end
output_dtype = input_tensor.dtype
diff --git a/torchao/sparsity/prototype/superblock/benchmark.py b/torchao/sparsity/prototype/superblock/benchmark.py
index c960eb7b..20df85f7 100644
--- a/torchao/sparsity/prototype/superblock/benchmark.py
+++ b/torchao/sparsity/prototype/superblock/benchmark.py
@@ -13,6 +13,7 @@ import torch.utils.data
import utils
from torch import nn
from torch.sparse._triton_ops_meta import optimize_bsr_dense_addmm
+from torch.sparse._triton_ops_meta import dump as store_tuned_kernel_params
from torchao.sparsity.prototype.superblock.utils import accelerate_with_sparsity, simulate_sparsity
from torchao.utils import benchmark_model, profiler_runner
@@ -34,15 +35,30 @@ def main(args):
# BSR kernel tuning
if args.bsr and args.tune_kernel_params:
print("Tuning kernel params")
+ kwargs = dict(
+ dtype=torch.int8 if args.quantization else dtype,
+ sparsity=args.sparsity_linear, verbose=True,
+ # per blocksparse_int_addmm:
+ alpha=1, beta=0, use_left_alpha=True, use_right_alpha=True,
+ # force tuning because existing tuning parameters are
+ # computed for use_left/right_alpha=False, however, it
+ # turns out that re-tuning for use_left/right_alpha=False
+ # leads to the same set of tuning parametes:
+ # force=True
+ )
if args.model == "vit_b_16":
- optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
- optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
+ optimize_bsr_dense_addmm(3072, 768, 50432, args.bsr, args.bsr, **kwargs)
+ optimize_bsr_dense_addmm(768, 3072, 50432, args.bsr, args.bsr, **kwargs)
elif args.model == "vit_h_14":
- optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
- optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, dtype=dtype, sparsity=args.sparsity_linear, verbose=True)
+ optimize_bsr_dense_addmm(5120, 1280, 65792, args.bsr, args.bsr, **kwargs)
+ optimize_bsr_dense_addmm(1280, 5120, 65792, args.bsr, args.bsr, **kwargs)
else:
raise NotImplementedError("Tuning kernel params for this model is not supported yet.")
-
+ # Warning: the following call will overwrite the source code
+ # of torch.sparse._triton_ops_meta (hence it is commented out
+ # by default) but when used, it'll enables reusing the tuned
+ # parameters in subsequent runs of this script:
+ # store_tuned_kernel_params()
print("Creating model")
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
diff --git a/torchao/sparsity/prototype/superblock/blocksparse.py b/torchao/sparsity/prototype/superblock/blocksparse.py
index 5bf458ea..8e520ac9 100644
--- a/torchao/sparsity/prototype/superblock/blocksparse.py
+++ b/torchao/sparsity/prototype/superblock/blocksparse.py
@@ -5,7 +5,7 @@ from typing import Optional, Tuple, List, Dict, Any, Callable
from torch.utils._python_dispatch import return_and_correct_aliasing
from torchao.utils import TorchAOBaseTensor
from torchao.quantization.quant_api import _get_linear_subclass_inserter
-from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims
+from torch.sparse._triton_ops import bsr_dense_mm, _int_bsr_dense_addmm, broadcast_batch_dims, bsr_dense_addmm
aten = torch.ops.aten
@@ -41,6 +41,31 @@ def blocksparse_int_mm_abstract(crow_indices: torch.Tensor, col_indices: torch.T
new_shape = (A.shape[-1], M)
return torch.empty(new_shape, dtype=torch.int8, device=A.device)
+
[email protected]_op("blocksparse::int_addmm", mutates_args=())
+def blocksparse_int_addmm(crow_indices: torch.Tensor,
+ col_indices: torch.Tensor,
+ values: torch.Tensor,
+ A: torch.Tensor,
+ left_alpha: torch.Tensor,
+ right_alpha: torch.Tensor) -> torch.Tensor:
+ assert values.dtype == torch.int8
+ M = left_alpha.shape[-1]
+ K = A.shape[-2]
+ N = A.shape[-1]
+ weight_bsr = torch.sparse_bsr_tensor(crow_indices, col_indices, values, size=(M, K))
+ original_batch_dims_broadcasted = broadcast_batch_dims(blocksparse_int_addmm, weight_bsr, A)
+ out = A.new_empty(original_batch_dims_broadcasted + (M, N))
+ return bsr_dense_addmm(out, weight_bsr, A, alpha=1, beta=0, out=out, left_alpha=left_alpha, right_alpha=right_alpha).t()
+
+
[email protected]_fake("blocksparse::int_addmm")
+def blocksparse_int_addmm_abstract(crow_indices: torch.Tensor, col_indices: torch.Tensor, values: torch.Tensor, A: torch.Tensor, left_alpha: torch.Tensor, right_alpha: torch.Tensor) -> torch.Tensor:
+ N = A.shape[-1]
+ M = left_alpha.shape[-1]
+ return torch.empty((N, M), dtype=torch.int8, device=A.device)
+
+
# Subclass definition
class BlockSparseTensor(TorchAOBaseTensor):
bsr_crow_indices: Optional[torch.Tensor]
As a result, the following performance test
python torchao/sparsity/prototype/superblock/benchmark.py --model vit_h_14 --batch-size 256 --sparsity-linear 0.8 --sp-linear-tile-size 64 --bsr 64 --sparsity bsr --quantization --tune-kernel-params
leads to
340.710 ms
2.935 img/s
Memory: 2592
which should be compared to the previous state:
581.503 ms
1.720 img/s
Memory: 5004
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/891
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.