AITemplate
AITemplate copied to clipboard
Support non-square sizes for stable diffusion like 640x384 don't seem to work
From @terrychenism "the group norm problem size is not supported yet."
My diff:
diff --git a/examples/05_stable_diffusion/compile.py b/examples/05_stable_diffusion/compile.py
index 513df5b..790f3c0 100644
--- a/examples/05_stable_diffusion/compile.py
+++ b/examples/05_stable_diffusion/compile.py
@@ -177,8 +177,8 @@ def map_clip_params(pt_mod, batch_size, seqlen, depth):
def compile_unet(
batch_size=2,
- hh=64,
- ww=64,
+ hh=48,
+ ww=80,
dim=320,
use_fp16_acc=False,
convert_conv_to_gemm=False,
@@ -339,7 +339,8 @@ def compile_diffusers(token, batch_size, img2img=False, use_fp16_acc=True, conve
use_auth_token=access_token,
).to("cuda")
- width = 96 if img2img else 64
+ width = 80
+ height = 48
# CLIP
compile_clip(batch_size=batch_size, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)
@@ -347,11 +348,12 @@ def compile_diffusers(token, batch_size, img2img=False, use_fp16_acc=True, conve
compile_unet(
batch_size=batch_size * 2,
ww=width,
+ hh=height,
use_fp16_acc=use_fp16_acc,
convert_conv_to_gemm=convert_conv_to_gemm,
)
# VAE
- compile_vae(batch_size=batch_size, width=width, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)
+ compile_vae(batch_size=batch_size, width=width, height=height, use_fp16_acc=use_fp16_acc, convert_conv_to_gemm=convert_conv_to_gemm)
if __name__ == "__main__":
Error:
/usr/include/cub/block/specializations/block_reduce_warp_reductions.cuh(75): here
instantiation of class "cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800]"
/usr/include/cub/block/block_reduce.cuh(249): here
instantiation of class "cub::BlockReduce<T, BLOCK_DIM_X, ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> [with T=float, BLOCK_DIM_X=0, ALGORITHM=cub::BLOCK_REDUCE_WARP_REDUCTIONS, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(336): here
instantiation of "T <unnamed>::BlockAllReduce<ReductionOp,T,block_size>(T) [with ReductionOp=<unnamed>::SumOp, T=float, block_size=0]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(406): here
instantiation of "void <unnamed>::group_norm_smem<FuseSwish,H,W,C,C_G,ILP,BANK_CONFLICT,NUM_THREADS>(const half *, half *, half *, half *, int, float) [with FuseSwish=true, H=6, W=10, C=1280, C_G=40, ILP=8, BANK_CONFLICT=0, NUM_THREADS=0]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(566): here
instantiation of "cudaError_t <unnamed>::invokeGroupNorm<FuseSwish,H,W,C,G>(half *, half *, half *, half *, int, float, int, cudaStream_t) [with FuseSwish=true, H=6, W=10, C=1280, G=32]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(593): here
/usr/include/cub/warp/specializations/warp_reduce_shfl.cuh(73): error: division by zero
detected during:
instantiation of class "cub::WarpReduceShfl<T, LOGICAL_WARP_THREADS, PTX_ARCH> [with T=float, LOGICAL_WARP_THREADS=0, PTX_ARCH=800]"
/usr/include/cub/warp/warp_reduce.cuh(168): here
instantiation of class "cub::WarpReduce<T, LOGICAL_WARP_THREADS, PTX_ARCH> [with T=float, LOGICAL_WARP_THREADS=0, PTX_ARCH=800]"
/usr/include/cub/block/specializations/block_reduce_warp_reductions.cuh(75): here
instantiation of class "cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800]"
/usr/include/cub/block/block_reduce.cuh(249): here
instantiation of class "cub::BlockReduce<T, BLOCK_DIM_X, ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> [with T=float, BLOCK_DIM_X=0, ALGORITHM=cub::BLOCK_REDUCE_WARP_REDUCTIONS, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(336): here
instantiation of "T <unnamed>::BlockAllReduce<ReductionOp,T,block_size>(T) [with ReductionOp=<unnamed>::SumOp, T=float, block_size=0]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(406): here
instantiation of "void <unnamed>::group_norm_smem<FuseSwish,H,W,C,C_G,ILP,BANK_CONFLICT,NUM_THREADS>(const half *, half *, half *, half *, int, float) [with FuseSwish=true, H=6, W=10, C=1280, C_G=40, ILP=8, BANK_CONFLICT=0, NUM_THREADS=0]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(566): here
instantiation of "cudaError_t <unnamed>::invokeGroupNorm<FuseSwish,H,W,C,G>(half *, half *, half *, half *, int, float, int, cudaStream_t) [with FuseSwish=true, H=6, W=10, C=1280, G=32]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(593): here
/usr/include/cub/block/specializations/block_reduce_warp_reductions.cuh(120): error: excessive recursion at instantiation of function "cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::ApplyWarpAggregates<FULL_TILE,ReductionOp,SUCCESSOR_WARP>(ReductionOp, T, int, cub::Int2Type<SUCCESSOR_WARP>) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>, SUCCESSOR_WARP=201]"
detected during:
instantiation of "T cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::ApplyWarpAggregates<FULL_TILE,ReductionOp,SUCCESSOR_WARP>(ReductionOp, T, int, cub::Int2Type<SUCCESSOR_WARP>) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>, SUCCESSOR_WARP=200]"
(120): here
instantiation of "T cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::ApplyWarpAggregates<FULL_TILE,ReductionOp,SUCCESSOR_WARP>(ReductionOp, T, int, cub::Int2Type<SUCCESSOR_WARP>) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>, SUCCESSOR_WARP=199]"
(120): here
instantiation of "T cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::ApplyWarpAggregates<FULL_TILE,ReductionOp,SUCCESSOR_WARP>(ReductionOp, T, int, cub::Int2Type<SUCCESSOR_WARP>) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>, SUCCESSOR_WARP=198]"
(120): here
instantiation of "T cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::ApplyWarpAggregates<FULL_TILE,ReductionOp,SUCCESSOR_WARP>(ReductionOp, T, int, cub::Int2Type<SUCCESSOR_WARP>) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>, SUCCESSOR_WARP=197]"
(120): here
instantiation of "T cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::ApplyWarpAggregates<FULL_TILE,ReductionOp,SUCCESSOR_WARP>(ReductionOp, T, int, cub::Int2Type<SUCCESSOR_WARP>) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>, SUCCESSOR_WARP=196]"
(120): here
[ 196 instantiation contexts not shown ]
instantiation of "T cub::BlockReduceWarpReductions<T, BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::Reduce<FULL_TILE,ReductionOp>(T, int, ReductionOp) [with T=float, BLOCK_DIM_X=0, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, FULL_TILE=true, ReductionOp=<unnamed>::SumOp<float>]"
/usr/include/cub/block/block_reduce.cuh(353): here
instantiation of "T cub::BlockReduce<T, BLOCK_DIM_X, ALGORITHM, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH>::Reduce(T, ReductionOp) [with T=float, BLOCK_DIM_X=0, ALGORITHM=cub::BLOCK_REDUCE_WARP_REDUCTIONS, BLOCK_DIM_Y=1, BLOCK_DIM_Z=1, PTX_ARCH=800, ReductionOp=<unnamed>::SumOp<float>]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(338): here
instantiation of "T <unnamed>::BlockAllReduce<ReductionOp,T,block_size>(T) [with ReductionOp=<unnamed>::SumOp, T=float, block_size=0]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(406): here
instantiation of "void <unnamed>::group_norm_smem<FuseSwish,H,W,C,C_G,ILP,BANK_CONFLICT,NUM_THREADS>(const half *, half *, half *, half *, int, float) [with FuseSwish=true, H=6, W=10, C=1280, C_G=40, ILP=8, BANK_CONFLICT=0, NUM_THREADS=0]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(566): here
instantiation of "cudaError_t <unnamed>::invokeGroupNorm<FuseSwish,H,W,C,G>(half *, half *, half *, half *, int, float, int, cudaStream_t) [with FuseSwish=true, H=6, W=10, C=1280, G=32]"
./tmp/UNet2DConditionModel/groupnorm_swish_603.cu(593): here
4 errors detected in the compilation of "./tmp/UNet2DConditionModel/groupnorm_swish_603.cu".
Done