spu icon indicating copy to clipboard operation
spu copied to clipboard

[Bug]: 8x communication compared to reported in Cheetah

Open kanav99 opened this issue 9 months ago • 3 comments

Issue Type

Performance

Modules Involved

MPC protocol

Have you reproduced the bug with SPU HEAD?

Yes

Have you searched existing issues?

Yes

SPU Version

0.9.0.dev20240425

OS Platform and Distribution

MacOS Sonama

Python Version

3.11.6

Compiler Version

Apple clang version 15.0.0 (clang-1500.1.0.2.5)

Current Behavior?

I tried doing a int32_t comparison using Cheetah. Expected to see 11*32/8 = 44 bytes of communication, but seeing 352 bytes instead.

Standalone code to reproduce the issue

int main(int argc, char** argv) {
  SPDLOG_INFO("in process");
  llvm::cl::ParseCommandLineOptions(argc, argv);
  
  auto sctx = MakeSPUContext();
  spu::mpc::Factory::RegisterProtocol(sctx.get(), sctx->lctx());

  spu::device::ColocatedIo cio(sctx.get());
  cio.hostSetVar(fmt::format("x-{}", sctx->lctx()->Rank()), Salary.getValue());
  cio.sync();

  auto x = cio.deviceGetVar("x-0");
  auto y = cio.deviceGetVar("x-1");

  auto b0 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
  auto comp = hlo::Less(sctx.get(), x, y);
  auto b1 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", b1 - b0);

  b0 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
  comp = hlo::Less(sctx.get(), x, y);
  b1 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", b1 - b0);

  b0 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;
  comp = hlo::Less(sctx.get(), x, y);
  b1 = sctx->lctx()->GetStats()->sent_bytes + sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", b1 - b0);
  
  auto comp_revealed = hal::dump_public_as<float>(
      sctx.get(),
      hlo::Cast(sctx.get(), comp, spu::VIS_PUBLIC, comp.dtype())
  );

  SPDLOG_INFO("comp_revealed: {}", comp_revealed[0]);
  return 0;
}

Relevant log output

[2024-04-24 22:43:47.553] [info] [gelu.cc:42] bytes communicated: 940678
[2024-04-24 22:43:47.557] [info] [gelu.cc:48] bytes communicated: 352
[2024-04-24 22:43:47.563] [info] [gelu.cc:54] bytes communicated: 352
[2024-04-24 22:43:47.564] [info] [gelu.cc:61] comp_revealed: 1

kanav99 avatar Apr 25 '24 02:04 kanav99

@fionser would you mind taking a look?

tpppppub avatar Apr 25 '24 03:04 tpppppub

@kanav99 Which specific OT you have used inside the MakeSPUContext function. For now, the Cheetah back-end supports 3 kinds of OT.

  • EMP_Ferret the Ferret OT implementation from the EMP library
  • YACL_Ferret the Ferret OT implemented by the SPU team
  • YACL_Spokensoft the IKNP OT (variant)

Here is my standalone test file

#include "libspu/device/io.h"
#include "libspu/kernel/hlo/basic_binary.h"
#include "libspu/mpc/utils/simulate.h"


template <typename T>
spu::Value infeed(spu::SPUContext* hctx, const xt::xarray<T>& ds) {
  spu::device::ColocatedIo cio(hctx);
  if (hctx->lctx()->Rank() == 0) {
    cio.hostSetVar(fmt::format("x-{}", hctx->lctx()->Rank()), ds);
  }
  cio.sync();
  auto x = cio.deviceGetVar("x-0");
  return x;
}

TEST_P(ObjectivesTest, TestLess) {
  using namespace spu;
  using namespace spu::kernel;
  using namespace spu::mpc;
  //
  // preparing input here ...
  //

  spu::mpc::utils::simulate(2, [&](std::shared_ptr<yacl::link::Context> lctx) {
    spu::RuntimeConfig rt_config;
    rt_config.set_protocol(ProtocolKind::CHEETAH);
    rt_config.mutable_cheetah_2pc_config()->set_ot_kind(
        CheetahOtKind::YACL_Ferret);
    rt_config.set_field(FM32);
    rt_config.set_fxp_fraction_bits(12);

    auto _ctx = std::make_unique<spu::SPUContext>(rt_config, lctx);
    auto ctx = _ctx.get();
    spu::mpc::Factory::RegisterProtocol(ctx, lctx);

    auto x = infeed<double>(ctx, _x);
    auto y = infeed<double>(ctx, _x);
    int64_t numel = x.numel();

    auto b0 = ctx->lctx()->GetStats()->sent_bytes +
              ctx->lctx()->GetStats()->recv_bytes;
    auto comp = hlo::Less(ctx, x, y);

    auto b1 = ctx->lctx()->GetStats()->sent_bytes +
              ctx->lctx()->GetStats()->recv_bytes;

    SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / numel);

    b0 = ctx->lctx()->GetStats()->sent_bytes +
         ctx->lctx()->GetStats()->recv_bytes;
    comp = hlo::Less(ctx, x, y);
    b1 = ctx->lctx()->GetStats()->sent_bytes +
         ctx->lctx()->GetStats()->recv_bytes;

    SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / numel);
  });
}

The resultss are

