How to Build Decision Trees in Python

Share on linkedin
Share on twitter
Share on facebook
Share on whatsapp
Share on pocket
How to Build Decision Trees in Python

Table of Contents

Nowadays there are many Machine Learning (ML) algorithms that can be applied to the same task. So, when choosing an algorithm for an ML project you might get lost as there are too many opportunities to explore. Anyway, if you ever get in such a spot the best thing you could do is to start working with something simple, for example, Linear or Logistic Regression while aiming to try more complex models later on.

To tell the truth, some Data Scientists state that simple models have the potential to work as well as complex ones. This is why at the beginning of each ML project simple models are usually applied to the task as they will at least help to create a baseline to beat. Among these simple models, there is one that differs a lot from the rest – a Decision Tree algorithm.

In this article we will talk about:

  • Machine Learning terminology
  • What is a Decision Tree?
  • Building a Decision Tree
    • Greedy algorithm
    • How are the splits made when building a Decision Tree?
    • The stopping criterion
    • Methods for handling missing values
    • Pruning methods
    • Popular Decision Tree building methods
  • Advantages and disadvantages of a Decision Tree
  • How to work with Decision Trees using Python?
    • Sklearn Decision Tree models
    • Decision Tree for Regression
    • Decision Tree for Classification
    • How to optimize the performance of a Decision Tree>
    • How to evaluate a Decision Tree?
    • How to visualize a Decision Tree?
  • Real-Life Applications of Decision Trees
    • Ensemble Learning
    • Boosting
    • Bagging and Random Forest

 

Let’s jump in.

Machine Learning terminology

To start with, it is important to use the correct Machine Learning terminology to understand the exact place of Decision Trees in the grand scheme of things. As you might know, it all starts from the Machine Learning task. An ML task is a task that needs to be solved using Machine Learning having the problem that needs to be solved, the type of prediction that needs to be made, and the available data.

You are familiar with some of the ML tasks, for example:

  • Regression task (predicts a continuous-valued attribute associated with an object)
  • Classification task (identifies which category an object belongs to)
  • Clustering task (automatically groups similar objects into sets)
  • And other

Each ML task has plenty of Machine Learning algorithms that can be used to solve a certain ML task. For example:

Finally, there are ML metrics. ML metrics are the metrics that are used to evaluate the performance of an ML algorithm on a specific ML task. For example:

  • MAE, MSE, and many other ML metrics can be used to evaluate an ML algorithm on a Regression task
  • Accuracy, Precision, Recall, and many other ML metrics can be used to evaluate an ML algorithm on a Classification task

Thus, you now know the basics of ML terminology. Also, you are aware that a Decision Tree is a Machine Learning algorithm used to solve both Regression and Classification ML tasks.

What is a Decision Tree?

A Decision Tree is a Supervised Machine Learning algorithm that can be easily visualized using a connected acyclic graph. In general, a connected acyclic graph is called a tree.

In maths, a graph is a set of vertices and a set of edges. Each edge in a graph connects exactly two vertices. A connected graph is a graph that has at least one path (set of edges) between any pair of vertices. An acyclic graph is a graph that does not have any cycles within it (there is no path in a graph starting and ending in the same vertice). So, as you see, Decision Trees have a mathematical background behind them.

You can find a simple example of a Decision Tree below.

Decision Tree

Image source: Medium

The edges of a tree (in the picture above they are visualized as arrows) are called branches. The vertices of a tree are known as nodes (blue circles in the picture above). The nodes without child nodes are called leaf nodes (in the picture above leaf nodes are green circles).

The basic Decision Tree concept is quite intuitive as it reflects the human decision-making process. When we are making a decision, we ask ourselves a question and based on the answer choose a direction of further discussion until the decision is made. 

The same concept is used in the Decision Tree algorithm. We will cover the exact algorithm a bit later but in general, each Decision Tree splits the data in each node using the object’s features based on a certain criterion until it reaches the stopping criterion. Thus, when the new object occurs the algorithm will know how to handle the object based on its features.

Thus, the Decision Tree algorithm proposes a unique approach to the Regression and Classification ML tasks that differs a lot from the traditional linear models

Building a Decision Tree

Now it is time to talk in-depth about building a Decision Tree.

Greedy algorithm

