flux
flux copied to clipboard
[QUESTION] Gemm +RS on 8xH100
Why is Gemm + RS performing much worse than torch baseline?
#tuning space
space: List[TuningConfig] = []
space_M = [8192, 16384, 32768]
space_N = [8192]
space_K = [28672, 8192]
space_fuse_reduction = [False, True]
space_transpose_weight = [True, False]
space_dtype = [torch.float16]
space_has_bias = [False]
tuning logs
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (990.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1021.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1254.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1308.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1952.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2011.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (2452.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2580.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (3890.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (4164.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (4966.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (5311.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1212.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1230.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1493.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1510.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (2275.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2426.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (2796.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2868.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (4568.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (4788.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RRR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (5774.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (5995.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (999.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1063.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1238.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1303.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1949.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2073.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (2479.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2593.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (3962.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (3986.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=0,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (5016.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (5253.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1200.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1224.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=8192,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (1466.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (1494.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (2282.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2300.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=16384,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (2822.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (2862.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=1024,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (4595.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (4770.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
GemmMeta(dtype=GemmDTypeConfig(a=FP16,b=FP16,c=Void,d=FP16,acc=FP32,blockscale=FP32),arch=Sm90,comm_op=ReduceScatter,gemm_layout=RCR,impl=GemmV3,impl_spec=GemmV3Meta(fast_accum=0,block_scale=0),comm_spec=ReduceScatterMeta(fuse_reduction=1,comm_kind=IntraNode))
RuntimeConfig(m=32768,n=8192,k=3584,comm_spec=ReduceScatterRuntimeConfig(world_size=8,nnodes=1))
* TopK=1 (5732.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(2,1,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
* TopK=2 (5885.000 µs): GemmHParams(impl_spec=GemmV3HParams(cluster_shape=(1,2,1),kernel_schedule=Cooperative),comm_spec=None,tile_shape=(128,256,64),gemm_kind=GemmDefault,mainloop_stage=0,raster_order=RasterHeuristic)
python = 3.11 torch = 2.6.0