ao icon indicating copy to clipboard operation
ao copied to clipboard

Davided0/feedforward horizontal fusion

Open davided0 opened this issue 4 weeks ago • 1 comments

Summary

Fuse w1 and w3 linear layers in the FeedForward module into a single w13 layer using horizontal fusion, reducing CUDA kernel launches and improving throughput by up to 24%.

Motivation

The LLaMA FeedForward module computes silu(w1(x)) * w3(x), requiring two independent matrix multiplications w1(x), w3(x) on the same input. The number of GEMM calls is reduced by concatenating w1 and w3 weights into a single w13 layer.

This mirrors the existing horizontal fusion optimization in the Attention module.

Changes

  • Combined w1 and w3 into a single w13 linear layer (2 * intermediate_size output features)
  • Added _merge_w1_w3 state dict hook for backward compatibility with existing checkpoints
  • Updated forward pass: single matmul → chunk() to split output

Performance Analysis

Kernel Reduction

The number of nn.linear calls in feed_forward is reduced from 192 to 128 (33% reduction, as expected). This was measured with the following command:

  • TORCH_LOGS="graph_code" python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --compile --quantization int8wo --write_result benchmark_results.txt | grep feed_forward | grep -c "nn.linear("

The reduction of nn.linear calls is further confirmed by the FX graphs:

Before (2 separate matmuls):

mm: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute)      # w1
mm_1: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute_1)  # w3

After (1 fused matmul + split):

mm: "f32[1, 22016]" = torch.ops.aten.mm.default(primals_2, permute)  # w13
split = torch.ops.aten.split.Tensor(mm, 11008, -1)                   # chunk
New FX graph
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[22016, 4096]", primals_2: "f32[1, 4096]", primals_3: "f32[4096, 11008]"):
         # File: /workspace/pytorch-ao/torchao/_models/llama/model.py:497 in forward, code: x1, x3 = self.w13(x).chunk(2, dim=-1)
        permute: "f32[4096, 22016]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        mm: "f32[1, 22016]" = torch.ops.aten.mm.default(primals_2, permute);  permute = None
        split = torch.ops.aten.split.Tensor(mm, 11008, -1);  mm = None
        getitem: "f32[1, 11008]" = split[0]
        getitem_1: "f32[1, 11008]" = split[1];  split = None
        
         # File: /workspace/pytorch-ao/torchao/_models/llama/model.py:498 in forward, code: return self.w2(F.silu(x1) * x3)
        sigmoid: "f32[1, 11008]" = torch.ops.aten.sigmoid.default(getitem)
        mul: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(getitem, sigmoid);  sigmoid = None
        mul_1: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mul, getitem_1);  mul = None
        permute_1: "f32[11008, 4096]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
        mm_1: "f32[1, 4096]" = torch.ops.aten.mm.default(mul_1, permute_1)
        permute_4: "f32[4096, 11008]" = torch.ops.aten.permute.default(permute_1, [1, 0]);  permute_1 = None
        return (mm_1, primals_2, getitem, getitem_1, mul_1, permute_4)
Original FX graph
class GraphModule(torch.nn.Module):
    def forward(self, primals_1: "f32[11008, 4096]", primals_2: "f32[1, 4096]", primals_3: "f32[11008, 4096]", primals_4: "f32[4096, 11008]"):
         # File: /workspace/pytorch-ao/torchao/_models/llama/model.py:486 in forward, code: return self.w2(F.silu(self.w1(x)) * self.w3(x))
        permute: "f32[4096, 11008]" = torch.ops.aten.permute.default(primals_1, [1, 0]);  primals_1 = None
        mm: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute);  permute = None
        sigmoid: "f32[1, 11008]" = torch.ops.aten.sigmoid.default(mm)
        mul: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mm, sigmoid);  sigmoid = None
        permute_1: "f32[4096, 11008]" = torch.ops.aten.permute.default(primals_3, [1, 0]);  primals_3 = None
        mm_1: "f32[1, 11008]" = torch.ops.aten.mm.default(primals_2, permute_1);  permute_1 = None
        mul_1: "f32[1, 11008]" = torch.ops.aten.mul.Tensor(mul, mm_1);  mul = None
        permute_2: "f32[11008, 4096]" = torch.ops.aten.permute.default(primals_4, [1, 0]);  primals_4 = None
        mm_2: "f32[1, 4096]" = torch.ops.aten.mm.default(mul_1, permute_2)
        permute_5: "f32[4096, 11008]" = torch.ops.aten.permute.default(permute_2, [1, 0]);  permute_2 = None
        return (mm_2, primals_2, mm, mm_1, mul_1, permute_5)
Driver
import torch
from torch._inductor import config
config.trace.enabled = True

from model import FeedForward, ModelArgs

args = ModelArgs()
model = FeedForward(args)
model = torch.compile(model, fullgraph=True)

x = torch.rand(size=(1, args.dim))
out = model(x)

Benchmark Results

Speedup calculated as new_tok/s / baseline_tok/s using commands from torchao/_models/llama/benchmarks.sh. The other columns are computed similarly.

  • GPU: A100-SXM4-40GB
  • CUDA: 11.8
  • PyTorch: 2.6.0+cu118