The basic Decision Tree building algorithm is called a greedy one and has several steps in it:

  1. Imagine you have a dataset X that you want to use as a training set. Also, you have a splitting criterion Q – such function that needs to be maximized when splitting your training set.
  2. Find the best split for each feature in your dataset using the Q function. For a feature with K different values in it, there will be K – 1 different splits. You need to pick the split that maximizes the Q function. Thus, if you have 3 features in your training set, after this step you will have one best split for each feature.
  3. Find the best split S among the best splits found in Step 2. S will be used to split the training set in a node as it is the best possible split that can be made.
  4. Make the split using S. The training set will be split into two parts – some objects will go to the left subtree node, others to the right.
  5. For each of these nodes (left and right), you need to recursively repeat Steps 2-4. Thus, you will build a Decision Tree by constructing child nodes for the root nodes.
  6. At each node, you need to check if some stopping condition has occurred. If it has, then you will stop the recursion and declare this node as a leaf
  7. When the Tree is built, a prediction value (the value that will be predicted by the Decision Tree if a new object gets to a leaf) is assigned to each leaf. In the case of Classification, it may be the majority class in the leaf. For Regression, it can be the mean or the median. Still, you might have your logic when assigning a prediction value to each leaf.

Decision trees can effectively handle missing values. It can be done by modifying the process of splitting the training set at a node which can be done in several ways. After the tree is built, you can perform Pruning (removal of some nodes to reduce the complexity of the model and increase its generalization ability). There are several approaches to that, so we will cover them a bit later.

In general, the Decision Tree building is determined by:

  1. The splitting criterion;
  2. The stopping criterion;
  3. The method for handling missing values;
  4. The Pruning method.

How are the splits made when building a Decision Tree?

As you already know, nodes are split based on a certain splitting criterion Q.  In general, it consists of two terms, each of which corresponds to the child nodes.

How-are the splits made when building a Decision Tree

In the picture above Rm is the samples that are in the node that needs to be split. Rl and Rr are the samples that will respectively get to the left and right subtree after the split. H(R) is the impurity criterion that evaluates the quality of the distribution of the target variable among the objects of the R set. The less the variety of the target variable, the less the value of the impurity criterion should be and, accordingly, it needs to be minimized. On the other hand, function Q needs to be maximized.

In the Regression case, the impurity criterion is usually measured by its variance (the lower the spread of the target variable, the better the node). Still, in the Classification case, it is a bit more interesting as two basic impurity criterions are usually used:

  • Entropy
Entropy
  • Gini

Let’s check a simple example to understand how entropy is used to build a Decision Tree. Imagine that you have a problem sorting a group of balls into two groups – yellow and blue.

There are 9 blue and 11 yellow balls in the picture. If we pull out a ball at random, then with 920probability it will be blue, whereas with 1120 probability it will be yellow. So, if we use the entropy formula we will get that the entropy of this state is

Right now the entropy value is pretty useless, but let’s see how the entropy will change if the balls are divided into two groups – with coordinates less than or equal to 12 and greater than 12.

dt02

Let’s check the entropy. The left group contains 13 balls (8 are blue and 5 are yellow). The entropy of this group is:

Whereas in the right group there were 7 balls (1 blue and 6 yellow). The entropy of the right group is:

 As you can see, the entropy decreased in both groups compared to the initial one. This is exactly how the splitting works. Among all features in the dataset, you need to find the best split that will maximize the Q function while minimizing the H(R1 ) and H(Rr) impurity criteria.

The stopping criterion

To tell the truth, you can come up with a large number of stopping criteria, for example:

  • Limit the maximum depth of a Tree;
  • Limit the minimum number of samples in a leaf;
  • Limit the maximum number of leaves in a Tree;
  • Stop building a Tree if all samples in a leaf belong to the same class;
  • And others.

The stopping criterion must be chosen wisely as it can significantly affect the overall quality of a Tree. However, such a selection will be time expensive as it will require cross-validation.

Methods for handling missing values

As mentioned above, Decision Trees can handle missing values but it requires changing the splitting logic a bit. In general, there are several approaches to this problem:

  • If the missing value occurs in a feature when building a Decision Tree (in a sample S from the training set) and the split should be made by this feature, then S must be put both in the left and the right subtree with certain weights;
  • If the missing value occurs in a feature when using a built Decision Tree (in a new sample S), then it must be also put in both subtrees and proceed further on with certain weights;
  • Surrogate predicate method requires finding a split that is as close as possible to the initial one by another feature;
  • Also, you can use other well-known methods, for example, assign all missing values to zero. Moreover, for Decision Trees, it might be effective to replace the missing values in a feature with numbers that exceed any value in a given feature. Such an approach will help to choose such a split that all objects with known values will go to the left subtree and all objects with missing values to the right.

