Davided0/feedforward horizontal fusion
Summary
Fuse w1 and w3 linear layers in the FeedForward module into a single w13 layer using horizontal fusion, reducing CUDA kernel launches and improving throughput by up to 24%.
Motivation
The LLaMA FeedForward module computes silu(w1(x)) * w3(x), requiring two independent matrix multiplications w1(x), w3(x) on the same input. The number of GEMM calls is reduced by concatenating w1 and w3 weights into a single w13 layer.
This mirrors the existing horizontal fusion optimization in the Attention module.
Changes
- Combined
w1andw3into a singlew13linear layer (2 * intermediate_sizeoutput features) - Added
_merge_w1_w3state dict hook for backward compatibility with existing checkpoints - Updated forward pass: single matmul →
chunk()to split output
Performance Analysis
Kernel Reduction
The number of nn.linear calls in feed_forward is reduced from 192 to 128 (33% reduction, as expected). This was measured with the following command:
TORCH_LOGS="graph_code" python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --compile --quantization int8wo --write_result benchmark_results.txt | grep feed_forward | grep -c "nn.linear("
The reduction of nn.linear calls is further confirmed by the FX graphs:
Before (2 separate matmuls):
mm: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute) # w1
mm_1: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute_1) # w3
After (1 fused matmul + split):
mm: "f32[1, 22016]" = torch.ops.aten.mm.default(primals_2, permute) # w13
split = torch.ops.aten.split.Tensor(mm, 11008, -1) # chunk
New FX graph
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[22016, 4096]", primals_2: "f32[1, 4096]", primals_3: "f32[4096, 11008]"):
# File: /workspace/pytorch-ao/torchao/_models/llama/model.py:497 in forward, code: x1, x3 = self.w13(x).chunk(2, dim=-1)
permute: "f32[4096, 22016]" = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None
mm: "f32[1, 22016]" = torch.ops.aten.mm.default(primals_2, permute); permute = None
split = torch.ops.aten.split.Tensor(mm, 11008, -1); mm = None
getitem: "f32[1, 11008]" = split[0]
getitem_1: "f32[1, 11008]" = split[1]; split = None
# File: /workspace/pytorch-ao/torchao/_models/llama/model.py:498 in forward, code: return self.w2(F.silu(x1) * x3)
sigmoid: "f32[1, 11008]" = torch.ops.aten.sigmoid.default(getitem)
mul: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(getitem, sigmoid); sigmoid = None
mul_1: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mul, getitem_1); mul = None
permute_1: "f32[11008, 4096]" = torch.ops.aten.permute.default(primals_3, [1, 0]); primals_3 = None
mm_1: "f32[1, 4096]" = torch.ops.aten.mm.default(mul_1, permute_1)
permute_4: "f32[4096, 11008]" = torch.ops.aten.permute.default(permute_1, [1, 0]); permute_1 = None
return (mm_1, primals_2, getitem, getitem_1, mul_1, permute_4)
Original FX graph
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[11008, 4096]", primals_2: "f32[1, 4096]", primals_3: "f32[11008, 4096]", primals_4: "f32[4096, 11008]"):
# File: /workspace/pytorch-ao/torchao/_models/llama/model.py:486 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
permute: "f32[4096, 11008]" = torch.ops.aten.permute.default(primals_1, [1, 0]); primals_1 = None
mm: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute); permute = None
sigmoid: "f32[1, 11008]" = torch.ops.aten.sigmoid.default(mm)
mul: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mm, sigmoid); sigmoid = None
permute_1: "f32[4096, 11008]" = torch.ops.aten.permute.default(primals_3, [1, 0]); primals_3 = None
mm_1: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute_1); permute_1 = None
mul_1: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mul, mm_1); mul = None
permute_2: "f32[11008, 4096]" = torch.ops.aten.permute.default(primals_4, [1, 0]); primals_4 = None
mm_2: "f32[1, 4096]" = torch.ops.aten.mm.default(mul_1, permute_2)
permute_5: "f32[4096, 11008]" = torch.ops.aten.permute.default(permute_2, [1, 0]); permute_2 = None
return (mm_2, primals_2, mm, mm_1, mul_1, permute_5)
Driver
import torch
from torch._inductor import config
config.trace.enabled = True
from model import FeedForward, ModelArgs
args = ModelArgs()
model = FeedForward(args)
model = torch.compile(model, fullgraph=True)
x = torch.rand(size=(1, args.dim))
out = model(x)
Benchmark Results
Speedup calculated as new_tok/s / baseline_tok/s using commands from torchao/_models/llama/benchmarks.sh. The other columns are computed similarly.
- GPU: A100-SXM4-40GB
- CUDA: 11.8
- PyTorch: 2.6.0+cu118
Notable improvements:
| Speedup (tok/s) | Speedup (tok/s_decode) | Speedup (ttft) | Speedup (mem/s) | Speedup (peak_mem) | Params |
|---|---|---|---|---|---|
| 1.236358749 | 1.24111364 | 0.9433497537 | 1.236358051 | 1.030627871 | quant: fp6 sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.119138562 | 1.122922546 | 0.9755244755 | 1.119154914 | 1 | quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.118612397 | 1.122824894 | 0.9876977153 | 1.118619349 | 1 | quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
Most configurations show neutral performance (within ±1%), with significant gains in specific quantization scenarios.
Full benchmark table
| Speedup (tok/s) | Speedup (tok/s_decode) | Speedup (ttft) | Speedup (mem/s) | Speedup (peak_mem) | Params |
|---|---|---|---|---|---|
| 0.9911355002 | 0.9907590074 | 0.9910447761 | 0.9910958849 | 1 | quant: None sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.087821543 | 1.083854358 | 0.8843557382 | 1.087794498 | 0.9229922992 | quant: int8dq sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 0.9950999355 | 0.9947361902 | 0.9940828402 | 0.9950826696 | 0.8942093541 | quant: int8wo sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.236358749 | 1.24111364 | 0.9433497537 | 1.236358051 | 1.030627871 | quant: fp6 sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 0.9922655878 | 0.9922562478 | 1 | 0.9922623814 | 1 | quant: None sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.044583045 | 1.041311755 | 0.9201552537 | 1.044735709 | 1.02365416 | quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.119138562 | 1.122922546 | 0.9755244755 | 1.119154914 | 1 | quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.032804737 | 1.032319894 | 0.9538461538 | 1.03283344 | 1 | quant: fp6 sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 0.9941659721 | 0.9936612278 | 0.972972973 | 0.9941143342 | 0.9836065574 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.118612397 | 1.122824894 | 0.9876977153 | 1.118619349 | 1 | quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.013262599 | 1.013150973 | 0.9518716578 | 1.013111394 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192 |
| 1.006341154 | 1.005653266 | 1.009287926 | 1.005828441 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization |
| 1.000664894 | 1.000659196 | 0.9685279188 | 1.000531491 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization --linear_causal_mask |
| 1.009274874 | 1.008375209 | 1 | 1.008704931 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384 |
| 1.003857281 | 1.002870813 | 0.9502572899 | 1.003917791 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384--kv_cache_quantization |
| 1.003898635 | 1.003872217 | 0.9458544839 | 1.003636364 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384--kv_cache_quantization --linear_causal_mask |
| 0.9970457903 | 0.9970631424 | 0.951285521 | 0.9973430427 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768 |
| 1.005033557 | 1.005 | 0.9988066826 | 1.005026248 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768--kv_cache_quantization |
| 1.003367003 | 1.003344482 | 0.9929824561 | 1.003138662 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768--kv_cache_quantization --linear_causal_mask |
| 1.002873563 | 1.002857143 | 0.9963753524 | 1.002870813 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536 |
| 1.003215434 | 1 | 0.9958491871 | 1.001283422 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536--kv_cache_quantization |
| 1 | 1.003205128 | 0.9958085924 | 1.001286725 | 1 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536--kv_cache_quantization --linear_causal_mask |
| 1.045706371 | 1.041081311 | 0.9073971079 | 1.045563549 | 1.023333333 | quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.053301512 | 1.043371723 | 0.8395337302 | 1.053222673 | 1 | quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 32 --top_k 200 --temperature 0.8 |
| 1.01754386 | 1.012869565 | 0.9205726613 | 1.017601432 | 1.010028653 | quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 128 --top_k 200 --temperature 0.8 |
| 0.999491353 | 0.9988133464 | 0.9824561404 | 0.9994782306 | 1 | quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.044860076 | 1.044827115 | 0.953164557 | 1.044879946 | 0.9806501548 | quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 32 --top_k 200 --temperature 0.8 |
| 1.094404394 | 1.096069869 | 0.9857336957 | 1.094362018 | 0.9884947267 | quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 128 --top_k 200 --temperature 0.8 |
| 0.9984371948 | 0.9967086434 | 0.9965533748 | 0.9985293352 | 1.006455234 | quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.032615385 | 1.040481928 | 0.9943841258 | 1.032706258 | 1 | quant: int8dq sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.013203825 | 1.018600098 | 0.9963788301 | 1.01321817 | 1.015944541 | quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
| 1.018356868 | 1.02598419 | 0.9936114993 | 1.018331352 | 1.028186275 | quant: sparse-marlin sparse: semi-structured mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.float16 device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 |
Related Issues
Fixes #606
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3380
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
This comment was automatically generated by Dr. CI and updates every 15 minutes.