机器学习之决策树-利用Bokeh生成树状关系图

发布时间:2021-12-03 公开文章

Base

Github加速

 
点此查看

Civil

土木分类资料

 
点此查看

Python

Python编程学习

 
点此查看

Games

JS前端编程学习

 
点此查看

Talk is cheap

import numpy as np
from sklearn import datasets
from sklearn import tree
from sklearn.tree import _tree
import networkx as nx
# 加载鸢尾花数据
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 使用决策树分类
dt = tree.DecisionTreeClassifier(criterion='entropy')
dt.fit(X, y)
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
            max_features=None, max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, presort=False, random_state=None,
            splitter='best')
def tree_to_code(tree, feature_names):

    '''
    利用networkx生成树状关系代码

    参数:
    -----------
    tree: 决策树模型
    feature_names: list
        特征名称
    '''

    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    g = nx.DiGraph()

    def recurse(node, depth, g):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            node_name = "{}\n<=\n{:.2f}\n".format(name, threshold) + ' NODE {}'.format(node)
            g.add_node(node_name)
            print("{}if {} <= {}:".format(indent, name, threshold))
            cl_name = recurse(tree_.children_left[node], depth + 1, g)
            g.add_edge(node_name, cl_name, name='yes')
            print("{}else:".format(indent, name, threshold))
            cr_name = recurse(tree_.children_right[node], depth + 1, g)
            g.add_edge(node_name, cr_name, name='no')
        else:
            node_name = "return {}".format(tree_.value[node]) + 'NODE{}'.format(node)
            g.add_node(node_name)
            print("{}return {}".format(indent, tree_.value[node]))
        return node_name
    recurse(0, 1, g)
    relabel_dict = {}
    order_dict = {}
    for n in g.nodes:
        relabel_dict[n], order = n.split('NODE')
        order_dict[relabel_dict[n]] = int(order)
    #g = nx.relabel_nodes(g, relabel_dict)

    return g, order_dict
g, order_dict = tree_to_code(dt, list(iris.feature_names))
def tree(sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)):
  if petal width (cm) <= 0.800000011920929:
    return [[50.  0.  0.]]
  else:
    if petal width (cm) <= 1.75:
      if petal length (cm) <= 4.950000047683716:
        if petal width (cm) <= 1.6500000357627869:
          return [[ 0. 47.  0.]]
        else:
          return [[0. 0. 1.]]
      else:
        if petal width (cm) <= 1.550000011920929:
          return [[0. 0. 3.]]
        else:
          if sepal length (cm) <= 6.949999809265137:
            return [[0. 2. 0.]]
          else:
            return [[0. 0. 1.]]
    else:
      if petal length (cm) <= 4.8500001430511475:
        if sepal length (cm) <= 5.950000047683716:
          return [[0. 1. 0.]]
        else:
          return [[0. 0. 2.]]
      else:
        return [[ 0.  0. 43.]]
def get_root(g):
    '''获取根节点'''
    root = [node for node, deg in g.degree() if deg == 2]
    if len(root) != 1:
        raise Exception('something wrong')
    else:
        return root[0]
def set_pos_dict(g, parent, node, pos_dict, dx=1, dy=1, root_coord=(0, 1), eps=0.5):
    '''节点位置关系'''
    if parent is None:
        node = get_root(g)
        x, y = root_coord
    else:
        x, y = pos_dict[parent]
        y = y - dy
        edge = g.get_edge_data(parent, node)
        if edge['name'] == 'yes':
            x = x + dx
        else:
            x = x - dx
    pos_dict[node] = np.array((x, y))

    children = [dest for orig, dest in g.edges if orig == node]
    for child in children:
        set_pos_dict(g, node, child, pos_dict, dx=dx*eps)
# 节点位置关系打印
pos_dict = {}
set_pos_dict(g, None, None, pos_dict, dx=50, dy=3)
pos_dict
{'petal width (cm)\n<=\n0.80\n NODE 0': array([0, 1]),
 'return [[50.  0.  0.]]NODE1': array([25.,  0.]),
 'petal width (cm)\n<=\n1.75\n NODE 2': array([-25.,   0.]),
 'petal length (cm)\n<=\n4.95\n NODE 3': array([-12.5,  -1. ]),
 'petal width (cm)\n<=\n1.65\n NODE 4': array([-6.25, -2.  ]),
 'return [[ 0. 47.  0.]]NODE5': array([-3.125, -3.   ]),
 'return [[0. 0. 1.]]NODE6': array([-9.375, -3.   ]),
 'petal width (cm)\n<=\n1.55\n NODE 7': array([-18.75,  -2.  ]),
 'return [[0. 0. 3.]]NODE8': array([-15.625,  -3.   ]),
 'sepal length (cm)\n<=\n6.95\n NODE 9': array([-21.875,  -3.   ]),
 'return [[0. 2. 0.]]NODE10': array([-20.3125,  -4.    ]),
 'return [[0. 0. 1.]]NODE11': array([-23.4375,  -4.    ]),
 'petal length (cm)\n<=\n4.85\n NODE 12': array([-37.5,  -1. ]),
 'sepal length (cm)\n<=\n5.95\n NODE 13': array([-31.25,  -2.  ]),
 'return [[0. 1. 0.]]NODE14': array([-28.125,  -3.   ]),
 'return [[0. 0. 2.]]NODE15': array([-34.375,  -3.   ]),
 'return [[ 0.  0. 43.]]NODE16': array([-43.75,  -2.  ])}
def fun_layout(g, pos=pos_dict ,scale=None, center=None, dim=None):
    '''定义渲染图层'''
    xy = pos.values()
    xy = np.array(list(xy))
    mean = xy.mean(axis=0)
    max_ = np.abs(xy).max(axis=0)
    xy = (xy - mean + center)*scale/max_
    i = 0
    for k, v in pos.items():
        pos[k] = xy[i]
        i += 1
    return pos
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, WheelZoomTool
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4

output_notebook()
<div class="bk-root">
    <a href="https://bokeh.pydata.org" target="_blank" class="bk-logo bk-logo-small bk-logo-notebook"></a>
    <span id="1027">Loading BokehJS ...</span>
</div>
G = g
plot = Plot(plot_width=600, plot_height=600,
            x_range=Range1d(-1.1,1.1), y_range=Range1d(-1.1,1.1))

plot.title.text = "图形交互演示"
hover = HoverTool(tooltips=[("Name:", "@name")])
plot.add_tools(hover, TapTool(), BoxSelectTool(), WheelZoomTool())

graph_renderer = from_networkx(G, fun_layout, scale=1, center=(0,0))

graph_renderer.node_renderer.glyph = Circle(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Circle(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Circle(size=15, fill_color=Spectral4[1])
graph_renderer.node_renderer.data_source.data['name'] = [e.split('NODE')[0] for e in list(g.nodes)]


graph_renderer.edge_renderer.glyph = MultiLine(line_color="#CCCCCC", line_alpha=0.8, line_width=5)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=5)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=5)

graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()

plot.renderers.append(graph_renderer)


show(plot)