[Question]: The number of convolutional multiplication decreases but the communication cost increases in SPU
Issue Type
Performance
Modules Involved
SPU runtime
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
spu 0.9.0.dev20240311
OS Platform and Distribution
Ubuntu 18.04.6 LTS by WSL
Python Version
3.10
Compiler Version
GCC 11.3.0
Current Behavior?
Not a bug. Just an abnormal question: I have tested the Comm. cost to evaluate the first and individual conv layer Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) of ResNet18 on CIFAR10. It costed 759296 byte Comm. and 0.015497988s latency.
Since Conv is multiplication-intensive, a solution of reducing the Comm./Latency cost is to reduce the number of multiplications using Winograd algorithm. Winograd uses some pre-defined matrices to transform the weight and input to Winograd-domain counterparts and implement element-wise matrix multiplication (EWMM) between the transformed weight&input. The output of EWMM after an additional transformation is equivalent to that of standard Conv. On average, the number of multiplications can be reduced by 2.25 times using Winograd without any accuracy loss.
I have tested the Comm./Latency cost of the standard/Winograd Conv. But curiously, the cost of Winograd conv is significantly increased: 6291456 byte Comm. and 0.0487127s latency.
Theoretically, for the first layer of ResNet18 on CIFAR10, the standard conv has 1,769,472 multiplications, and Winograd conv has 786432 multiplications (2.25x reduction), but the Comm. increases by 8.2859 times.
May I ask if you understand the underlying reasons, or if there are some potential convolution-specific optimizations that I am not aware of?
Thanks a lot.
Standalone code to reproduce the issue
N/A
Relevant log output
Here I report the SPU logs relevant to standard and Winograd conv evaluation:
Standard conv:
[2024-05-10 18:26:50,734] [Process-1] Starting grpc server at 127.0.0.1:61320
[2024-05-10 18:26:50,734] [Process-2] Starting grpc server at 127.0.0.1:61321
[2024-05-10 18:26:59,661] [Process-2] Run : builtin_spu_init at node:1
[2024-05-10 18:26:59,661] [Process-1] Run : builtin_spu_init at node:0
I0510 18:26:59.665524 12394 external/com_github_brpc_brpc/src/brpc/server.cpp:1158] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61331.
W0510 18:26:59.665545 12394 external/com_github_brpc_brpc/src/brpc/server.cpp:1164] Builtin services are disabled according to ServerOptions.has_builtin_services
I0510 18:26:59.666197 12396 external/com_github_brpc_brpc/src/brpc/server.cpp:1158] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61330.
W0510 18:26:59.666213 12396 external/com_github_brpc_brpc/src/brpc/server.cpp:1164] Builtin services are disabled according to ServerOptions.has_builtin_services
[2024-05-10 18:26:59,667] [Process-2] spu-runtime (SPU) initialized
[2024-05-10 18:26:59,668] [Process-1] spu-runtime (SPU) initialized
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2024-05-10 18:26:59,986] [Process-1] Run : <lambda> at node:0
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2024-05-10 18:27:00,003] [Process-2] Run : <lambda> at node:1
[2024-05-10 18:27:00,005] [Process-2] Run : make_shares at node:1
[2024-05-10 18:27:00,006] [Process-2] RunR: builtin_fetch_meta at node:1
[2024-05-10 18:27:00,008] [Process-2] Run : make_shares at node:1
[2024-05-10 18:27:00,010] [Process-2] RunR: builtin_fetch_meta at node:1
[2024-05-10 18:27:00,011] [Process-1] Run : make_shares at node:0
[2024-05-10 18:27:00,012] [Process-1] RunR: builtin_fetch_meta at node:0
[2024-05-10 18:27:00,021] [Process-2] Run : builtin_spu_run at node:1
[2024-05-10 18:27:00,023] [Process-1] RunR: builtin_fetch_object at node:0
[2024-05-10 18:27:00,023] [Process-1] Run : builtin_spu_run at node:0
[2024-05-10 18:27:00,025] [Process-2] RunR: builtin_fetch_object at node:1
[2024-05-10 18:27:00,026] [Process-2] RunR: builtin_fetch_object at node:1
[2024-05-10 18:27:00.031] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 18:27:00.044] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 18:27:00.049] [info] [api.cc:158] [Profiling] SPU execution infer completed, input processing took 8.61e-07s, execution took 0.022561069s, output processing took 1.745e-06s, total time 0.022563675s.
[2024-05-10 18:27:00.049] [info] [api.cc:191] HLO profiling: total time 0.022231636
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.add, executed 1 times, duration 0.001296596s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.broadcast, executed 1 times, duration 6.47e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.constant, executed 1 times, duration 6.612e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.convert, executed 1 times, duration 2.1625e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.convolution, executed 1 times, duration 0.020818828s, send bytes 759296
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.free, executed 5 times, duration 4.6146e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pphlo.pad, executed 1 times, duration 3.5359e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:191] HAL profiling: total time 0.018824205
[2024-05-10 18:27:00.049] [info] [api.cc:194] - f_add, executed 1 times, duration 0.001292103s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - f_tensordot, executed 1 times, duration 0.017514063s, send bytes 759296
[2024-05-10 18:27:00.049] [info] [api.cc:194] - seal, executed 1 times, duration 1.8039e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:191] MPC profiling: total time 0.021005868000000004
[2024-05-10 18:27:00.049] [info] [api.cc:194] - add_aa, executed 1 times, duration 0.001286959s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - broadcast, executed 1 times, duration 2.828e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - concatenate, executed 1 times, duration 0.000931655s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - extract_slice, executed 1024 times, duration 0.000790953s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - mmul_aa, executed 1 times, duration 0.011660626s, send bytes 235008
[2024-05-10 18:27:00.049] [info] [api.cc:194] - p2a, executed 1 times, duration 1.3275e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - pad, executed 1 times, duration 3.3511e-05s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - reshape, executed 1029 times, duration 0.000453675s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - transpose, executed 2 times, duration 3.981e-06s, send bytes 0
[2024-05-10 18:27:00.049] [info] [api.cc:194] - trunc_a, executed 1 times, duration 0.005828405s, send bytes 524288
[2024-05-10 18:27:00.049] [info] [api.cc:204] Link details: total send bytes 759296, send actions 2
-------------------------------------------------
Winograd conv:
[2024-05-10 17:50:15,814] [ForkServerProcess-2] Starting grpc server at 127.0.0.1:61321
[2024-05-10 17:50:15,814] [ForkServerProcess-1] Starting grpc server at 127.0.0.1:61320
[2024-05-10 17:50:21,376] [ForkServerProcess-1] Run : builtin_spu_init at node:0
[2024-05-10 17:50:21,377] [ForkServerProcess-2] Run : builtin_spu_init at node:1
I0510 17:50:21.562828 8755 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61331.
W0510 17:50:21.562851 8755 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
I0510 17:50:21.563300 8753 external/com_github_brpc_brpc/src/brpc/server.cpp:1181] Server[yacl::link::transport::internal::ReceiverServiceImpl] is serving on port=61330.
W0510 17:50:21.563313 8753 external/com_github_brpc_brpc/src/brpc/server.cpp:1187] Builtin services are disabled according to ServerOptions.has_builtin_services
[2024-05-10 17:50:21,564] [ForkServerProcess-2] spu-runtime (SPU) initialized
[2024-05-10 17:50:21,564] [ForkServerProcess-1] spu-runtime (SPU) initialized
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[2024-05-10 17:50:21,614] [ForkServerProcess-1] Run : <lambda> at node:0
[2024-05-10 17:50:21,616] [ForkServerProcess-2] Run : <lambda> at node:1
[2024-05-10 17:50:21,617] [ForkServerProcess-1] Run : make_shares at node:0
[2024-05-10 17:50:21.617] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 17:50:21,636] [ForkServerProcess-1] RunR: builtin_fetch_meta at node:0
[2024-05-10 17:50:21,645] [ForkServerProcess-1] Run : builtin_spu_run at node:0
[2024-05-10 17:50:21,646] [ForkServerProcess-2] Run : builtin_spu_run at node:1
[2024-05-10 17:50:21,647] [ForkServerProcess-1] RunR: builtin_fetch_object at node:0
[2024-05-10 17:50:21.671] [info] [thread_pool.cc:30] Create a fixed thread pool with size 23
[2024-05-10 17:50:21.695] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 6.96e-07s, execution took 0.0487127s, output processing took 1.804e-06s, total time 0.0487152s.
[2024-05-10 17:50:21.696] [info] [api.cc:209] HLO profiling: total time 1.952e-06
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.constant, executed 1 times, duration 1.77e-06s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.free, executed 2 times, duration 9.2e-08s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.multiply, executed 1 times, duration 4.7e-08s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - pphlo.broadcast, executed 1 times, duration 4.3e-08s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:209] HAL profiling: total time 0.048321089
[2024-05-10 17:50:21.696] [info] [api.cc:212] - f_mul, executed 1 times, duration 0.048321089s, send bytes 6291456 recv bytes 6291456
[2024-05-10 17:50:21.696] [info] [api.cc:209] MPC profiling: total time 0.048304739
[2024-05-10 17:50:21.696] [info] [api.cc:212] - trunc_a, executed 1 times, duration 0.044309967s, send bytes 6291456 recv bytes 6291456
[2024-05-10 17:50:21.696] [info] [api.cc:212] - mul_ap, executed 1 times, duration 0.003989605s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:212] - broadcast, executed 1 times, duration 5.167e-06s, send bytes 0 recv bytes 0
[2024-05-10 17:50:21.696] [info] [api.cc:222] Link details: total send bytes 6291456, recv bytes 6291456, send actions 1
You can see, the number of truncations increases a lot. That should be normal. When doing matmul, the number of truncations should be quadratic to the matrix size, but for Winograd, I think you need to write SPU in C++ to reduce the number of truncations. For example, to add some multiplication results then perform one truncation, instead of calling many truncations on them.
hi, @warpoons . Interesting idea. As pointed out by @fionser , the problem is due to the increasing amont of truncations.
According to the Winograd algorithm, the matmul is separated into several parts (currently, each part shall incur additional truncations), which I believe is not friendly in SPU.
In my opinion, to maximize the performance of Winograd, you may need to add a backend op for Winograd, and implement the algorithm in C++.
Hi @llCurious @fionser ! Thanks for your response!
As pointed out by @fionser , when doing matmul, the number of truncations should be quadratic to the matrix size. In Winograd, the input feature map is separated into several overlapped parts (or called tiles) and do element-wise matmul separately. Reasonably, there will be additional EWMMs among all the tiles than the standard conv. Is this understanding correct?
I have another question, is there a method to estimate the theoretical communication cost of standard conv and Winograd conv (considering only the EWMMs in Winograd and do the transformation of weights offline) in SPU?
Thanks!
Hi @llCurious @fionser ! In this week, I have further tested the Winograd convolution for reducing multiplications in SPU.
As I previously described in this ISSUE, Winograd converts standard conv into EWMM with fewer multiplications, coming at the cost of low parallelism in EWMM.
Here is another way to convert the Winograd's EWMM into general matmul (GEMM) by transposing the Winograd weights and inputs.
As suggested in a NeruIPS 2023 paper Copriv: Network/protocol co-optimization for communication-efficient private inference as below, the communication increases after using the Winograd with multiplication reduction. To reach the expected comm improvement, we should consider the EWMM->GEMM conversion.
This finding somewhat confirms that why the comm size abnormally increases by 8x after using EWMM-based Winograd.
Hence, I have further tested the GEMM-based Winograd to observe that if there is an expected 2.25x comm reduction, but the answer is NO. The profiling is ("SEMI2K", "FM64"):
- jnp.dtype = jnp.float32
[2024-05-31 15:56:54.026] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 1.297e-06s, execution took 0.059690046s, output processing took 1.735e-06s, total time 0.059693078s.
[2024-05-31 15:56:54.026] [info] [api.cc:209] HLO profiling: total time 5.8410000000000005e-06
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.constant, executed 6 times, duration 2.155e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.free, executed 50 times, duration 1.814e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.reshape, executed 18 times, duration 7.56e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.transpose, executed 7 times, duration 3.05e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.broadcast, executed 6 times, duration 2.09e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.dot, executed 4 times, duration 1.95e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.iota, executed 2 times, duration 8.2e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pphlo.convolution, executed 1 times, duration 4e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:209] HAL profiling: total time 0.054427574
[2024-05-31 15:56:54.026] [info] [api.cc:212] - f_mmul, executed 20 times, duration 0.041460792s, send bytes 3866624 recv bytes 3866624
[2024-05-31 15:56:54.026] [info] [api.cc:212] - f_tensordot, executed 1 times, duration 0.012745663s, send bytes 98304 recv bytes 98304
[2024-05-31 15:56:54.026] [info] [api.cc:212] - i_equal, executed 2 times, duration 0.000147861s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - mixed_mul, executed 1 times, duration 4.8814e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - seal, executed 1 times, duration 1.9777e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - int2fxp, executed 1 times, duration 4.667e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:209] MPC profiling: total time 0.056513282
[2024-05-31 15:56:54.026] [info] [api.cc:212] - trunc_a, executed 21 times, duration 0.044545624s, send bytes 3964928 recv bytes 3964928
[2024-05-31 15:56:54.026] [info] [api.cc:212] - mmul_ap, executed 55 times, duration 0.007790367s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - reshape, executed 332 times, duration 0.001793458s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - concatenate, executed 2 times, duration 0.001615727s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - extract_slice, executed 360 times, duration 0.00036179s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - transpose, executed 132 times, duration 0.00019234s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - equal_pp, executed 2 times, duration 0.000101518s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - mul_pp, executed 1 times, duration 4.7267e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - pad, executed 1 times, duration 3.4435e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - p2a, executed 1 times, duration 1.8228e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - broadcast, executed 6 times, duration 9.442e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:212] - lshift_p, executed 1 times, duration 3.086e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:56:54.026] [info] [api.cc:222] Link details: total send bytes 3964928, recv bytes 3964928, send actions 21
- Inspired by https://github.com/secretflow/spu/issues/672#issuecomment-2098578217, I have also tested jnp.dtype = jnp.integer to obtain comm without truncations. The profiling is ("SEMI2K", "FM64"):
[2024-05-31 15:59:36.265] [info] [api.cc:163] [Profiling] SPU execution infer completed, input processing took 2.39e-07s, execution took 0.069512618s, output processing took 2.288e-06s, total time 0.069515145s.
[2024-05-31 15:59:36.265] [info] [api.cc:209] HLO profiling: total time 5.878e-06
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.constant, executed 6 times, duration 1.957e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.free, executed 50 times, duration 1.8e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.reshape, executed 18 times, duration 6.85e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.dot, executed 4 times, duration 5.09e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.transpose, executed 7 times, duration 2.92e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.broadcast, executed 6 times, duration 2.19e-07s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.convert, executed 2 times, duration 7.9e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.iota, executed 2 times, duration 7.8e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.equal, executed 2 times, duration 7.5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.dot_general, executed 1 times, duration 5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.convolution, executed 1 times, duration 4.5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.pad, executed 1 times, duration 4.5e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pphlo.multiply, executed 1 times, duration 4.4e-08s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:209] HAL profiling: total time 0.064209222
[2024-05-31 15:59:36.265] [info] [api.cc:212] - f_mmul, executed 2 times, duration 0.05382723s, send bytes 1572864 recv bytes 1572864
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_tensordot, executed 1 times, duration 0.009371561s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - mixed_mmul, executed 16 times, duration 0.000773674s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_equal, executed 2 times, duration 0.000100369s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_mmul, executed 2 times, duration 6.5911e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - i_mul, executed 1 times, duration 4.9618e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - seal, executed 1 times, duration 2.0859e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:209] MPC profiling: total time 0.066035378
[2024-05-31 15:59:36.265] [info] [api.cc:212] - trunc_a, executed 2 times, duration 0.05136436s, send bytes 1572864 recv bytes 1572864
[2024-05-31 15:59:36.265] [info] [api.cc:212] - mmul_ap, executed 55 times, duration 0.010664855s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - concatenate, executed 2 times, duration 0.001737654s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - reshape, executed 332 times, duration 0.001582893s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - extract_slice, executed 360 times, duration 0.000322834s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - transpose, executed 132 times, duration 0.000154002s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - equal_pp, executed 2 times, duration 9.7572e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - mul_pp, executed 1 times, duration 4.8151e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - pad, executed 1 times, duration 3.5026e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - p2a, executed 1 times, duration 1.9398e-05s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:212] - broadcast, executed 6 times, duration 8.633e-06s, send bytes 0 recv bytes 0
[2024-05-31 15:59:36.265] [info] [api.cc:222] Link details: total send bytes 1572864, recv bytes 1572864, send actions 2
We observe that the comm is reduced compared to the EWMM-based Winograd but still far from the expected improvement.
Another issue is that using jnp.integer still has trunc_a and comm in the profiling, and I cannot reach the reason behind it.
And also, jnp.dtype = jnp.float32 has f_tensordot with comm but jnp.dtype = jnp.integer has i_tensordot without comm.
To make it clear, this is my model.py to define Winograd Conv layer:
class FlaxConvWino(nn.Module):
inCh: int
outCh: int
filterDim: int
outTileDim: int
dtype: jnp.dtype = jnp.float32
@nn.compact
def __call__(self, input):
padding = int((self.filterDim - 1)/2)
temp_padding = ZeroPad2dFlax(padding)
input_ = temp_padding(input)
number_tiling_positions = (input_.shape[3] - 2 * padding) / self.outTileDim
if number_tiling_positions.is_integer():
Pad_tiling = ZeroPad2dFlax(0)
else:
decimal_part = number_tiling_positions - int(number_tiling_positions)
to_pad = round((1.0 - decimal_part) * self.outTileDim)
to_pad_even = round(to_pad / 2)
Pad_tiling = ZeroPad2dFlax(to_pad_even)
expected_output_width = input_.shape[2] - 2 * padding
input_ = Pad_tiling(input_)
Tiler = winUtils.TilerFlax(self.outTileDim, self.filterDim)
input_ = Tiler.tile(input_)
weight = jnp.ones((1, self.outCh, self.inCh, self.filterDim, self.filterDim), dtype=self.dtype) ★★★
A_t = params.A_T
B_t = params.B_T
G = params.G
# Note that the PI communication increases by over 10x without Tile Transposition
# Therefore, next we transpose the winograd input and weight for converting EWMM to GEMM
# Weight/Input transformation
weight_winograd = jnp.matmul(jnp.matmul(G, weight), jnp.transpose(G, (1, 0)))
input_winograd = jnp.matmul(jnp.matmul(B_t, input_), jnp.transpose(B_t, (1, 0)))
# Tile Transposition
weight_winograd_TTrans = jnp.transpose(weight_winograd, (0, 3, 4, 1, 2))
input_winograd_TTrans = jnp.transpose(input_winograd, (0, 3, 4, 2, 1))
GEMM = jnp.matmul(weight_winograd_TTrans, input_winograd_TTrans)
output = jnp.transpose(GEMM, (0, 4, 3, 1, 2))
output = jnp.matmul(jnp.matmul(A_t, output), jnp.transpose(A_t, (1, 0)))
output = Tiler.untile(output)
if output.shape[3] is not expected_output_width:
warnings.warn('output dim is not expected. Error may occur !!!')
padding = Pad_tiling.padding
output = output[:, :, padding[0]:-padding[1], padding[2]:-padding[3]]
return output
Note that the line marked with ★★★ is used to initialize an all-ones weights inside the model definition since I think the specific parameters will not significantly affect the comm results. So the flax_model.init(jax.random.PRNGKey(1),jnp.ones(input_shape)) is an empty {}. I don't know if this will have an impact.
Sorry for taking your time. Thanks!