taichi icon indicating copy to clipboard operation
taichi copied to clipboard

Incorrect calculation of num_triplets at SparseMatrix construction

Open junwha opened this issue 1 year ago • 0 comments

Describe the bug Sparse matrix is constructed with incorrect num_triplets, which leads to buffer-overflow read.

To Reproduce PoC was modified from test_sparse_matrix.py

import taichi as ti
arch = ti.cpu # or ti.cuda
ti.init(arch=arch)

def test_build_sparse_matrix_frome_ndarray(dtype, storage_format):
    n = 8
    triplets = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=n)
    A = ti.linalg.SparseMatrix(n=10, m=10, dtype=ti.f32, storage_format=storage_format)

    @ti.kernel
    def fill(triplets: ti.types.ndarray()):
        for i in range(n):
            triplet = ti.Vector([i, i, i], dt=ti.f32)
            triplets[i] = triplet

    fill(triplets)
    A.build_from_ndarray(triplets)

    for i in range(n):
        assert A[i, i] == i

test_build_sparse_matrix_frome_ndarray(ti.f32, "col_major")

Additional comments At make_sparse_matrix_from_ndarray (taichi/program/sparse_matrix.cpp:378), it calculates num_triplets by ndarray.get_nelement() * ndarray.get_element_size() / 3. Here, let ndarray.get_nelement() be N and ndarray.get_element_size() be M. and we know only 3*N*M bytes are accessible from data_ptr.

void make_sparse_matrix_from_ndarray(Program *prog,
                                     SparseMatrix &sm,
                                     const Ndarray &ndarray) {
  std::string sdtype = taichi::lang::data_type_name(sm.get_data_type());
  auto data_ptr = prog->get_ndarray_data_ptr_as_int(&ndarray);
  auto num_triplets = ndarray.get_nelement() * ndarray.get_element_size() / 3;
  if (sdtype == "f32") {
    build_ndarray_template<float32>(sm, data_ptr, num_triplets);
  } else if (sdtype == "f64") {
    build_ndarray_template<float64>(sm, data_ptr, num_triplets);
  } else {
    TI_ERROR("Unsupported sparse matrix data type {}!", sdtype);
  }
}

And at build_ndarray_template (taichi/program/sparse_matrix.cpp:373), it casts data to T typed array, and accesses to index 0 to 3*(num_triplets-1)+2, which is 3*(N*M/3-1)+2 = N*M-1. Thus, it accesses ((char*) data_ptr + (N*M-1)*M), that is, it overflows the limit ((T*) data_ptr + 3*N-1).

template <typename T>
void build_ndarray_template(SparseMatrix &sm,
                            intptr_t data_ptr,
                            size_t num_triplets) {
  using V = Eigen::Triplet<T>;
  std::vector<V> triplets;
  T *data = reinterpret_cast<T *>(data_ptr);
  for (int i = 0; i < num_triplets; i++) {
    x.push_back(
        V(data[i * 3], data[i * 3 + 1], taichi_union_cast<T>(data[i * 3 + 2])));
  }
  sm.build_triplets(static_cast<void *>(&triplets));
}

To fix this, we need to correct the num_triplets as ndarray.get_nelement(). I will open the PR for this.

Thank you!:)

junwha avatar Mar 31 '24 19:03 junwha