决策树 Decision_trees - Ex 4: Understanding the decision tree structure

优质
小牛编辑
126浏览
2023-12-01

决策树范例四: Understanding the decision tree structure

http://scikit-learn.org/stable/auto_examples/tree/plot_unveil_tree_structure.html

范例目的

此范例主要在进一步探讨决策树内部的结构,分析以获得特征与目标之间的关係,并进而进行预测。

  1. 当每个节点的分支最多只有两个称之为二元树结构。
  2. 判断每个深度的节点是否为叶,在二元树中若该节点为判断的最后一层称之为叶。
  3. 利用 decision_path 获得决策路径的资讯。
  4. 利用 apply 得到预测结果,也就是决策树最后抵达的叶。
  5. 建立完成后的规则变能用来预测。
  6. 一组多个样本可以寻得其中共同的决策路径。

(一)引入函式库及测试资料

引入函式资料库

  • load_iris 引入鸢尾花资料库。
  1. from sklearn.model_selection import train_test_split
  2. from sklearn.datasets import load_iris
  3. from sklearn.tree import DecisionTreeClassifier

建立训练、测试集及决策树分类器

  • X (特征资料) 以及 y (目标资料)。
  • train_test_split(X, y, random_state) 将资料随机分为测试集及训练集。

    X为特征资料集、y为目标资料集,random_state 随机数生成器。
  • DecisionTreeClassifier(max_leaf_nodes, random_state) 建立决策树分类器。

    max_leaf_nodes 节点为叶的最大数目,random_state 若存在则为随机数生成器,若不存在则使用np.random
  • fit(X, y) 用做训练,X为训练用特征资料,y为目标资料。
  1. iris = load_iris()
  2. X = iris.data
  3. y = iris.target
  4. X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
  5. estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
  6. estimator.fit(X_train, y_train)

(二) 决策树结构探讨

DecisionTreeClassifier 中有个属性 tree_,储存了整个树的结构。

二元树被表示为多个平行的矩阵,每个矩阵的第i个元素储存着关于节点”i”的信息,节点0代表树的根。

需要注意的是,有些矩阵只适用于有分支的节点,在这种情况下,其他类型的节点的值是任意的。

上述所说的矩阵包含了:

  1. node_count :总共的节点个数。
  2. children_left:节点左边的节点的ID,”-1”代表该节点底下已无分支。
  3. children_righ:节点右边的节点的ID,”-1”代表该节点底下已无分支。
  4. feature:使节点产生分支的特征,”-2”代表该节点底下已无分支。
  5. threshold:节点的阀值。若距离不超过 threshold ,则边的两端就视作同一个群集。
  1. n_nodes = estimator.tree_.node_count
  2. children_left = estimator.tree_.children_left
  3. children_right = estimator.tree_.children_right
  4. feature = estimator.tree_.feature
  5. threshold = estimator.tree_.threshold

以下为各矩阵的内容

  1. n_nodes = 5
  2. children_left [ 1 -1 3 -1 -1]
  3. children_right [ 2 -1 4 -1 -1]
  4. feature [ 3 -2 2 -2 -2]
  5. threshold [ 0.80000001 -2. 4.94999981 -2. -2. ]

二元树的结构所通过的各个属性是可以被计算的,例如每个节点的深度以及是否为树的最底层。

  • node_depth :节点在决策树中的深度(层)。
  • is_leaves :该节点是否为决策树的最底层(叶)。
  • stack:存放尚未判断是否达决策树底层的节点资讯。

将stack的一组节点资讯pop出来,判断该节点的左边节点ID是否等于右边节点ID。

若不相同分别将左右节点的资讯加入stack中,若相同则该节点已达底层is_leaves设为True。

  1. node_depth = np.zeros(shape=n_nodes)
  2. is_leaves = np.zeros(shape=n_nodes, dtype=bool)
  3. stack = [(0, -1)] #initial
  4. while len(stack) > 0:
  5. node_id, parent_depth = stack.pop()
  6. node_depth[node_id] = parent_depth + 1
  7. # If we have a test node
  8. if (children_left[node_id] != children_right[node_id]):
  9. stack.append((children_left[node_id], parent_depth + 1))
  10. stack.append((children_right[node_id], parent_depth + 1))
  11. else:
  12. is_leaves[node_id] = True

