4

Due to some restriction I cannot use graphviz , webgraphviz.com to visualize decision tree (work network is closed from the other world).

Question: Is there some alternative utilite or some Python code for at least very simple visualization may be just ASCII visualization of decision tree (python/sklearn) ?

I mean, I can use sklearn in particular: tree.export_graphviz( ) which produces text file with tree structure, from which one can read a tree, but doing it by "eyes" is not pleasant ...

PS Pay attention that

graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
Image(graph.create_png())

will NOT work, since create_png uses implicitly graphviz

Alexander Chervov
  • 556
  • 1
  • 6
  • 19
  • I have used [networkx](https://networkx.github.io/documentation/stable/reference/drawing.html) before, but it requires a fair bit of tuning – G. Anderson Oct 04 '18 at 20:18
  • You may find more answers [on this question](https://stackoverflow.com/questions/7670280/tree-plotting-in-python) as well – G. Anderson Oct 04 '18 at 20:19
  • Also here some info: https://stackoverflow.com/questions/20224526/how-to-extract-the-decision-rules-from-scikit-learn-decision-tree?noredirect=1&lq=1 – Alexander Chervov Oct 05 '18 at 07:08
  • 1
    This article may be useful for you: http://explained.ai/decision-tree-viz/index.html – avchauzov Oct 06 '18 at 10:22

1 Answers1

4

Here is an answer that doesn't use either graphviz or an online converter. As of scikit-learn version 21.0 (roughly May 2019), Decision Trees can now be plotted with matplotlib using scikit-learn’s tree.plot_tree without relying on graphviz.

import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

X, y = load_iris(return_X_y=True)

# Make an instance of the Model
clf = DecisionTreeClassifier(max_depth = 5)

# Train the model on the data
clf.fit(X, y)

fn=['sepal length (cm)','sepal width (cm)','petal length (cm)','petal width (cm)']
cn=['setosa', 'versicolor', 'virginica']

# Setting dpi = 300 to make image clearer than default
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (4,4), dpi=300)

tree.plot_tree(clf,
           feature_names = fn, 
           class_names=cn,
           filled = True);

fig.savefig('imagename.png')

The image below is what is saved. enter image description here

The code was adapted from this post.