Gary Sieling

Convert scikit-learn decision trees to JSON

SKLearn has a function to convert decision trees to “graphviz” (for rendering) but I find JSON more helpful, as you can read it more easily, as well as use it in web apps. The function below will give you JSON.

The reason this is necessary (vs the JSON.dumps) library is that the Decision Tree interfaces don’t support the interfaces the JSON library needs to run. Additionally, even if it did, the JSON library in python dies on very small floating point numbers, which is why it’s not used at all in my version.


def treeToJson(decision_tree, feature_names=None):
  from warnings import warn

  js = ""

  def node_to_str(tree, node_id, criterion):
    if not isinstance(criterion, sklearn.tree.tree.six.string_types):
      criterion = "impurity"

    value = tree.value[node_id]
    if tree.n_outputs == 1:
      value = value[0, :]

    jsonValue = ', '.join([str(x) for x in value])

    if tree.children_left[node_id] == sklearn.tree._tree.TREE_LEAF:
      return '"id": "%s", "criterion": "%s", "impurity": "%s", "samples": "%s", "value": [%s]' \
             % (node_id, 
                criterion,
                tree.impurity[node_id],
                tree.n_node_samples[node_id],
                jsonValue)
    else:
      if feature_names is not None:
        feature = feature_names[tree.feature[node_id]]
      else:
        feature = tree.feature[node_id]

      if "=" in feature:
        ruleType = "="
        ruleValue = "false"
      else:
        ruleType = "<="
        ruleValue = "%.4f" % tree.threshold[node_id]

      return '"id": "%s", "rule": "%s %s %s", "%s": "%s", "samples": "%s"' \
             % (node_id, 
                feature,
                ruleType,
                ruleValue,
                criterion,
                tree.impurity[node_id],
                tree.n_node_samples[node_id])

  def recurse(tree, node_id, criterion, parent=None, depth=0):
    tabs = "  " * depth
    js = ""

    left_child = tree.children_left[node_id]
    right_child = tree.children_right[node_id]

    js = js + "\n" + \
         tabs + "{\n" + \
         tabs + "  " + node_to_str(tree, node_id, criterion)

    if left_child != sklearn.tree._tree.TREE_LEAF:
      js = js + ",\n" + \
           tabs + '  "left": ' + \
           recurse(tree, \
                   left_child, \
                   criterion=criterion, \
                   parent=node_id, \
                   depth=depth + 1) + ",\n" + \
           tabs + '  "right": ' + \
           recurse(tree, \
                   right_child, \
                   criterion=criterion, \
                   parent=node_id,
                   depth=depth + 1)

    js = js + tabs + "\n" + \
         tabs + "}"

    return js

  if isinstance(decision_tree, sklearn.tree.tree.Tree):
    js = js + recurse(decision_tree, 0, criterion="impurity")
  else:
    js = js + recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)

  return js
Exit mobile version