执行过程

  1. stack len 1
  2. node_id 0 parent_depth -1
  3. node_depth [ 0. 0. 0. 0. 0.]
  4. stack [(1, 0), (2, 0)]
  5. stack len 2
  6. node_id 2 parent_depth 0
  7. node_depth [ 0. 0. 1. 0. 0.]
  8. stack [(1, 0), (3, 1), (4, 1)]
  9. stack len 3
  10. node_id 4 parent_depth 1
  11. node_depth [ 0. 0. 1. 0. 2.]
  12. stack [(1, 0), (3, 1)]
  13. stack len 2
  14. node_id 3 parent_depth 1
  15. node_depth [ 0. 0. 1. 2. 2.]
  16. stack [(1, 0)]
  17. stack len 1
  18. node_id 1 parent_depth 0
  19. node_depth [ 0. 1. 1. 2. 2.]
  20. stack []

Ex 4: Understanding the decision tree structure - 图1

下面这个部分是以程式的方式印出决策树结构,这个决策树共有5个节点。

若遇到的是test node则用阀值决定该往哪个节点前进,直到走到叶为止。

  1. print("The binary tree structure has %s nodes and has "
  2. "the following tree structure:"
  3. % n_nodes)
  4. for i in range(n_nodes):
  5. if is_leaves[i]:
  6. print("%snode=%s leaf node." % (node_depth[i] * "\t", i)) #"\t"缩排
  7. else:
  8. print("%snode=%s test node: go to node %s if X[:, %s] <= %s else to "
  9. "node %s."
  10. % (node_depth[i] * "\t",
  11. i,
  12. children_left[i],
  13. feature[i],
  14. threshold[i],
  15. children_right[i],
  16. ))

执行结果

  1. The binary tree structure has 5 nodes and has the following tree structure:
  2. node=0 test node: go to node 1 if X[:, 3] <= 0.800000011921 else to node 2.
  3. node=1 leaf node.
  4. node=2 test node: go to node 3 if X[:, 2] <= 4.94999980927 else to node 4.
  5. node=3 leaf node.
  6. node=4 leaf node.

接下来要来探索每个样本的决策路径,利用decision_path方法可以让我们得到这些资讯,apply存放所有sample最后抵达哪个叶。

以第0笔样本当作范例,indices存放每个样本经过的节点,indptr存放每个样本存放节点的位置,node_index中存放了第0笔样本所经过的节点ID。

  1. node_indicator = estimator.decision_path(X_test)
  2. # Similarly, we can also have the leaves ids reached by each sample.
  3. leave_id = estimator.apply(X_test)
  4. # Now, it's possible to get the tests that were used to predict a sample or
  5. # a group of samples. First, let's make it for the sample.
  6. sample_id = 0
  7. node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
  8. node_indicator.indptr[sample_id + 1]]
  9. print('node_index', node_index)
  10. print('Rules used to predict sample %s: ' % sample_id)
  11. for node_id in node_index:
  12. if leave_id[sample_id] != node_id:
  13. continue
  14. if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
  15. threshold_sign = "<="
  16. else:
  17. threshold_sign = ">"
  18. print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
  19. % (node_id,
  20. sample_id,
  21. feature[node_id],
  22. X_test[i, feature[node_id]],
  23. threshold_sign,
  24. threshold[node_id]))

执行结果

  1. node_index [0 2 4]
  2. Rules used to predict sample 0:
  3. decision id node 4 : (X[0, -2] (= 1.5) > -2.0)

接下来是探讨多个样本,是否有经过相同的节点。

以样本0、1当作范例,node_indicator.toarray()存放多个矩阵0代表没有经过该节点,1代表经过该节点。common_nodes中存放true与false,若同一个节点相加的值等于输入样本的各树,则代表该节点都有被经过。

  1. # For a group of samples, we have the following common node.
  2. sample_ids = [0, 1]
  3. common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
  4. len(sample_ids))
  5. print('node_indicator',node_indicator.toarray()[sample_ids])
  6. print('common_nodes',common_nodes)
  7. common_node_id = np.arange(n_nodes)[common_nodes]
  8. print('common_node_id',common_node_id)
  9. print("\nThe following samples %s share the node %s in the tree"
  10. % (sample_ids, common_node_id))
  11. print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))

执行结果

  1. node_indicator [[1 0 1 0 1]
  2. [1 0 1 1 0]]
  3. common_nodes [ True False True False False]
  4. common_node_id [0 2]
  5. The following samples [0, 1] share the node [0 2] in the tree
  6. It is 40.0 % of all nodes.