Pruning methods

Pruning is an alternative to the stopping criterion described above. When using Pruning, an overfitted Decision Tree is built first (for example, until there is one object in each leaf), and then its structure is optimized to improve generalization ability. Many studies show that Pruning results in a better quality of the model compared to early stopping the building based on various stopping criteria.

However, as of today, Pruning is rarely used. Moreover, it is not implemented in the majority of data analysis libraries. This is because the Trees themselves are weak algorithms that are usually used in various ensembles, for example, Bagging (Random Forest) or Boosting that do not require using the Pruning technique.

Still, if you want to learn more about Pruning, you should probably study the Cost-Complexity Pruning method as it is used in the Decision Tree CART building method.

Popular Decision Tree building methods

here are three popular Decision Tree building methods:

  • The ID3 method uses entropy criterion and builds a tree until each leaf contains objects of the same class, or while the partition of the node gives a decrease in the entropy criterion;
  • The C4.5 method uses the Gain Ratio criterion (normalized entropy criterion). As a stopping criterion, it limits the number of samples in a leaf. Pruning is performed using the Error-Based Pruning method that uses the generalization capacity estimates to make a vertex removal decision. Missing values are handled using the method that ignores objects with missing values when computing the branching criterion and then transfers such objects to both subtrees with certain weights.
  • The CART method uses Gini criterion and Cost-Complexity Pruning. Missing values are handled using the surrogate predicate method.

Advantages and disadvantages of a Decision Tree

Undoubtedly, the Decision Tree has several advantages:

  • It is pretty intuitive;
  • It can be easily visualized and interpreted;
  • It has the potential to achieve zero error for almost any problem;
  • It can be used without data preprocessing (normalization, scaling, and so on).

Still, there are major disadvantages:

  • It tends to overfit;
  • When solving a Regression task you might face the extrapolation problem as a Decision Tree algorithm is unable to predict values outside the interval of the target variable on the training set. In the case of Regression, a Decision Tree will underperform because of that.

Thus, despite the tangible advantages, it is better not to use a standalone Decision Tree model on any task.

How to work with Decision Trees using Python?

For this section I have prepared a Google Collab notebook for you featuring working with Decision Trees in Python, training on the Boston and Iris dataset, hyperparameter tuning using GridSearchCV, and some visualizations. Please feel free to experiment and play around as there is no better way to master something than practice.

Sklearn Decision Tree models

Fortunately, it is quite easy to work with Decision Trees in Python thanks to the scikit-learn (sklearn) Python package. As you might know, sklearn is an efficient and simple library for Machine Learning that has plenty of ML algorithms, metrics, datasets, and additional tools implemented. Thus, the easiest way to start working with Decision Trees is to use scikit-learn.

Sklearn has two standalone Decision Tree models implemented:

These are the basic Decision Tree models. If you want something advanced, sklearn has awesome documentation including valuable tutorials, simple examples, and brief descriptions of the algorithms, so if you want to explore all the opportunities please refer to it.

To get things going simply import the necessary models, for example:

				
					from sklearn.tree import DecisionTreeRegressor
from sklearn.tree import DecisionTreeClassifier
				
			

Decision Tree for Regression

As mentioned above, sklearn has a built-in DecisionTreeRegressor model that can be used to solve a Regression task using a standalone Decision Tree. From the code perspective, it will take only a few lines. All you need to do is to import and initialize the sklearn model and use the fit method to train your model on the training set.

				
					
X, y = load_boston(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)
dtr = DecisionTreeRegressor()
dtr.fit(X_train, y_train)

				
			

By doing this you will get a trained Decision Tree model and will be able to proceed with further actions such as model optimization, evaluation, and visualization.

Decision Tree for Classification

Fortunately, working with a DecisionTreeClassifier model is no different. The steps are the following:

  1. Import a necessary model;
  2. Initialize it;
  3. Use the fit method to train it;
  4. Proceed with further exploration.
				
					X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)
dtc = DecisionTreeClassifier()
dtc.fit(X_train, y_train)

				
			

