Introduction to Decision Trees

Basic concepts and advantages and disadvantages of decision tree

  Decision Tree (Decision Tree) is based on the known probability of occurrence of various situations, by forming a decision tree to obtain the probability that the expected value of the net present value is greater than or equal to zero, evaluate the project risk, and judge its feasibility. A graphical method that uses probability analysis intuitively. Since this decision branch is drawn as a graph like the branches of a tree, it is called a decision tree.

The main advantages of decision trees:

  1. Being very interpretable, the model can generate understandable rules.
  2. The importance of features can be found.
  3. The computational complexity of the model is low.

The main disadvantages of decision trees:

  1. The model is prone to overfitting and needs to be dealt with by branch reduction techniques.
  2. Does not make good use of continuous features.
  3. The prediction ability is limited and cannot achieve the effect of other strong supervision models.
  4. The variance is high, and a slight change in the data distribution can easily result in a completely different tree structure.

An example of decision tree based on penguin dataset

  • Step1: Function library import and data reading
## Basic function library
import numpy as np
import pandas as pd
## Drawing function library
import matplotlib.pyplot as plt
import seaborn as sns

This time, we choose the penguin data (palmerpenguins) to try to train the method. The data set contains a total of 8 variables, including 7 feature variables and 1 target categorical variable. There are 150 samples in total, and the target variable is the category of penguins, which all belong to the three subgenera of penguin, namely (Adélie, Chinstrapand Gentoo). The seven characteristics of the three species of penguins included are the island, the length of the mouth, the depth of the mouth, the length of the flippers, body size, gender and age.


## We use the read_csv function that comes with Pandas to read and convert to DataFrame format
data = pd.read_csv('penguins_raw.csv')

## In order to facilitate us to select only four simple features, interested students can study the meaning and usage of other features
data = data[['Species','Culmen Length (mm)','Culmen Depth (mm)','Flipper Length (mm)','Body Mass (g)']]
  • Step2: Simple view of data information
## Use .info() to view the overall information of the data

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 5 columns):
Species                344 non-null object
Culmen Length (mm)     342 non-null float64
Culmen Depth (mm)      342 non-null float64
Flipper Length (mm)    342 non-null float64
Body Mass (g)          342 non-null float64
dtypes: float64(4), object(1)
memory usage: 13.6+ KB
## For simple data viewing, we can use .head() head.tail() tail



Here we find that there are NaNs in the data set. Generally, we think that NaNs represent missing values ​​in the data set, which may be an error in data collection or processing. Here we use -1 to fill in missing values, and there are other missing value processing methods such as "median filling and mean filling". Students who are interested can also try.

data = data.fillna(-1)  #Fill missing values ​​with -1
## The corresponding category labels are 'Adelie Penguin', 'Gentoo penguin', 'Chinstrap penguin' three different penguin categories.

array(['Adelie Penguin (Pygoscelis adeliae)',
       'Gentoo penguin (Pygoscelis papua)',
       'Chinstrap penguin (Pygoscelis antarctica)'], dtype=object)
## Use the value_counts function to view the number of each category

Adelie Penguin (Pygoscelis adeliae)          152
Gentoo penguin (Pygoscelis papua)            124
Chinstrap penguin (Pygoscelis antarctica)     68
Name: Species, dtype: int64
## Do some statistical description of the features



  • Step3: Visual description


## Scatter visualization of feature and label combinations
sns.pairplot(data=data, diag_kind='hist', hue= 'Species')


'''For convenience we convert labels to numbers
'Adelie Penguin (Pygoscelis adeliae)'  ------0
'Gentoo penguin (Pygoscelis papua)'  ------1
'Chinstrap penguin (Pygoscelis antarctica) ------2 '''
def trans(x):
  if x == data['Species'].unique()[0]:
    return 0
  if x == data['Species'].unique()[1]:
    return 1
  if x == data['Species'].unique()[2]:
    return 2
data['Species'] = data['Species'].apply(trans)

Draw boxplots of different types of penguins under each attribute, and we can also get the distribution differences of different categories on different characteristics by using the boxplots.

for col in data.columns:
    if col != 'Species':
        sns.boxplot(x='Species', y=col, saturation=0.5, palette='pastel', data=data)

# Select its first three features to draw a 3D scatter plot

from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')
data_class0 = data[data['Species']==0].values
data_class1 = data[data['Species']==1].values
data_class2 = data[data['Species']==2].values
# 'setosa'(0), 'versicolor'(1), 'virginica'(2)
ax.scatter(data_class0[:,0], data_class0[:,1],
ax.scatter(data_class1[:,0], data_class1[:,1],
ax.scatter(data_class2[:,0], data_class2[:,1],




  • Step4: Use the decision tree model to train and predict on the binary classification


## To properly evaluate model performance, the data is divided into training and test sets, and the model is trained on the training set and the model performance is verified on the test set.
from sklearn.model_selection import train_test_split
## Select samples whose classes are 0 and 1 (excluding samples with class 2)
data_target_part = data[data['Species'].isin([0,1])][['Species']]
data_features_part = data[data['Species'].isin([0,1])][['Culmen Length (mm)', 'Culmen Depth (mm)', 'Flipper Length (mm)', 'Body Mass (g)']]
## Test set size is 20%, 80%/20% points
x_train, x_test, y_train, y_test = train_test_split(data_features_part, data_target_part,
test_size = 0.2, random_state = 2020)


## Import decision tree model from sklearn
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
## Define Logistic Regression Model
clf = DecisionTreeClassifier(criterion='entropy')
# Train a decision tree model on the training set, y_train)
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort=False, random_state=None,
## visualization
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)


Posted by mcatalf0221 on Fri, 20 May 2022 19:53:53 +0300