Machine-Learning-in-Action-Python3
Machine-Learning-in-Action-Python3 copied to clipboard
DecisionTree_Project2/DecisionTree.py 方法classify 使用有误
530行里 调用classify方法,给定的第二项输入应该是完整的数据labels,且顺序和数据集顺序应该一致
同时优化了classify方法的写法 更加直观
def classify(inputTree, featLabels, testVec):
# 获取决策树结点
# 当前树节点的key首项 表明选择的特征类型
keyLabel = list(inputTree.keys())[0]
# 对应类型的特征树
currDict = inputTree[keyLabel]
# 获取特征类型在特征中的index
featIndex = featLabels.index(keyLabel)
# 获取当前的特征叶子 或者是 特征树
judgeValue = currDict.get(testVec[featIndex])
# 如果是树就继续向下走 如果是叶子输出
if type(judgeValue).__name__ == 'dict':
return classify(judgeValue, featLabels, testVec)
else:
return judgeValue