Now you know how to train a standalone Decision Tree model for both Regression and Classification tasks. However, if you pay closer attention to the notebook, you will notice that both “out-of-the-box” models drastically overfit the training data. Thus, let’s try and optimize the Decision Tree performance.

How to optimize the performance of a Decision Tree?

When talking about the Decision Tree performance optimization Data scientists usually refer not to obtaining the best possible metric, they refer to minimizing the overfitting problem. In general, if you use the “out-of-the-box” Decision Tree model, it will create a very complex tree with each sample getting its own leaf node. Sure, such an approach will help to achieve a zero error on the training set. Still, the model will generalize poorly and its performance on unseen data will leave much to be desired.

Nevertheless, you can use several techniques to optimize a Decision Tree performance, for example:

  1. Feature Engineering technique will help you to get rid of irrelevant features and create more valuable ones;
  2. If there are too many features in your dataset, then PCA might be a nice algorithm to help you with the dataset decomposition;
  3. Hyperparameter tuning will be useful for overfitting reduction;
  4. If you have a class imbalance in your dataset, then Sampling techniques are there to help.

It is worth mentioning that in terms of an ML metric that you will use to evaluate your Decision Tree model these methods will likely cause a decrease in it. Still, you must remember that your ultimate goal is to reduce overfitting that cannot be obtained without a cost.

To get rid of irrelevant features you simply need to train your model on the initial dataset and use the .feature_importances_ method to get the importance of each feature in the dataset. Then you can remove irrelevant features from the dataset and proceed with training.

				
					X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)
dtr = DecisionTreeRegressor()
dtr.fit(X_train, y_train)
plt.figure(figsize=(12, 8))
plt.bar(np.arange(X.shape[1]), np.squeeze(np.abs(dtr.feature_importances_)))

				
			

Image Source: Notebook

				
					features_to_save = np.where(np.abs(np.squeeze(dtr.feature_importances_)) >= 1e-1)[0]
X_train_new = X_train[:, features_to_save]
X_test_new = X_test[:, features_to_save]

				
			

When tuning a model’s hyperparameters you need to be careful and precise as they can massively affect the model’s performance. Please refer to the Building a Decision Tree section for an in-depth Decision Tree algorithm that will help you identify which of the model’s parameters are the most crucial. For the full list of the hyperparameters please refer to the sklearn documentation of the respective model:

Still, from the code perspective, the task is simple as you can effectively use either RandomizedSeachCV or GridSearchCV methods to tune your model. Still, please remember that it might take a while before these methods will produce a result.

				
					params_dist = {
 'max_depth': [4, 6, 8, 10, 12, 14, 16],
 'max_leaf_nodes': [1000, 2000, 3000],
 'min_samples_leaf': [20, 30, 40, 50],
 'min_samples_split': [30, 40, 50]
}
dtr = DecisionTreeRegressor()
random_search = RandomizedSearchCV(dtr, params_dist, cv=5)
random_search.fit(X_train, y_train)
tuned_dtr = random_search.best_estimator_

				
			

For more code please refer to the Notebook.

How to evaluate a Decision Tree?

Just as any other ML algorithm a Decision Tree can be easily evaluated using suitable metrics and tools, for example:

  1. ML metrics designed for a specific ML task (MAE for Regression, Precision for Classification)
  2. Cross-Validation to check the algorithm’s ability to generalize

From the code perspective, you will face no obstacles trying to evaluate a Decision Tree model as the process is intuitive and does not differ from the evaluation process used when working with linear models (Linear and Logistic Regression).

You can simply use sklearn.metrics to import a necessary metric or sklearn.model_selection to import a necessary tool and then apply them to your model.

				
					
from sklearn.metrics import mean_absolute_error as mae
from sklearn.model_selection import cross_val_score
X, y = load_boston(return_X_y=True)
dtr = DecisionTreeRegressor()
cross_val_score(dtr, X, y)
 
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)
dtr = DecisionTreeRegressor()
dtr.fit(X_train, y_train)
print('MAE on train:', mae(y_train, dtr.predict(X_train)))
print('MAE on test:', mae(y_test, dtr.predict(X_test)))

				
			

How to visualize a Decision Tree?

As mentioned above, one of the clear Decision Tree advantages is the interpretability achieved through simple visualization. You can easily visualize a Decision Tree using sklearn.tree.plot_tree method.


