import numpy as np
from sklearn import datasets
from sklearn import tree
from sklearn.tree import _tree
import networkx as nx
# 加载鸢尾花数据
iris = datasets.load_iris()
X = iris.data
y = iris.target
# 使用决策树分类
dt = tree.DecisionTreeClassifier(criterion='entropy')
dt.fit(X, y)
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=None,
splitter='best')
def tree_to_code(tree, feature_names):
'''
利用networkx生成树状关系代码
参数:
-----------
tree: 决策树模型
feature_names: list
特征名称
'''
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print("def tree({}):".format(", ".join(feature_names)))
g = nx.DiGraph()
def recurse(node, depth, g):
indent = " " * depth
if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_name[node]
threshold = tree_.threshold[node]
node_name = "{}\n<=\n{:.2f}\n".format(name, threshold) + ' NODE {}'.format(node)
g.add_node(node_name)
print("{}if {} <= {}:".format(indent, name, threshold))
cl_name = recurse(tree_.children_left[node], depth + 1, g)
g.add_edge(node_name, cl_name, name='yes')
print("{}else:".format(indent, name, threshold))
cr_name = recurse(tree_.children_right[node], depth + 1, g)
g.add_edge(node_name, cr_name, name='no')
else:
node_name = "return {}".format(tree_.value[node]) + 'NODE{}'.format(node)
g.add_node(node_name)
print("{}return {}".format(indent, tree_.value[node]))
return node_name
recurse(0, 1, g)
relabel_dict = {}
order_dict = {}
for n in g.nodes:
relabel_dict[n], order = n.split('NODE')
order_dict[relabel_dict[n]] = int(order)
#g = nx.relabel_nodes(g, relabel_dict)
return g, order_dict
g, order_dict = tree_to_code(dt, list(iris.feature_names))
def tree(sepal length (cm), sepal width (cm), petal length (cm), petal width (cm)):
if petal width (cm) <= 0.800000011920929:
return [[50. 0. 0.]]
else:
if petal width (cm) <= 1.75:
if petal length (cm) <= 4.950000047683716:
if petal width (cm) <= 1.6500000357627869:
return [[ 0. 47. 0.]]
else:
return [[0. 0. 1.]]
else:
if petal width (cm) <= 1.550000011920929:
return [[0. 0. 3.]]
else:
if sepal length (cm) <= 6.949999809265137:
return [[0. 2. 0.]]
else:
return [[0. 0. 1.]]
else:
if petal length (cm) <= 4.8500001430511475:
if sepal length (cm) <= 5.950000047683716:
return [[0. 1. 0.]]
else:
return [[0. 0. 2.]]
else:
return [[ 0. 0. 43.]]
def get_root(g):
'''获取根节点'''
root = [node for node, deg in g.degree() if deg == 2]
if len(root) != 1:
raise Exception('something wrong')
else:
return root[0]
def set_pos_dict(g, parent, node, pos_dict, dx=1, dy=1, root_coord=(0, 1), eps=0.5):
'''节点位置关系'''
if parent is None:
node = get_root(g)
x, y = root_coord
else:
x, y = pos_dict[parent]
y = y - dy
edge = g.get_edge_data(parent, node)
if edge['name'] == 'yes':
x = x + dx
else:
x = x - dx
pos_dict[node] = np.array((x, y))
children = [dest for orig, dest in g.edges if orig == node]
for child in children:
set_pos_dict(g, node, child, pos_dict, dx=dx*eps)
# 节点位置关系打印
pos_dict = {}
set_pos_dict(g, None, None, pos_dict, dx=50, dy=3)
pos_dict
{'petal width (cm)\n<=\n0.80\n NODE 0': array([0, 1]),
'return [[50. 0. 0.]]NODE1': array([25., 0.]),
'petal width (cm)\n<=\n1.75\n NODE 2': array([-25., 0.]),
'petal length (cm)\n<=\n4.95\n NODE 3': array([-12.5, -1. ]),
'petal width (cm)\n<=\n1.65\n NODE 4': array([-6.25, -2. ]),
'return [[ 0. 47. 0.]]NODE5': array([-3.125, -3. ]),
'return [[0. 0. 1.]]NODE6': array([-9.375, -3. ]),
'petal width (cm)\n<=\n1.55\n NODE 7': array([-18.75, -2. ]),
'return [[0. 0. 3.]]NODE8': array([-15.625, -3. ]),
'sepal length (cm)\n<=\n6.95\n NODE 9': array([-21.875, -3. ]),
'return [[0. 2. 0.]]NODE10': array([-20.3125, -4. ]),
'return [[0. 0. 1.]]NODE11': array([-23.4375, -4. ]),
'petal length (cm)\n<=\n4.85\n NODE 12': array([-37.5, -1. ]),
'sepal length (cm)\n<=\n5.95\n NODE 13': array([-31.25, -2. ]),
'return [[0. 1. 0.]]NODE14': array([-28.125, -3. ]),
'return [[0. 0. 2.]]NODE15': array([-34.375, -3. ]),
'return [[ 0. 0. 43.]]NODE16': array([-43.75, -2. ])}
def fun_layout(g, pos=pos_dict ,scale=None, center=None, dim=None):
'''定义渲染图层'''
xy = pos.values()
xy = np.array(list(xy))
mean = xy.mean(axis=0)
max_ = np.abs(xy).max(axis=0)
xy = (xy - mean + center)*scale/max_
i = 0
for k, v in pos.items():
pos[k] = xy[i]
i += 1
return pos
from bokeh.plotting import figure, show, output_notebook
from bokeh.models import Plot, Range1d, MultiLine, Circle, HoverTool, TapTool, BoxSelectTool, WheelZoomTool
from bokeh.models.graphs import from_networkx, NodesAndLinkedEdges, EdgesAndLinkedNodes
from bokeh.palettes import Spectral4
output_notebook()
<div class="bk-root">
<a href="https://bokeh.pydata.org" target="_blank" class="bk-logo bk-logo-small bk-logo-notebook"></a>
<span id="1027">Loading BokehJS ...</span>
</div>
G = g
plot = Plot(plot_width=600, plot_height=600,
x_range=Range1d(-1.1,1.1), y_range=Range1d(-1.1,1.1))
plot.title.text = "图形交互演示"
hover = HoverTool(tooltips=[("Name:", "@name")])
plot.add_tools(hover, TapTool(), BoxSelectTool(), WheelZoomTool())
graph_renderer = from_networkx(G, fun_layout, scale=1, center=(0,0))
graph_renderer.node_renderer.glyph = Circle(size=15, fill_color=Spectral4[0])
graph_renderer.node_renderer.selection_glyph = Circle(size=15, fill_color=Spectral4[2])
graph_renderer.node_renderer.hover_glyph = Circle(size=15, fill_color=Spectral4[1])
graph_renderer.node_renderer.data_source.data['name'] = [e.split('NODE')[0] for e in list(g.nodes)]
graph_renderer.edge_renderer.glyph = MultiLine(line_color="#CCCCCC", line_alpha=0.8, line_width=5)
graph_renderer.edge_renderer.selection_glyph = MultiLine(line_color=Spectral4[2], line_width=5)
graph_renderer.edge_renderer.hover_glyph = MultiLine(line_color=Spectral4[1], line_width=5)
graph_renderer.selection_policy = NodesAndLinkedEdges()
graph_renderer.inspection_policy = NodesAndLinkedEdges()
plot.renderers.append(graph_renderer)
show(plot)