Notable improvements:

Speedup (tok/s) Speedup (tok/s_decode) Speedup (ttft) Speedup (mem/s) Speedup (peak_mem) Params
1.236358749 1.24111364 0.9433497537 1.236358051 1.030627871 quant: fp6 sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.119138562 1.122922546 0.9755244755 1.119154914 1 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.118612397 1.122824894 0.9876977153 1.118619349 1 quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8

Most configurations show neutral performance (within ±1%), with significant gains in specific quantization scenarios.

Full benchmark table
Speedup (tok/s) Speedup (tok/s_decode) Speedup (ttft) Speedup (mem/s) Speedup (peak_mem) Params
0.9911355002 0.9907590074 0.9910447761 0.9910958849 1 quant: None sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.087821543 1.083854358 0.8843557382 1.087794498 0.9229922992 quant: int8dq sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
0.9950999355 0.9947361902 0.9940828402 0.9950826696 0.8942093541 quant: int8wo sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.236358749 1.24111364 0.9433497537 1.236358051 1.030627871 quant: fp6 sparse: None mod: Llama-2-7b-chat-hf kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
0.9922655878 0.9922562478 1 0.9922623814 1 quant: None sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.044583045 1.041311755 0.9201552537 1.044735709 1.02365416 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.119138562 1.122922546 0.9755244755 1.119154914 1 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.032804737 1.032319894 0.9538461538 1.03283344 1 quant: fp6 sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.float16 device: cuda repro: python generate.py --quantization fp6 --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.float16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
0.9941659721 0.9936612278 0.972972973 0.9941143342 0.9836065574 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.118612397 1.122824894 0.9876977153 1.118619349 1 quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.013262599 1.013150973 0.9518716578 1.013111394 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192
1.006341154 1.005653266 1.009287926 1.005828441 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization
1.000664894 1.000659196 0.9685279188 1.000531491 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 8192--kv_cache_quantization --linear_causal_mask
1.009274874 1.008375209 1 1.008704931 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384
1.003857281 1.002870813 0.9502572899 1.003917791 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384--kv_cache_quantization
1.003898635 1.003872217 0.9458544839 1.003636364 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 16384--kv_cache_quantization --linear_causal_mask
0.9970457903 0.9970631424 0.951285521 0.9973430427 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768
1.005033557 1.005 0.9988066826 1.005026248 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768--kv_cache_quantization
1.003367003 1.003344482 0.9929824561 1.003138662 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 32768--kv_cache_quantization --linear_causal_mask
1.002873563 1.002857143 0.9963753524 1.002870813 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536
1.003215434 1 0.9958491871 1.001283422 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536--kv_cache_quantization
1 1.003205128 0.9958085924 1.001286725 1 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: True compile: False compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8 --cache_size 65536--kv_cache_quantization --linear_causal_mask
1.045706371 1.041081311 0.9073971079 1.045563549 1.023333333 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.053301512 1.043371723 0.8395337302 1.053222673 1 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 32 --top_k 200 --temperature 0.8
1.01754386 1.012869565 0.9205726613 1.017601432 1.010028653 quant: int8dq sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 128 --top_k 200 --temperature 0.8
0.999491353 0.9988133464 0.9824561404 0.9994782306 1 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.044860076 1.044827115 0.953164557 1.044879946 0.9806501548 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 32 --top_k 200 --temperature 0.8
1.094404394 1.096069869 0.9857336957 1.094362018 0.9884947267 quant: int8wo sparse: None mod: Meta-Llama-3-8B kv_quant: False compile: True compile_prefill: False dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3-8B/model.pth --device cuda --precision torch.bfloat16 --compile --num_samples 5 --max_new_tokens 200 --batch_size 128 --top_k 200 --temperature 0.8
0.9984371948 0.9967086434 0.9965533748 0.9985293352 1.006455234 quant: None sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.032615385 1.040481928 0.9943841258 1.032706258 1 quant: int8dq sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8dq --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.013203825 1.018600098 0.9963788301 1.01321817 1.015944541 quant: int8wo sparse: None mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.bfloat16 device: cuda repro: python generate.py --quantization int8wo --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.bfloat16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
1.018356868 1.02598419 0.9936114993 1.018331352 1.028186275 quant: sparse-marlin sparse: semi-structured mod: Meta-Llama-3.1-8B kv_quant: False compile: True compile_prefill: True dtype: torch.float16 device: cuda repro: python generate.py --quantization sparse-marlin --sparsity semi-structured --checkpoint_path ../../../checkpoints/meta-llama/Meta-Llama-3.1-8B/model.pth --device cuda --precision torch.float16 --compile --compile_prefill --prefill_size 8000--num_samples 5 --max_new_tokens 200 --batch_size 1 --top_k 200 --temperature 0.8
Raw data

Related Issues

Fixes #606

davided0 avatar Nov 24 '25 15:11 davided0

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3380

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Nov 24 '25 15:11 pytorch-bot[bot]