This is the part where you can get creative as you can work a bit on your visualization and get a nice presentable picture. Sure, you can still use the method as is and get the following picture:

				
					from sklearn import tree
import matplotlib.pyplot as plt
dtc = DecisionTreeClassifier()
dtc.fit(X_train, y_train)
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (9,9), dpi=800)
tree.plot_tree(dtc)
 

				
			
How to visualize a Decision Tree

Still, it is always better if you make your visualizations a bit more advanced:

				
					from sklearn import tree
import matplotlib.pyplot as plt
data = load_boston()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3)
dtr = DecisionTreeRegressor()
dtr.fit(X_train, y_train)
fn = data.feature_names
cn = data.target
fig, axes = plt.subplots(nrows = 1,ncols = 1,figsize = (9,9), dpi=800)
tree.plot_tree(dtr,
              feature_names = fn,
              class_names=cn,
              filled = True);

				
			
Decision Tree

Also, please pay attention that your Decision Tree might be quite big

Real-Life Applications of Decision Trees

As you already know, the overfitting problem that usually occurs when working with Decision Trees makes a standalone Decision Tree model almost irrelevant in the grand scheme of things. The extrapolation problem of DecisionTreeRegressor is also crucial as it makes a Decision Tree algorithm not very suitable for the Regression task.

That is why a single Decision Tree model is a bad model to use in a Machine Learning project. Moreover, you should probably avoid using Decision Trees for a Regression task either. Still, the Decision Tree’s ability to achieve zero error for almost any problem makes it a very attractive algorithm for Data Scientists. So, ML specialists came up with a strategy called Ensemble Learning that minimizes the disadvantages of Decision Trees.

To make things clear, in real-life Decision Trees are often used as a part of ensembles. However, you will probably never face a standalone Decision Tree model.

Ensemble Learning

Ensemble learning is a process where multiple ML models (called base models) are generated and combined to solve a particular problem. There are several types of Ensemble Learning:

  1. Bagging;
  2. Boosting;
  3. Stacking;
  4. And Other.

Decision Trees are commonly used as base models for both Boosting and Bagging.

Boosting

Boosting is a sequential Ensemble Learning technique that incrementally builds an ensemble by training each new model instance to emphasize the training instances that previous models misclassified. So, one model is learning from the mistakes of another which boosts the learning.

 

Multiple Boosting algorithms use Decision Trees as base models, for example, Gradient Boosting, XGBoost, CatBoost, AdaBoost, and others. 

Bagging and Random Forest

Bootstrap Aggregating or Bagging was originally developed to overcome the overfitting problem. It is also known as parallel Ensemble Learning as it proposes training multiple base models (for example, Decision Trees) in parallel on unique training sets. Such an approach helps to minimize the overfit. For the detailed Bagging algorithm please refer to the related article.

Random Forest is a popular ML algorithm that is strongly based on Bagging while developing its ideas. The Forest part of the name comes from the approach of training many Decision Trees as base models. For the detailed Random Forest algorithm please refer to the related article.

Random Forest is a powerful yet simple ML algorithm that is widely used across the industry and can be applied to many tasks, for example:

  1. Fraud Detection;
  2. Kaggle competitions;
  3. Credit scoring;
  4. E-commerce case;
  5. Any other Classification problem.


Thus, Decision Trees are everywhere but usually as a part of ensembles. For extra support, you can access the related article for an in-depth explanation.

Final Thoughts

Hopefully, this tutorial will help you succeed and use the Decision Tree algorithm in your next Machine Learning project. 

To summarize, we started with an in-depth explanation of Machine Learning terminology. Then we covered some theoretical information about the Decision Tree algorithm, its advantages, and disadvantages, and went through a step-by-step guide on how to use Decision Trees in Python for both the Regression and Classification tasks. Lastly, we talked about some real-life applications of Decision Trees.

If you enjoyed this post, a great next step would be to start building your Machine Learning project with all the relevant tools. Check out tools like:

 

For extra support, you can access the Notebook for further code and documentation.

Thanks for reading, and happy training!

Resources

Top MLOps guides and news in your inbox every month

Share on facebook
Share on twitter
Share on linkedin
Share on whatsapp
Share on pocket

Announcing CORE, a free ML Platform for the community to help data scientists focus more on data science and less on technical complexity

Download cnrvg CORE for Free

By submitting this form, I agree to cnvrg.io’s
privacy policy and terms of service.