SKLearn has a function to convert decision trees to “graphviz” (for rendering) but I find JSON more helpful, as you can read it more easily, as well as use it in web apps. The function below will give you JSON.
The reason this is necessary (vs the JSON.dumps) library is that the Decision Tree interfaces don’t support the interfaces the JSON library needs to run. Additionally, even if it did, the JSON library in python dies on very small floating point numbers, which is why it’s not used at all in my version.
def treeToJson(decision_tree, feature_names=None):
from warnings import warn
js = ""
def node_to_str(tree, node_id, criterion):
if not isinstance(criterion, sklearn.tree.tree.six.string_types):
criterion = "impurity"
value = tree.value[node_id]
if tree.n_outputs == 1:
value = value[0, :]
jsonValue = ', '.join([str(x) for x in value])
if tree.children_left[node_id] == sklearn.tree._tree.TREE_LEAF:
return '"id": "%s", "criterion": "%s", "impurity": "%s", "samples": "%s", "value": [%s]' \
% (node_id,
criterion,
tree.impurity[node_id],
tree.n_node_samples[node_id],
jsonValue)
else:
if feature_names is not None:
feature = feature_names[tree.feature[node_id]]
else:
feature = tree.feature[node_id]
if "=" in feature:
ruleType = "="
ruleValue = "false"
else:
ruleType = "<="
ruleValue = "%.4f" % tree.threshold[node_id]
return '"id": "%s", "rule": "%s %s %s", "%s": "%s", "samples": "%s"' \
% (node_id,
feature,
ruleType,
ruleValue,
criterion,
tree.impurity[node_id],
tree.n_node_samples[node_id])
def recurse(tree, node_id, criterion, parent=None, depth=0):
tabs = " " * depth
js = ""
left_child = tree.children_left[node_id]
right_child = tree.children_right[node_id]
js = js + "\n" + \
tabs + "{\n" + \
tabs + " " + node_to_str(tree, node_id, criterion)
if left_child != sklearn.tree._tree.TREE_LEAF:
js = js + ",\n" + \
tabs + ' "left": ' + \
recurse(tree, \
left_child, \
criterion=criterion, \
parent=node_id, \
depth=depth + 1) + ",\n" + \
tabs + ' "right": ' + \
recurse(tree, \
right_child, \
criterion=criterion, \
parent=node_id,
depth=depth + 1)
js = js + tabs + "\n" + \
tabs + "}"
return js
if isinstance(decision_tree, sklearn.tree.tree.Tree):
js = js + recurse(decision_tree, 0, criterion="impurity")
else:
js = js + recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)
return js