scikit-tree icon indicating copy to clipboard operation
scikit-tree copied to clipboard

Add SMERF tree functionality

Open tyler-tomita opened this issue 2 years ago • 1 comments
trafficstars

Similarity and Metric Random Forests (SMERF) take {X_1, ..., X_n} and an n-by-n matrix Z in which z_ij = f(X_i, X_j) is a true but unknown similarity measure between X_i and X_j. The goal is to learn f(•, •).

The splitting objective function is to find the split that maximizes the average of Z(L) and Z(R), where Z(L) and Z(R) are the similarity matrices for points to the left and right of the split, respectively.

tyler-tomita avatar Mar 02 '23 18:03 tyler-tomita

Okay now that I thought about this a little bit more, I think this can rely on the existing TreeBuilder/Splitter Cython code, where y can be 2D array. Other changes will require just work in the Python code, which will be more simple. The SMERF algorithm seems like all it would need to compute a metric is the split point "s".

Given "s" then at every split the feature vector {x_1, ..., x_n} is sorted, so you now have a reordering of the sample indices {1,...,n}. Then getting samples in the similarity matrix y would just mean the left/right are:

  • left child = y[:s, :s]
  • right child = y[s:, s:]

So really, everything else can be passed as is...

Some ideas on Cython implementation

Most likely, you could subclass the ClassificationCriterion class here: https://github.com/neurodata/scikit-learn/blob/e185973346da837cb755a84ec4b2ccee42ea1690/sklearn/tree/_criterion.pyx#L217 and override most of the functions there. E.g.

  • __cinit__: the n_classes will just be set arbitrarily to 1 and we will use n_outputs to encode the dimensionality of the output?
  • set_sample_pointers, update, etc. I think

Here is a description of some of the data structures and functions used internally:

  • sum_total, sum_right and sum_left should keep track of the "metric" in the current node, right child and left child respectively. These are basically computed when the splitter calls set_sample_pointers and update.
  • node_impurity computes the impurity at the root node
  • child_impurity computes the impurity of the left/right child

The first step I think would be sketching out the new Criterion class and then trying to add the Python class for SMERFDecisionTree, which can probably subclass BaseDecisionTree inside our sklearn.

adam2392 avatar Mar 02 '23 19:03 adam2392