cuml
cuml copied to clipboard
[BUG] Incorrect NaN handling in experimental FIL with categorical splits
Describe the bug
The experimental FIL library gives incorrect results for NaN feature values with categorical splits. It looks like if the split categories contains a zero then a NaN will always go left rather than obeying the "default_left" property of the node. When the categories don't contain a zero the "default_left" property is used correctly.
Steps/Code to reproduce bug
The following code reproduces the issue and compares results to Treelite GTIL and the standard FIL library:
#include <cuda_runtime.h>
#include <cuml/fil/fil.h>
#include <cuml/experimental/fil/treelite_importer.hpp>
#include <cuml/experimental/fil/forest_model.hpp>
#include <raft/raft.hpp>
#include <treelite/frontend.h>
#include <treelite/tree.h>
#include <treelite/gtil.h>
#define CUDA_RT_CALL(call) do { \
cudaError_t cuda_status = (call); \
if (cudaSuccess != cuda_status) { \
throw std::runtime_error( \
"CUDA call failed with status " + \
std::to_string(cuda_status) + ": " + \
std::string(cudaGetErrorString(cuda_status))); \
} \
} while (false)
std::unique_ptr<treelite::Model> create_model() {
std::unique_ptr<treelite::Model> model = treelite::Model::Create<float, float>();
model->num_feature = 1;
model->average_tree_output = false;
model->task_type = treelite::TaskType::kBinaryClfRegr;
model->task_param.grove_per_class = false;
model->task_param.output_type = treelite::TaskParam::OutputType::kFloat;
model->task_param.num_class = 1;
model->task_param.leaf_vector_size = 1;
model->param.global_bias = 0.0;
std::strncpy(model->param.pred_transform, "identity", sizeof(model->param.pred_transform));
model->SetTreeLimit(1);
auto* model_impl = dynamic_cast<treelite::ModelImpl<float, float>*>(model.get());
treelite::Tree<float, float>& tree = model_impl->trees[0];
tree.Init();
tree.AddChilds(0);
auto leftId = tree.LeftChild(0);
auto rightId = tree.RightChild(0);
bool defaultLeft = false;
bool categoriesGoRight = false;
std::vector<uint32_t> categories { 0u, 2u };
tree.SetCategoricalSplit(0, 0, defaultLeft, categories, categoriesGoRight);
tree.SetLeaf(leftId, 1.0);
tree.SetLeaf(rightId, 2.0);
return model;
}
std::vector<float> predict_gtil(
const std::unique_ptr<treelite::Model>& model, const std::vector<float>& features) {
size_t num_rows = features.size() / model->num_feature;
std::vector<float> predictions(num_rows);
treelite::gtil::Configuration config;
std::vector<size_t> output_shape(1);
treelite::gtil::Predict(model.get(), features.data(), num_rows, predictions.data(), config, output_shape);
return predictions;
}
std::vector<float> predict_orig_fil(
const std::unique_ptr<treelite::Model>& model, const std::vector<float>& features) {
ML::fil::treelite_params_t params{
.algo = ML::fil::algo_t::ALGO_AUTO,
.output_class = false,
.threshold = 0.0f,
.storage_type = ML::fil::storage_type_t::DENSE,
.blocks_per_sm = 0,
.threads_per_tree = 16,
.n_items = 0,
.pforest_shape_str = nullptr
};
cudaStream_t stream;
CUDA_RT_CALL(cudaStreamCreate(&stream));
raft::handle_t handle(stream);
ML::fil::forest_variant forest_variant;
ML::fil::from_treelite(handle, &forest_variant, model.get(), ¶ms);
ML::fil::forest_t<float> forest = std::get<ML::fil::forest_t<float>>(forest_variant);
size_t num_features = model->num_feature;
size_t num_rows = features.size() / num_features;
float* device_features;
float* device_preds;
std::vector<float> predictions(num_rows);
CUDA_RT_CALL(cudaMalloc(&device_features, num_features * num_rows * sizeof(float)));
CUDA_RT_CALL(cudaMalloc(&device_preds, num_rows * sizeof(float)));
CUDA_RT_CALL(cudaMemcpy(
device_features, features.data(), num_features * num_rows * sizeof(float), cudaMemcpyHostToDevice));
ML::fil::predict(handle, forest, device_preds, device_features, num_rows, false);
CUDA_RT_CALL(cudaStreamSynchronize(stream));
CUDA_RT_CALL(cudaMemcpy(
predictions.data(), device_preds, num_rows * sizeof(float), cudaMemcpyDeviceToHost));
CUDA_RT_CALL(cudaFree(device_features));
CUDA_RT_CALL(cudaFree(device_preds));
CUDA_RT_CALL(cudaStreamDestroy(stream));
return predictions;
}
std::vector<float> predict_experimental_gpu(
const std::unique_ptr<treelite::Model>& model, const std::vector<float>& features) {
cudaStream_t stream;
CUDA_RT_CALL(cudaStreamCreate(&stream));
raft::handle_t handle(stream);
auto fil_model = ML::experimental::fil::import_from_treelite_model(
*model,
ML::experimental::fil::tree_layout::breadth_first,
128u,
false,
raft_proto::device_type::gpu,
0,
stream);
size_t num_features = model->num_feature;
size_t num_rows = features.size() / num_features;
float* device_features;
float* device_preds;
std::vector<float> predictions(num_rows);
CUDA_RT_CALL(cudaMalloc(&device_features, num_features * num_rows * sizeof(float)));
CUDA_RT_CALL(cudaMalloc(&device_preds, num_rows * sizeof(float)));
CUDA_RT_CALL(cudaMemcpy(
device_features, features.data(), num_features * num_rows * sizeof(float), cudaMemcpyHostToDevice));
fil_model.predict(
handle,
device_preds,
device_features,
num_rows,
raft_proto::device_type::gpu,
raft_proto::device_type::gpu);
CUDA_RT_CALL(cudaStreamSynchronize(stream));
CUDA_RT_CALL(cudaMemcpy(
predictions.data(), device_preds, num_rows * sizeof(float), cudaMemcpyDeviceToHost));
CUDA_RT_CALL(cudaFree(device_features));
CUDA_RT_CALL(cudaFree(device_preds));
CUDA_RT_CALL(cudaStreamDestroy(stream));
return predictions;
}
std::vector<float> predict_experimental_cpu(
const std::unique_ptr<treelite::Model>& model, const std::vector<float>& features) {
raft::handle_t handle;
auto fil_model = ML::experimental::fil::import_from_treelite_model(
*model,
ML::experimental::fil::tree_layout::breadth_first,
0u,
false,
raft_proto::device_type::cpu,
0,
nullptr);
size_t num_rows = features.size() / model->num_feature;
std::vector<float> predictions(num_rows);
fil_model.predict(
handle,
predictions.data(),
(float*) features.data(),
num_rows,
raft_proto::device_type::cpu,
raft_proto::device_type::cpu);
return predictions;
}
int main(int argc, char* argv[]) {
auto model = create_model();
std::vector<float> features { std::numeric_limits<float>::quiet_NaN() };
auto gtil = predict_gtil(model, features);
auto orig_fil = predict_orig_fil(model, features);
auto experimental_cpu = predict_experimental_cpu(model, features);
auto experimental_gpu = predict_experimental_gpu(model, features);
std::cout << "GTIL result = " << gtil[0] << "\n"
<< "Original FIL result = " << orig_fil[0] << "\n"
<< "Experimental FIL CPU result = " << experimental_cpu[0] << "\n"
<< "Experimental FIL GPU result = " << experimental_gpu[0] << "\n";
return 0;
}
This outputs:
GTIL result = 2
Original FIL result = 2
Experimental FIL CPU result = 1
Experimental FIL GPU result = 1
Expected behavior
I'd expect experimental FIL to return 2.0 like the other implementations as NaN should be sent to the right child of the split based on default_left
being false.
Environment details:
- Environment location: Docker
- Linux Distro/Architecture: Ubuntu 20.04 x64
- GPU Model/Driver: Nvidia Tesla T4 and driver 525.125.06
- CUDA: 12.0
- Method of cuDF & cuML install: From source, commit
a381e03
(23.06.00), cmake 2.23.2 and gcc 9.4.0
Additional context
From briefly looking into the code but not properly debugging, I think this is due to NaN being cast to uint32, which results in 0, before being passed into bitset.test
, so bitset.test
returns true and the if branch here isn't entered: https://github.com/rapidsai/cuml/blob/7d86042b8de06bc8acce632230fe5821bd36c17d/cpp/include/cuml/experimental/fil/detail/evaluate_tree.hpp#L74
Only one of the evaluate_tree_impl
implementations seems to have this problem, the other checks for NaN before testing category values.
Thanks for the detailed bug report. We will look into it and get back to you.