spu
spu copied to clipboard
[Bug]: 8x communication compared to reported in Cheetah
Issue Type
Performance
Modules Involved
MPC protocol
Have you reproduced the bug with SPU HEAD?
Yes
Have you searched existing issues?
Yes
SPU Version
0.9.0.dev20240425
OS Platform and Distribution
MacOS Sonama
Python Version
3.11.6
Compiler Version
Apple clang version 15.0.0 (clang-1500.1.0.2.5)
Current Behavior?
I tried doing a int32_t comparison using Cheetah. Expected to see 11*32/8 = 44 bytes of communication, but seeing 352 bytes instead.
Standalone code to reproduce the issue
int main(int argc, char** argv) {
SPDLOG_INFO("in process");
llvm::cl::ParseCommandLineOptions(argc, argv);
auto sctx = MakeSPUContext();
spu::mpc::Factory::RegisterProtocol(sctx.get(), sctx->lctx());
spu::device::ColocatedIo cio(sctx.get());
cio.hostSetVar(fmt::format("x-{}", sctx->lctx()->Rank()), Salary.getValue());
cio.sync();
auto x = cio.deviceGetVar("x-0");
auto y = cio.deviceGetVar("x-1");
auto b0 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
auto comp = hlo::Less(sctx.get(), x, y);
auto b1 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", b1 - b0);
b0 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
comp = hlo::Less(sctx.get(), x, y);
b1 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", b1 - b0);
b0 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
comp = hlo::Less(sctx.get(), x, y);
b1 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", b1 - b0);
auto comp_revealed = hal::dump_public_as<float>(
sctx.get(),
hlo::Cast(sctx.get(), comp, spu::VIS_PUBLIC, comp.dtype())
);
SPDLOG_INFO("comp_revealed: {}", comp_revealed[0]);
return 0;
}
Relevant log output
[2024-04-24 22:43:47.553] [info] [gelu.cc:42] bytes communicated: 940678
[2024-04-24 22:43:47.557] [info] [gelu.cc:48] bytes communicated: 352
[2024-04-24 22:43:47.563] [info] [gelu.cc:54] bytes communicated: 352
[2024-04-24 22:43:47.564] [info] [gelu.cc:61] comp_revealed: 1
@fionser would you mind taking a look?
@kanav99 Which specific OT you have used inside the MakeSPUContext
function.
For now, the Cheetah back-end supports 3 kinds of OT.
-
EMP_Ferret
the Ferret OT implementation from the EMP library -
YACL_Ferret
the Ferret OT implemented by the SPU team -
YACL_Spokensoft
the IKNP OT (variant)
Here is my standalone test file
#include "libspu/device/io.h"
#include "libspu/kernel/hlo/basic_binary.h"
#include "libspu/mpc/utils/simulate.h"
template <typename T>
spu::Value infeed(spu::SPUContext* hctx, const xt::xarray<T>& ds) {
spu::device::ColocatedIo cio(hctx);
if (hctx->lctx()->Rank() == 0) {
cio.hostSetVar(fmt::format("x-{}", hctx->lctx()->Rank()), ds);
}
cio.sync();
auto x = cio.deviceGetVar("x-0");
return x;
}
TEST_P(ObjectivesTest, TestLess) {
using namespace spu;
using namespace spu::kernel;
using namespace spu::mpc;
//
// preparing input here ...
//
spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
spu::RuntimeConfig rt_config;
rt_config.set_protocol(ProtocolKind::CHEETAH);
rt_config.mutable_cheetah_2pc_config()->set_ot_kind(
CheetahOtKind::YACL_Ferret);
rt_config.set_field(FM32);
rt_config.set_fxp_fraction_bits(12);
auto _ctx = std::make_unique<spu::SPUContext>(rt_config, lctx);
auto ctx = _ctx.get();
spu::mpc::Factory::RegisterProtocol(ctx, lctx);
auto x = infeed<double>(ctx, _x);
auto y = infeed<double>(ctx, _x);
int64_t numel = x.numel();
auto b0 = ctx->lctx()->GetStats()->sent_bytes +
ctx->lctx()->GetStats()->recv_bytes;
auto comp = hlo::Less(ctx, x, y);
auto b1 = ctx->lctx()->GetStats()->sent_bytes +
ctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / numel);
b0 = ctx->lctx()->GetStats()->sent_bytes +
ctx->lctx()->GetStats()->recv_bytes;
comp = hlo::Less(ctx, x, y);
b1 = ctx->lctx()->GetStats()->sent_bytes +
ctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / numel);
});
}
The resultss are
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.571] [info] [objectives_test.cc:500] bytes communicated: 2951.3418972332015
[2024-04-25 14:39:36.571] [info] [objectives_test.cc:500] bytes communicated: 2951.3418972332015
[2024-04-25 14:39:36.573] [info] [objectives_test.cc:508] bytes communicated: 40.569169960474305
[2024-04-25 14:39:36.573] [info] [objectives_test.cc:508] bytes communicated: 40.569169960474305
@kanav99 Here is another 2PC example (you can simply replace the main function in here) Yes, somting seems go wrong if we are only benchmarking on a single value. So it might be better to use a longer vector, e.g., n = 100 in the following example.
// bazel run -c opt experimental/squirrel:squirrel_demo_main -- --rank=0 --lr=0.2 --field=1
// bazel run -c opt experimental/squirrel:squirrel_demo_main -- --rank=1 --lr=0.3 --field=1
std::unique_ptr<spu::SPUContext> MakeSPUContext() {
auto lctx = MakeLink(Parties.getValue(), Rank.getValue());
spu::RuntimeConfig config;
config.set_protocol(spu::ProtocolKind::CHEETAH);
config.mutable_cheetah_2pc_config()->set_enable_mul_lsb_error(true);
// replace `EMP_Ferret` or `YACL_Softspoken`
config.mutable_cheetah_2pc_config()->set_ot_kind(
spu::CheetahOtKind::YACL_Ferret);
config.set_field(static_cast<spu::FieldType>(Field.getValue()));
config.set_fxp_fraction_bits(18);
config.set_fxp_div_goldschmidt_iters(1);
config.set_enable_hal_profile(EngineTrace.getValue());
auto hctx = std::make_unique<spu::SPUContext>(config, lctx);
spu::mpc::Factory::RegisterProtocol(hctx.get(), lctx);
return hctx;
}
int main(int argc, char** argv) {
SPDLOG_INFO("in process");
llvm::cl::ParseCommandLineOptions(argc, argv);
// YACL_Ferret
auto sctx = MakeSPUContext();
std::vector<size_t> shape = {100};
xt::xarray<double> input(shape);
std::fill_n(input.data(), input.size(), LearningRate.getValue());
auto x = InfeedLabel(sctx.get(), input, sctx->lctx()->Rank() == 0);
auto y = InfeedLabel(sctx.get(), input, sctx->lctx()->Rank() == 1);
size_t b0 = sctx->lctx()->GetStats()->sent_bytes +
sctx->lctx()->GetStats()->recv_bytes;
auto comp = spu::kernel::hlo::Less(sctx.get(), x, y);
size_t b1 = sctx->lctx()->GetStats()->sent_bytes +
sctx->lctx()->GetStats()->recv_bytes;
printf("size %zd\n", input.size());
SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());
b0 = sctx->lctx()->GetStats()->sent_bytes +
sctx->lctx()->GetStats()->recv_bytes;
comp = spu::kernel::hlo::Less(sctx.get(), x, y);
b1 = sctx->lctx()->GetStats()->sent_bytes +
sctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());
b0 = sctx->lctx()->GetStats()->sent_bytes +
sctx->lctx()->GetStats()->recv_bytes;
comp = spu::kernel::hlo::Less(sctx.get(), x, y);
b1 = sctx->lctx()->GetStats()->sent_bytes +
sctx->lctx()->GetStats()->recv_bytes;
SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());
auto comp_revealed = spu::kernel::hal::dump_public_as<float>(
sctx.get(),
spu::kernel::hlo::Cast(sctx.get(), comp, spu::VIS_PUBLIC, comp.dtype()));
SPDLOG_INFO("comp_revealed: {}", comp_revealed[0]);
return 0;
}
Yes, the results are still "not making" sense to me. I would expect the 2nd run also communicated 41bytes. But, if we enlarge the vector size (e.g., n > 200), it just works as I expected.
[2024-04-25 15:12:10.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 15:12:10.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
size 100
[2024-04-25 15:12:11.232] [info] [squirrel_demo_main.cc:283] bytes communicated: 12107.21
[2024-04-25 15:12:11.418] [info] [squirrel_demo_main.cc:291] bytes communicated: 2703.88
[2024-04-25 15:12:11.419] [info] [squirrel_demo_main.cc:299] bytes communicated: 41.28
[2024-04-25 15:12:11.419] [info] [squirrel_demo_main.cc:305] comp_revealed: 0
I0425 15:12:11 259 external/com_github_brpc_brpc/src/brpc/server.cpp:1218] Server[yacl::link::transport::internal::ReceiverServiceImpl] is going to quit
[2024-04-25 15:12:11.427] [warning] [channel.h:162] Channel destructor is called before WaitLinkTaskFinish, try stop send thread
Thanks for the detailed response. Working in a large batch works for me!
I can close this issue, but do you want me to keep it as is? I realize that you do see an issue here
@kanav99 I close this issue. The "wired" stats might due to some implementation decisions :).