[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.155] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 14:39:36.571] [info] [objectives_test.cc:500] bytes communicated: 2951.3418972332015
[2024-04-25 14:39:36.571] [info] [objectives_test.cc:500] bytes communicated: 2951.3418972332015
[2024-04-25 14:39:36.573] [info] [objectives_test.cc:508] bytes communicated: 40.569169960474305
[2024-04-25 14:39:36.573] [info] [objectives_test.cc:508] bytes communicated: 40.569169960474305

fionser avatar Apr 25 '24 06:04 fionser

@kanav99 Here is another 2PC example (you can simply replace the main function in here) Yes, somting seems go wrong if we are only benchmarking on a single value. So it might be better to use a longer vector, e.g., n = 100 in the following example.

// bazel run -c opt experimental/squirrel:squirrel_demo_main -- --rank=0 --lr=0.2 --field=1
// bazel run -c opt experimental/squirrel:squirrel_demo_main -- --rank=1 --lr=0.3 --field=1

std::unique_ptr<spu::SPUContext> MakeSPUContext() {
  auto lctx = MakeLink(Parties.getValue(), Rank.getValue());

  spu::RuntimeConfig config;
  config.set_protocol(spu::ProtocolKind::CHEETAH);
  config.mutable_cheetah_2pc_config()->set_enable_mul_lsb_error(true);
  // replace `EMP_Ferret` or `YACL_Softspoken` 
  config.mutable_cheetah_2pc_config()->set_ot_kind(
      spu::CheetahOtKind::YACL_Ferret);

  config.set_field(static_cast<spu::FieldType>(Field.getValue()));
  config.set_fxp_fraction_bits(18);
  config.set_fxp_div_goldschmidt_iters(1);
  config.set_enable_hal_profile(EngineTrace.getValue());
  auto hctx = std::make_unique<spu::SPUContext>(config, lctx);
  spu::mpc::Factory::RegisterProtocol(hctx.get(), lctx);
  return hctx;
}

int main(int argc, char** argv) {
  SPDLOG_INFO("in process");
  llvm::cl::ParseCommandLineOptions(argc, argv);
  // YACL_Ferret 
  auto sctx = MakeSPUContext();
  std::vector<size_t> shape = {100};
  xt::xarray<double> input(shape);
  std::fill_n(input.data(), input.size(), LearningRate.getValue());

  auto x = InfeedLabel(sctx.get(), input, sctx->lctx()->Rank() == 0);
  auto y = InfeedLabel(sctx.get(), input, sctx->lctx()->Rank() == 1);

  size_t b0 = sctx->lctx()->GetStats()->sent_bytes +
              sctx->lctx()->GetStats()->recv_bytes;
  auto comp = spu::kernel::hlo::Less(sctx.get(), x, y);
  size_t b1 = sctx->lctx()->GetStats()->sent_bytes +
              sctx->lctx()->GetStats()->recv_bytes;

  printf("size %zd\n", input.size());

  SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());

  b0 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;
  comp = spu::kernel::hlo::Less(sctx.get(), x, y);
  b1 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());

  b0 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;
  comp = spu::kernel::hlo::Less(sctx.get(), x, y);
  b1 = sctx->lctx()->GetStats()->sent_bytes +
       sctx->lctx()->GetStats()->recv_bytes;

  SPDLOG_INFO("bytes communicated: {}", (b1 - b0) * 1. / x.numel());

  auto comp_revealed = spu::kernel::hal::dump_public_as<float>(
      sctx.get(),
      spu::kernel::hlo::Cast(sctx.get(), comp, spu::VIS_PUBLIC, comp.dtype()));

  SPDLOG_INFO("comp_revealed: {}", comp_revealed[0]);
  return 0;
}

Yes, the results are still "not making" sense to me. I would expect the 2nd run also communicated 41bytes. But, if we enlarge the vector size (e.g., n > 200), it just works as I expected.

[2024-04-25 15:12:10.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
[2024-04-25 15:12:10.941] [info] [thread_pool.cc:30] Create a fixed thread pool with size 4
size 100
[2024-04-25 15:12:11.232] [info] [squirrel_demo_main.cc:283] bytes communicated: 12107.21
[2024-04-25 15:12:11.418] [info] [squirrel_demo_main.cc:291] bytes communicated: 2703.88
[2024-04-25 15:12:11.419] [info] [squirrel_demo_main.cc:299] bytes communicated: 41.28
[2024-04-25 15:12:11.419] [info] [squirrel_demo_main.cc:305] comp_revealed: 0
I0425 15:12:11   259 external/com_github_brpc_brpc/src/brpc/server.cpp:1218] Server[yacl::link::transport::internal::ReceiverServiceImpl] is going to quit
[2024-04-25 15:12:11.427] [warning] [channel.h:162] Channel destructor is called before WaitLinkTaskFinish, try stop send thread

fionser avatar Apr 25 '24 07:04 fionser

Thanks for the detailed response. Working in a large batch works for me!

I can close this issue, but do you want me to keep it as is? I realize that you do see an issue here

kanav99 avatar May 01 '24 13:05 kanav99

@kanav99 I close this issue. The "wired" stats might due to some implementation decisions :).

fionser avatar May 01 '24 15:05 fionser