机器学习之决策树-利用Bokeh生成树状关系图

发布时间:2021-12-03 公开文章

Talk is cheap

import numpy as np
from sklearn import datasets
from sklearn import tree
from sklearn.tree import _tree
import networkx as nx
# 加载鸢尾花数据
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 使用决策树分类
dt = tree.DecisionTreeClassifier(criterion='entropy')
dt.fit(X, y)
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')
def tree_to_code(tree, feature_names):

    '''
    利用networkx生成树状关系代码

    参数:
    -----------
    tree: 决策树模型
    feature_names: list
        特征名称
    '''

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    g = nx.DiGraph()

    def recurse(node, depth, g):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            node_name = "{}\n<=\n{:.2f}\n".format(name, threshold) + ' NODE {}'.format(node)
            g.add_node(node_name)
            print("{}if {} <= {}:".format(indent, name, threshold))
            cl_name = recurse(tree_.children_left[node], depth + 1, g)
            g.add_edge(node_name, cl_name, name='yes')
            print("{}else:".format(indent, name, threshold))
            cr_name = recurse(tree_.children_right[node], depth + 1, g)
            g.add_edge(node_name, cr_name, name='no')
        else:
            node_name = "return {}".format(tree_.value[node]) + 'NODE{}'.format(node)
            g.add_node(node_name)
            print("{}return {}".format(indent, tree_.value[node]))
        return node_name
    recurse(0, 1, g)
    relabel_dict = {}
    order_dict = {}
    for n in g.nodes:
        relabel_dict[n], order = n.split('NODE')
        order_dict[relabel_dict[n]] = int(order)
    #g = nx.relabel_nodes(g, relabel_dict)

    return g, order_dict
g, order_dict = tree_to_code(dt, list(iris.feature_names))
def tree(sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)):
  if petal width (cm) <= 0.800000011920929:
    return [[50.  0.  0.]]
  else:
    if petal width (cm) <= 1.75:
      if petal length (cm) <= 4.950000047683716:
        if petal width (cm) <= 1.6500000357627869:
          return [[ 0. 47.  0.]]
        else:
          return [[0. 0. 1.]]
      else:
        if petal width (cm) <= 1.550000011920929:
          return [[0. 0. 3.]]
        else:
          if sepal length (cm) <= 6.949999809265137:
            return [[0. 2. 0.]]
          else:
            return [[0. 0. 1.]]
    else:
      if petal length (cm) <= 4.8500001430511475:
        if sepal length (cm) <= 5.950000047683716:
          return [[0. 1. 0.]]
        else:
          return [[0. 0. 2.]]
      else:
        return [[ 0.  0. 43.]]
def get_root(g):
    '''获取根节点'''
    root = [node for node, deg in g.degree() if deg == 2]
    if len(root) != 1:
        raise Exception('something wrong')
    else:
        return root[0]
def set_pos_dict(g, parent, node, pos_dict, dx=1, dy=1, root_coord=(0, 1), eps=0.5):
    '''节点位置关系'''
    if parent is None:
        node = get_root(g)
        x, y = root_coord
    else:
        x, y = pos_dict[parent]
        y = y - dy
        edge = g.get_edge_data(parent, node)
        if edge['name'] == 'yes':
            x = x + dx
        else:
            x = x - dx
    pos_dict[node] = np.array((x, y))

    children = [dest for orig, dest in g.edges if orig == node]
    for child in children:
        set_pos_dict(g, node, child, pos_dict, dx=dx*eps)
# 节点位置关系打印
pos_dict = {}
set_pos_dict(g, None, None, pos_dict, dx=50, dy=3)
pos_dict
{'petal width (cm)\n<=\n0.80\n NODE 0': array([0, 1]),
 'return [[50.  0.  0.]]NODE1': array([25.,  0.]),
 'petal width (cm)\n<=\n1.75\n NODE 2': array([-25.,   0.]),
 'petal length (cm)\n<=\n4.95\n NODE 3': array([-12.5,  -1. ]),
 'petal width (cm)\n<=\n1.65\n NODE 4': array([-6.25, -2.  ]),
 'return [[ 0. 47.  0.]]NODE5': array([-3.125, -3.   ]),
 'return [[0. 0. 1.]]NODE6': array([-9.375, -3.   ]),
 'petal width (cm)\n<=\n1.55\n NODE 7': array([-18.75,  -2.  ]),
 'return [[0. 0. 3.]]NODE8': array([-15.625,  -3.   ]),
 'sepal length (cm)\n<=\n6.95\n NODE 9': array([-21.875,  -3.   ]),
 'return [[0. 2. 0.]]NODE10': array([-20.3125,  -4.    ]),
 'return [[0. 0. 1.]]NODE11': array([-23.4375,  -4.    ]),
 'petal length (cm)\n<=\n4.85\n NODE 12': array([-37.5,  -1. ]),
 'sepal length (cm)\n<=\n5.95\n NODE 13': array([-31.25,  -2.  ]),
 'return [[0. 1. 0.]]NODE14': array([-28.125,  -3.   ]),
 'return [[0. 0. 2.]]NODE15': array([-34.375,  -3.   ]),
 'return [[ 0.  0. 43.]]NODE16': array([-43.75,  -2.  ])}
def fun_layout(g, pos=pos_dict ,scale=None, center=None, dim=None):
    '''定义渲染图层'''
    xy = pos.values()
    xy = np.array(list(xy))
    mean = xy.mean(axis=0)
    max_ = np.abs(xy).max(axis=0)
    xy = (xy - mean + center)*scale/max_
    i = 0
    for k, v in pos.items():
        pos[k] = xy[i]
        i += 1
    return pos
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, WheelZoomTool
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4

output_notebook()
<div class="bk-root">
    <a href="https://bokeh.pydata.org" target="_blank" class="bk-logo bk-logo-small bk-logo-notebook"></a>
    <span id="1027">Loading BokehJS ...</span>
</div>
G = g
plot = Plot(plot_width=600, plot_height=600,
            x_range=Range1d(-1.1,1.1), y_range=Range1d(-1.1,1.1))

plot.title.text = "图形交互演示"
hover = HoverTool(tooltips=[("Name:", "@name")])
plot.add_tools(hover, TapTool(), BoxSelectTool(), WheelZoomTool())

graph_renderer = from_networkx(G, fun_layout, scale=1, center=(0,0))

graph_renderer.node_renderer.glyph = Circle(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Circle(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Circle(size=15, fill_color=Spectral4[1])
graph_renderer.node_renderer.data_source.data['name'] = [e.split('NODE')[0] for e in list(g.nodes)]


graph_renderer.edge_renderer.glyph = MultiLine(line_color="#CCCCCC", line_alpha=0.8, line_width=5)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=5)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=5)

graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()

plot.renderers.append(graph_renderer)


show(plot)