(三)完整程式码

  1. import numpy as np
  2. from sklearn.model_selection import train_test_split
  3. from sklearn.datasets import load_iris
  4. from sklearn.tree import DecisionTreeClassifier
  5. iris = load_iris()
  6. X = iris.data
  7. y = iris.target
  8. X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
  9. estimator = DecisionTreeClassifier(max_leaf_nodes=3, random_state=0)
  10. estimator.fit(X_train, y_train)
  11. # The decision estimator has an attribute called tree_ which stores the entire
  12. # tree structure and allows access to low level attributes. The binary tree
  13. # tree_ is represented as a number of parallel arrays. The i-th element of each
  14. # array holds information about the node `i`. Node 0 is the tree's root. NOTE:
  15. # Some of the arrays only apply to either leaves or split nodes, resp. In this
  16. # case the values of nodes of the other type are arbitrary!
  17. #
  18. # Among those arrays, we have:
  19. # - left_child, id of the left child of the node
  20. # - right_child, id of the right child of the node
  21. # - feature, feature used for splitting the node
  22. # - threshold, threshold value at the node
  23. #
  24. # Using those arrays, we can parse the tree structure:
  25. n_nodes = estimator.tree_.node_count
  26. children_left = estimator.tree_.children_left
  27. children_right = estimator.tree_.children_right
  28. feature = estimator.tree_.feature
  29. threshold = estimator.tree_.threshold
  30. # The tree structure can be traversed to compute various properties such
  31. # as the depth of each node and whether or not it is a leaf.
  32. node_depth = np.zeros(shape=n_nodes)
  33. is_leaves = np.zeros(shape=n_nodes, dtype=bool)
  34. stack = [(0, -1)] # seed is the root node id and its parent depth
  35. while len(stack) > 0:
  36. node_id, parent_depth = stack.pop()
  37. node_depth[node_id] = parent_depth + 1
  38. # If we have a test node
  39. if (children_left[node_id] != children_right[node_id]):
  40. stack.append((children_left[node_id], parent_depth + 1))
  41. stack.append((children_right[node_id], parent_depth + 1))
  42. else:
  43. is_leaves[node_id] = True
  44. print("The binary tree structure has %s nodes and has "
  45. "the following tree structure:"
  46. % n_nodes)
  47. for i in range(n_nodes):
  48. if is_leaves[i]:
  49. print("%snode=%s leaf node." % (node_depth[i] * "\t", i))
  50. else:
  51. print("%snode=%s test node: go to node %s if X[:, %s] <= %ss else to "
  52. "node %s."
  53. % (node_depth[i] * "\t",
  54. i,
  55. children_left[i],
  56. feature[i],
  57. threshold[i],
  58. children_right[i],
  59. ))
  60. print()
  61. # First let's retrieve the decision path of each sample. The decision_path
  62. # method allows to retrieve the node indicator functions. A non zero element of
  63. # indicator matrix at the position (i, j) indicates that the sample i goes
  64. # through the node j.
  65. node_indicator = estimator.decision_path(X_test)
  66. # Similarly, we can also have the leaves ids reached by each sample.
  67. leave_id = estimator.apply(X_test)
  68. # Now, it's possible to get the tests that were used to predict a sample or
  69. # a group of samples. First, let's make it for the sample.
  70. sample_id = 0
  71. node_index = node_indicator.indices[node_indicator.indptr[sample_id]:
  72. node_indicator.indptr[sample_id + 1]]
  73. print('Rules used to predict sample %s: ' % sample_id)
  74. for node_id in node_index:
  75. if leave_id[sample_id] != node_id:
  76. continue
  77. if (X_test[sample_id, feature[node_id]] <= threshold[node_id]):
  78. threshold_sign = "<="
  79. else:
  80. threshold_sign = ">"
  81. print("decision id node %s : (X[%s, %s] (= %s) %s %s)"
  82. % (node_id,
  83. sample_id,
  84. feature[node_id],
  85. X_test[i, feature[node_id]],
  86. threshold_sign,
  87. threshold[node_id]))
  88. # For a group of samples, we have the following common node.
  89. sample_ids = [0, 1]
  90. common_nodes = (node_indicator.toarray()[sample_ids].sum(axis=0) ==
  91. len(sample_ids))
  92. common_node_id = np.arange(n_nodes)[common_nodes]
  93. print("\nThe following samples %s share the node %s in the tree"
  94. % (sample_ids, common_node_id))
  95. print("It is %s %% of all nodes." % (100 * len(common_node_id) / n_nodes,))