handson-ml2 icon indicating copy to clipboard operation
handson-ml2 copied to clipboard

[QUESTION] How MultiLabel Classification in KNN works

Open Kirushikesh opened this issue 3 years ago • 1 comments

Hi @ageron, In chapter 3, under the MultiLabel Classification section you demonstrated the topic using the KNN algorithm, i was quite eager to know how it works i.e., internally how KNN handles the Multi-Label Classification problem? I tried to browse the net for the answers but nothing is as good as your explanations xD.

Kirushikesh avatar Apr 25 '22 13:04 Kirushikesh

Hi @Kirushikesh ,

Thanks for your kind words and your interesting question. Scikit-Learn's source code is actually quite readable, that's often where I get the best answer to my questions, when it's not in the docs. In this case, you'll find the answer in the predict() method of the KNeighborsClassifier class (see source code):

for k, classes_k in enumerate(classes_):
        if weights is None:
            mode, _ = stats.mode(_y[neigh_ind, k], axis=1)
        else:
            mode, _ = weighted_mode(_y[neigh_ind, k], weights, axis=1)

In this code, _y contains the labels (as class indices), and neigh_ind contains the indices of the K nearest neighbors. Also, in this code k is the label index. For example, if the data has two labels "large / not large" and "odd / not odd" (as in the example in the book) then k will be 0 (for "large / not large"), then 1 (for "odd / not odd"). This k should not to be confused with the K in "K Nearest Neighbors".

As you can see, the predict() method just looks at each label independently, and it finds the most common class (i.e., the mode) among the K nearest neighbors.

For example, if K=3, then the algorithm will only look at the 3 nearest neighbors. Suppose they are:

  • Large, Odd
  • Large, Not Odd
  • Not Large, Not Odd

Then in this case the model will predict Large, Not Odd. In this case it's not a combination that any of the K nearest neighbors uses!

Note that the classes can be weighted, in which case the weighted_mode() function will be called instead of stats.mode().

Hope this helps.

ageron avatar May 05 '22 04:05 ageron