cudf icon indicating copy to clipboard operation
cudf copied to clipboard

[BUG] Can't concurrently join by multiple threads with cudf::hash-join

Open ahmet-uyar opened this issue 1 year ago • 7 comments

Describe the bug I create a single cudf::hash_join object and then try to join multiple probing tables by multiple threads concurrently. I get non-deterministic results.

Steps/Code to reproduce bug You can run below code with gtest. It sometimes succeeds but most often fails.

std::unique_ptr<cudf::table> InnerHashJoin(const cudf::table_view& probe_view,
                                           const cudf::table_view& build_view,
                                           const cudf::hash_join& hash_join_object,
                                           const std::vector<int>& join_column_indices) {
  try {
    std::unique_ptr<rmm::device_uvector<cudf::size_type>> build_indices, probe_indices;
    std::tie(probe_indices, build_indices) =
        hash_join_object.inner_join(probe_view.select(join_column_indices));
    size_t out_rows = probe_indices->size();
    auto left_indices_span =
        cudf::device_span<cudf::size_type const>(probe_indices->data(), out_rows);
    auto right_indices_span =
        cudf::device_span<cudf::size_type const>(build_indices->data(), out_rows);
    auto left_joined = cudf::gather(probe_view, cudf::column_view{left_indices_span},
                                    cudf::out_of_bounds_policy::DONT_CHECK);
    auto right_joined = cudf::gather(build_view, cudf::column_view{right_indices_span},
                                     cudf::out_of_bounds_policy::DONT_CHECK);
    auto left_cols = left_joined->release();
    auto right_cols = right_joined->release();
    std::move(right_cols.begin(), right_cols.end(), std::back_inserter(left_cols));
    return std::make_unique<cudf::table>(std::move(left_cols));
  } catch (...) {
    throw;
  }
}

template <typename T>
using column_wrapper = cudf::test::fixed_width_column_wrapper<T>;

std::unique_ptr<cudf::table> MakeTable(column_wrapper<int32_t> col) {
  std::vector<std::unique_ptr<cudf::column>> cols;
  cols.push_back(col.release());
  return std::make_unique<cudf::table>(std::move(cols));
}

std::unique_ptr<cudf::table> MakeTable(column_wrapper<int32_t> col1, column_wrapper<int32_t> col2) {
  std::vector<std::unique_ptr<cudf::column>> cols;
  cols.push_back(col1.release());
  cols.push_back(col2.release());
  return std::make_unique<cudf::table>(std::move(cols));
}

TEST_F(HashJoinTest, MultiParallelJoin) {
  // build_table is the right table, proble_table is the left table
  auto build_table = MakeTable(column_wrapper<int32_t>{{0, 1, 2, 3, 4}});

  int join_count = 3;
  std::vector<std::unique_ptr<cudf::table>> probe_tables(join_count);
  std::vector<std::unique_ptr<cudf::table>> expected_results(join_count);
  probe_tables[0] = MakeTable(column_wrapper<int32_t>{{3, 4, 5, 6, 7}});
  expected_results[0] = MakeTable(column_wrapper<int32_t>{{3, 4}}, column_wrapper<int32_t>{{3, 4}});

  probe_tables[1] = MakeTable(column_wrapper<int32_t>{{0, 2, 4, 6, 8}});
  expected_results[1] = MakeTable(column_wrapper<int32_t>{{0, 2, 4}}, column_wrapper<int32_t>{{0, 2, 4}});

  probe_tables[2] = MakeTable(column_wrapper<int32_t>{{1, 3, 5, 7, 9}});
  expected_results[2] = MakeTable(column_wrapper<int32_t>{{1, 3}}, column_wrapper<int32_t>{{1, 3}});

  std::vector<int> join_column_indices = {0};
  cudf::hash_join hash_join_obj(build_table->view().select(join_column_indices),
                                cudf::null_equality::UNEQUAL);
  // join tables in parallel
  cudaDeviceSynchronize();

  std::vector<std::thread> threads(join_count);
  for (int i = 0; i < join_count; i++) {
    threads[i] = std::thread([&, i] {
      auto joined_table =
          InnerHashJoin(probe_tables[i]->view(), build_table->view(),
                        hash_join_obj, join_column_indices);
      EXPECT_EQ(joined_table->num_rows(), expected_results[i]->num_rows()) << "failed iteration: " << i;
    });
  }

  for (int i = 0; i < join_count; i++) {
    threads[i].join();
  }
}

Expected behavior I think, multiple concurrent joins should be supported by hash_join. Documentation also says that "This class enables the hash join scheme that builds hash table once, and probes as many times as needed (possibly in parallel). "

Environment overview (please complete the following information) I am using: libcudf 24.06.01.03 PTDS enabled

Environment details Please run and paste the output of the cudf/print_env.sh script here, to gather any other relevant environment details

Additional context Add any other context about the problem here.

ahmet-uyar avatar Jul 26 '24 14:07 ahmet-uyar

@ahmet-uyar Thanks for raising your concern.

Unlike pandas, cudf/libcudf join uses hash-based algorithms so the order is not guaranteed. See:

  • https://docs.rapids.ai/api/cudf/stable/user_guide/10min/#join
  • https://github.com/rapidsai/cudf/blob/d953676e9281125a5b8bd9be739c997611471771/cpp/include/cudf/join.hpp#L80

The above tests should pass if checked against the sorted join output. Please let us know if that fixes the issue.

PointKernel avatar Jul 31 '24 19:07 PointKernel

As you can see in the test case, I am comparing only num_rows() in both tables. So it is not an order of rows issue in joined tables. Even the number of rows doesn't match in the joined tables. When I calculated total_num_rows from 3 tables, that usually doesn't match with the expected total number of rows (7 rows).

ahmet-uyar avatar Jul 31 '24 20:07 ahmet-uyar

As you can see in the test case, I am comparing only num_rows() in both tables. So it is not an order of rows issue in joined tables. Even the number of rows doesn't match in the joined tables. When I calculated total_num_rows from 3 tables, that usually doesn't match with the expected total number of rows (7 rows).

Ah OK, sorry about the confusion. Looking into it now

PointKernel avatar Jul 31 '24 20:07 PointKernel

@ahmet-uyar This is indeed a bug on our end but it's not in cudf.

Long story short, cudf hash join is implemented via cuCollections's static_multimap and two APIs involved in this process, pair_count and pair_retrieve, are not thread-safe. I've opened https://github.com/NVIDIA/cuCollections/issues/566 to track this issue.

Will keep you posted.

PointKernel avatar Aug 01 '24 03:08 PointKernel

Update: https://github.com/rapidsai/cudf/pull/16496 demonstrates that this bug will be fixed once rapids-cmake fetches the fix in https://github.com/NVIDIA/cuCollections/pull/569.

@ahmet-uyar Will let you know once the fix is in place

PointKernel avatar Aug 05 '24 23:08 PointKernel

thanks for the update @PointKernel

ahmet-uyar avatar Aug 06 '24 06:08 ahmet-uyar

@PointKernel should we update rapids-cmake, or are there blockers to adopting the latest cuco?

vyasr avatar Aug 16 '24 19:08 vyasr

Close as it's resolved with cuco version bump.

PointKernel avatar Sep 30 '24 21:09 PointKernel