Gary Sieling

Rendering scikit Decision Trees in D3.js

Scikit-learn provides routines to export decision trees to a format called Graphviz, although typically this is used to provide an image of a chart.

For some applications this is valuable, but if the product of machine learning is a the ability to generate models (rather than predictions), it would be preferable to provide interactive models.

Automatically construction decision trees, for instance, might allow you to discover patterns in underlying data, e.g. determining that many failures are caused by a particular device or vendor. In this scenario, being able to predict failure is relatively useless, since the goal is to take corrective action.

There are many awesome interactive tree examples with D3.js and the example that follows will show how to link these two products together.

Scikit-learn provides a function called export_graphviz, which I pulled changed to export JSON (the library would probably benefit from adding API calls that let you iterate over their trees, so this is not needed)

The whole function is a bit long, so feel free to skip to the next section.

def viz(decision_tree, feature_names=None):
  from warnings import warn

  js = ""

  def node_to_str(tree, node_id, criterion):
    if not isinstance(criterion, sklearn.tree.tree.six.string_types):
      criterion = "impurity"

    value = tree.value[node_id]
    if tree.n_outputs == 1:
      value = value[0, :]

    if tree.children_left[node_id] == sklearn.tree._tree.TREE_LEAF:
      return '{"id": "%s", "criterion": "%s", "impurity": "%s", "samples": "%s", "value": "%s"}' \
             % (node_id, 
                criterion,
                tree.impurity[node_id],
                tree.n_node_samples[node_id],
                value)
    else:
      if feature_names is not None:
        feature = feature_names[tree.feature[node_id]]
      else:
        feature = tree.feature[node_id]

      return '"id": "%s", "rule": "%s <= %.4f", "%s": "%s", "samples": "%s"' \
             % (node_id, 
                feature,
                tree.threshold[node_id],
                criterion,
                tree.impurity[node_id],
                tree.n_node_samples[node_id])

  def recurse(tree, node_id, criterion, parent=None, depth=0):
    tabs = "  " * depth
    js = ""

    left_child = tree.children_left[node_id]
    right_child = tree.children_right[node_id]

    js = js + "\n" + \
         tabs + "{\n" + \
         tabs + "  " + node_to_str(tree, node_id, criterion)

    if left_child != sklearn.tree._tree.TREE_LEAF and depth < 6:
      js = js + ",\n" + \
           tabs + '  "left": ' + \
           recurse(tree, \
                   left_child, \
                   criterion=criterion, \
                   parent=node_id, \
                   depth=depth + 1) + ",\n" + \
           tabs + '  "right": ' + \
           recurse(tree, \
                   right_child, \
                   criterion=criterion, \
                   parent=node_id,
                   depth=depth + 1)

    js = js + tabs + "\n" + \
         tabs + "}"

    return js

  if isinstance(decision_tree, sklearn.tree.tree.Tree):
    js = js + recurse(decision_tree, 0, criterion="impurity")
  else:
    js = js + recurse(decision_tree.tree_, 0, criterion=decision_tree.criterion)

  return js

cols = dict()
for i, c in enumerate(income_trn.columns):
  cols[i] = c.name

print viz(clf, feature_names=cols)

I'm extending a previous example I wrote, which builds a decision tree for income data from Postgres data. The 'rules' here are what determines whether the left or right path down a tree applies to a data point. The internal structure of the tree numbers these as column 1, column 2, etc, but the function above lets you provide a mapping, so these can be turned back into column headers.

x = {
  id: '0', rule: 'marital_status <= 12.0000', 'gini': '0.360846845083', 'samples': '16281',
  left: 
  {
    id: '1', rule: 'capital_gain <= 86.5000', 'gini': '0.0851021980364', 'samples': '5434',
    left: 
    {
      id: '2', rule: 'capital_loss <= 2372.0000', 'gini': '0.0634649568944', 'samples': '5212',
      left: 
      {
        id: '3', rule: 'education <= 46.0000', 'gini': '0.0584505343', 'samples': '5177'      
      },
      right: 
      {
        id: '570', rule: 'capital_loss <= 2558.5000', 'gini': '0.489795918367', 'samples': '35'      
      }    
    },
    right: 
    {
      id: '577', rule: 'education_num <= 47.0000', 'gini': '0.43507020534', 'samples': '222',
      left: 
      {
        id: '578', rule: 'capital_gain <= 221.5000', 'gini': '0.272324674466', 'samples': '123'      
      },
      right: 
      {
        id: '607', rule: 'capital_gain <= 817.5000', 'gini': '0.499540863177', 'samples': '99'      
      }    
    }  
  },
  right: 
  {
    id: '632', rule: 'marital_status <= 55.0000', 'gini': '0.44372508662', 'samples': '10847',
    left: 
    {
      id: '633', rule: 'education_num <= 47.0000', 'gini': '0.493880410242', 'samples': '7403',
      left: 
      {
        id: '634', rule: 'capital_gain <= 22.0000', 'gini': '0.454579586462', 'samples': '4363'      
      },
      right: 
      {
        id: '2885', rule: 'education_num <= 113.5000', 'gini': '0.486689750693', 'samples': '3040'      
      }    
    },
    right: 
    {
      id: '4292', rule: 'capital_gain <= 86.5000', 'gini': '0.164770726851', 'samples': '3444',
      left: 
      {
        id: '4293', rule: 'education_num <= 47.0000', 'gini': '0.126711048456', 'samples': '3207'      
      },
      right: 
      {
        id: '4902', rule: 'education <= 46.0000', 'gini': '0.478627000659', 'samples': '237'      
      }    
    }  
  }
}

To make this render with D3, I picked one example structure, which looks like this:

The nice thing about this visualization is that while it shows leaf nodes, it groups them into the tree hierarchy and lets you selectively drill down by clicking to zoom.

The JSON structure for every example isn't necessarily guaranteed to be the same, so I've written a function to restructure the above tree into exactly what this needs (this is easier than rewriting the above python code for every test case, or fixing the D3 examples)

function toJson(x) 
{
  var result = {};
  result.name = x.rule;

	if ( (!!x.left && !x.left.value) ||
			 (!!x.right && !x.right.value) )
    result.children = [];
	else
    result.size = parseInt(x.samples);

  var index = 0;
  if (!!x.left && !x.left.value)
    result.children[index++] = toJson(x.left);

  if (!!x.right && !x.right.value)
    result.children[index++] = toJson(x.right);

  return result;
}

Then, the only change you need to make to the D3 example is to add this function and add a call:

  node = root = toJson(data);

Now, this isn't the prettiest and is only one view of the tree (leaves), but it wires up enough parts to get you set up to find the right visualization for what you're doing.

Exit mobile version