[mathjax]
Mini-Case Study: A Toy Decision Tree
LEARNING OBJECTIVES
In this section, we construct a toy decision tree on a small-scale dataset taken from a sample question of the Modern Actuarial Statistics II Exam of the Casualty Actuarial Society and displayed in Table 5.1. The small number of observations makes it possible for us to perform calculations by hand and replicate the R output that is inadequately explained in the PA e-learning modules and commonly misunderstood by many users (partly due to the somewhat confusing documentation of the package we will use).
X1 | X2 | Y |
1 | 0 | 1.2 |
2 | 1 | 2.1 |
3 | 2 | 1.5 |
4 | 1 | 3.0 |
2 | 2 | 2.0 |
1 | 1 | 1.6 |
# CHUNK 1 X1 <- c(1, 2, 3, 4, 2, 1) X2 <- c(0, 1, 2, 1, 2, 1) Y <- c(1.2, 2.1, 1.5, 3.0, 2.0, 1.6) dat <- data.frame(X1, X2, Y)
After completing this mini-case study, you should be able to:
- Fit a decision tree using the
rpart()
function from therpart
package. - Understand how the control parameters of the
rpart()
function control tree complexity. - Produce a graphical representation of a fitted decision tree using the
rpart
.plot()
function. - Interpret the output for a decision tree, both in list form and in graphical form.
- Identify possible interactions between variables from a fitted tree.
- Prune a decision tree using the
prune()
function.
Basic Functions and Arguments
Main function
rpart()
and its control parameters
rpart
(short for “recursive partitioning”) package has the main function rpart()
. Just like the lm()
and glm()
functions, the rpart()
function takes a formula argument specifying the target variable and the predictors and a data argument specifying the data frame hosting all of the necessary variables.
The method argument determines whether a regression tree (with method = "anova"
) or a classification tree (with method = "class"
) is to be grown. If this argument is left unspecified, then R will examine the nature of the target variable and make an intelligent guess on whether a regression or classification tree should be built, which may not be desirable, especially when the target variable is a categorical variable coded by numeric labels.
Perhaps the most important argument is the control
argument, which uses the rpart.control()
function to specify a list of parameters “controlling” when the partition stops, or equivalently, complexity of the tree to be built. Here are the most commonly used parameters:
minsplit
This is the minimum number of observations that must exist in a node in order for a split to be attempted. For example, if minsplit
is set to 10, then a node with fewer than 10 observations will not be considered for further splitting.
Everything else equal, the lower the value of minsplit
(the lowest value is 1), the larger and more complex the fitted tree.
minbucket
Short for “minimum bucket size,” this is the minimum number of observations in any terminal node (“bucket”). If a split generates a node that has fewer than this number of observations, then the split will not be made. Same as minsplit
, the lower the value of minbucket
(the lowest value is 1), the larger and more complex the fitted tree.
Comparing minsplit
and minbucket
, we can see that the former refers to the minimum number of observations before splitting and the latter refers to the minimum number of observations after splitting. In most situations, either minsplit
or minbucket
needs to be specified as they function similarly to limit decision tree complexity.
According to the documentation of rpart.control()
, when only one of minsplit
or minbucket
is specified, either minsplit
is set to minbucket*3
or minbucket
is set to minsplit/3
.
cp
Short for “complexity parameter,” this refers to the Cp value penalizing a tree by its size when cost-complexity pruning is performed. Usually we can think of cp
as the minimum amount of reduction in the relative training error required for a split to be made. When a split is made, the tree size grows by 1 and the penalty term Cp|T| increases by cp
, so if the split does not decrease the relative training error by at least cp
, then the split will not be performed.
The default value of cp
is 0.01. The higher the value of cp
, the fewer the number of splits to be made and the less complex the fitted tree.
In the two extreme cases, setting cp = 0
leads to the most complex tree (subject to the constraints on other parameters such as minbucket
and maxdepth
) and setting cp = 1
prohibits any splits.
To reduce tree complexity, later we will prune a fitted decision tree using a certain value of cp based on the results returned by the rpart()
function. You may then wonder:
Why do we have to pre-specify a complexity parameter? Why do we not grow a very complex tree by setting
cp = 0
and prune later?
The role played by the cp
parameter in the initial call of the rpart()
function is to “pre-prune” splits that are obviously not worthwhile to pursue (which will likely be pruned off by cross-validation when conducted) and thus to save computational effort in the case of large datasets.
We should be careful not to prescribe an inappropriately large cp
to avoid “over-pruning.” A reasonably small value such as cp = 0.001
will do this screening well.
maxdepth
This is the maximum depth of the tree, or the maximum number of branches from the tree’s root node to the furthest terminal node. The higher the value of maxdepth
, the more complex the fitted tree. The default value of maxdepth
is 30.
Very often, the control parameters above are hyperparameters that are tuned by cross-validation (for cp
) or manually (for minbucket
and maxdeptht
as part of the tree construction process in search of the right level of tree complexity.
The two parameters below are not related to tree complexity, but affect how tree splits are evaluated.
xval
This is the number of folds used when doing cross-validation. The default value is 10, i.e., 10-fold cross-validation is performed. Although this parameter has no effect on the complexity of the fitted tree, it affects the performance, assessed by cross-validation, of the decision trees automatically fitted when the rpart()
function is called.
parms
This argument is specific to categorical target variables and describes the parameters that guide how the splits are performed, using Gini (parms = list(split = "gini")
) or information gain, which is the drop in entropy following the split, (parms = list(split = "information")
) as the node impurity measure.
Fitting and Visualizing a Decision Tree
In CHUNK 2
, we load the rpart
and rpart.plot
(for plotting rpart
objects) packages, fit a regression tree for Y using X1 and X2 as predictors and save it as an rpart
object named dt
(meaning a “decision tree”).
Because the sample size is so small and this case study is for illustration purposes, we have deliberately set the control parameters such that the most complex tree will be constructed. With minsplit = 1
and minbucket = 1
and cp = 0
, each node only has to contain at least one observation and a split will be made as long as it results in some improvement in node impurity. Then we pass dt
to the rpart.plot()
If unction to give a graphical representation of the regression tree just grown.
# CHUNK 2 library(rpart) library(rpart.plot) dt <- rpart(Y ~ ., data = dat, control = rpart.control(minsplit = 1, minbucket = 1, cp = 0, xval = 6)) rpart.plot(dt)
Figure 5.2.1: The fully grown regression tree for Y with X1 and X2 as predictors
Root Node
In Figure 5.2.1, the top number in each node of the tree indicates the fitted target value and the bottom number represents the proportion of observations lying in that node. For instance, at the top of the tree the root node has all of the n = 6 observations (hence 100% of the observations) and a fitted value of 1.9, which is simply the average of the target value of the 6 observations:
\(\bar{y}=\dfrac{1.2+2.1+1.5+3.0+2.0+1.6}{6}=1.9\)
Child Nodes
The first split uses X1 < 4 as the splitting criterion, shown immediately below the root node, meaning that out of the two predictors X1 and X2 and among all possible cutoff points, partitioning the data using X1 and using 4 as the cutoff at this stage best separates the six observations and leads to the greatest reduction in tree impurity measured by RSS. Observations satisfying X1 < 4 (i.e., all except observation 4) are classified to the left branch (see the box labeled “yes”) while the one that does not (i.e., observation 4) is assigned to the right branch (see the box labeled “no”).
The partitions continue with the left branch, which still contains five observations, until no more node can be split to reduce the overall RSS of the model. In this simple case, each of the six observations occupies a separate terminal node with a fitted value equal to the observed value and the RSS being exactly zero. We have a perfect fit, which will be the case if the control parameters impose no restrictions on the complexity of the fitted tree.
Using a decision tree to detect interaction
From Figure 5.2.1, it appears that there is some interaction between X1 and X2. Recall from earlier chapters that an interaction exists if the relationship between one predictor and the target variable depends on the value of another predictor. In the case of dt
, X2 has a significant effect on Y only when X1 < 4. The left branch of the root node is followed by a split using X2 while the same split cannot be found in the right branch. If X1 ≥ 4, then the fitted tree says that a further split by X2 is no longer useful (of course, we know that the right branch contains only one observation and so further splitting is impossible, but the considerations here apply to larger datasets). In contrast, if there were no interactions between X1 and X2, then we would expect the same set of splits that appear in the left branch to be used in the right branch as well. In fact, if we fit an ordinary least squares linear model regressing Y on X1 and X2 with interaction, then the model summary (see CHUNK 3
) confirms that the interaction is statistically significant.
EXAM NOTE
The December 2018 and December 2019 exams expect candidates to recognize the existence of interaction between predictor variables from a fitted decision tree, then incorporate the interaction term into a GLM, although there is hardly any mention of interaction in the Rmd
sections of Module 7 of the PA e-learning modules.
# CHUNK 3 ols <- lm(Y ~ X1 * X2, data = dat) summary(ols)
Call: lm(formula = Y ~ X1 * X2, data = dat) Residuals: 1 2 3 4 5 6 0.053226 -0.024194 -0.053226 0.008065 0.106452 -0.090323 Coefficients: Estimate Std. Error t value Pr(>|t|) (Intercept) -0.06129 0.18810 -0.326 0.77547 X1 1.20806 0.10892 11.092 0.00803 ** X2 1.31774 0.20351 6.475 0.02303 * X1:X2 -0.77419 0.09995 -7.746 0.01626 * --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 Residual standard error: 0.1136 on 2 degrees of freedom Multiple R-squared: 0.9871, Adjusted R-squared: 0.9677 F-statistic: 51 on 3 and 2 DF, p-value: 0.01929
A closer look at the tree splits
If we type the name of an rpart
object, we will get a condensed printout of how all of the tree splits are performed. This is done in CHUNK 4
for dt
. Although the output is not as aesthetically appealing as the displayed tree in Figure 5.2.1, we can get more detailed information about different splits.
# CHUNK 4 dt
n= 6 node), split, n, deviance, yval * denotes terminal node 1) root 6 2.000 1.90 2) X1< 3.5 5 0.548 1.68 4) X2< 0.5 1 0.000 1.20 * 5) X2>=0.5 4 0.260 1.80 10) X1>=2.5 1 0.000 1.50 * 11) X1< 2.5 3 0.140 1.90 22) X1< 1.5 1 0.000 1.60 * 23) X1>=1.5 2 0.005 2.05 46) X2>=1.5 1 0.000 2.00 * 47) X2< 1.5 1 0.000 2.10 * 3) X1>=3.5 1 0.000 3.00 *
There is one line per node. For each node, we can see:
- The splitting criterion in the
split
column - The number of training observations lying in that node in the
n
column - The residual sum of squares in the
deviance
column (recall that we are dealing with a numeric target variable) - The fitted target value in the
yval
column
The tree starts with the root node, numbered the first node, where all the 6 observations are hosted and the fitted target value is the sample target mean \(\bar{y}=\sum\nolimits_{i=1}^{6}{y_i/6}=1.9\), with a residual sum of squares of:
\(TSS=\sum\nolimits_{i=1}^{6}{(y_i-\bar{y})^2}=(1.2-1.9)^2+(2.1-1.9)^2+…+(1.6-1.9)^2=2\)
which is also the total sum of squares in the terminology of linear models. The root node is then split into child nodes 2 and 3, with observations satisfying X1 < 3.5 sent to the left branch, where the fitted target value is 1.68 (= (1.2 + 2.1 + 1.5 + 2.0 + 1.6)/5), and those satisfying X1 ≥ 3.5 (in fact, only observation 4) sent to the right branch, where the fitted target value is 3.
To facilitate identification and inspection, child nodes in a decision tree are indented and numbered in this format: The two child nodes of node x are numbered 2x and 2x + 1, so, for example, the two child nodes of node 2 are node 4 (= 2 x 2) and node 5 (= 2 x 2 + 1). When the splitting stops, we have reached a terminal node, which is signified by an asterisk (*) at the end of the line.
Pruning a Decision Tree
Reading the complexity parameter table
A powerful feature of the rpart()
function is that in addition to fitting the decision tree by solving (5.1.1) with the indicated complexity parameter denoted by cp
, behind the scenes it automatically solves (5.1.1) for all values of the complexity parameter greater than cp ≥ cp
, fits a collection of decision trees simpler than the fitted tree, and evaluates their predictive performance using internal cross-validation. This allows us to evaluate the model at a wide range of complexity and choose the best one. To access these results, we take advantage of the fact that an rpart
object such as dt
is a list (that is why we had to learn lists back in Subsection 1.2.4!) and extract its cptable
component in CHUNK 5
.
# CHUNK 5 dt$cptable
CP nsplit rel error xerror xstd 1 0.72600 0 1.0000 1.440000 0.745794 2 0.14400 1 0.2740 2.399012 1.059821 3 0.06375 2 0.1300 2.220556 1.004330 4 0.00250 4 0.0025 2.420000 1.252211 5 0.00000 5 0.0000 2.420000 1.252211
The cptable
is a matrix that shows, for each cutoff value of the complexity parameter (the CP
column), the number of splits of the tree constructed (the nsplit
column), relative training error (the rel
error column), cross-validation error (the xerror
column), and standard error of the cross-validation error (the xstd
column).
The number of splits is a measure of tree complexity; the total number of terminal nodes is the number of splits plus one. As we expect, the lower the value of the complexity parameter, the larger the number of splits, and the lower the relative error. The rel
error and xerror
columns of the cptable
are not adequately explained in the PA e-learning modules. Let’s give them a closer look and verify some of the entries.
Choosing the Best cp Value
While the relative error decreases as we descend the CP
column (as the tree becomes more complex), the (scaled) cross-validation error follows a different pattern which is somewhat irregular due to the small sample size. We can see that although the five-split tree has a zero relative error, it has the highest cross-validation error, showing that it is seriously overfitted. To ensure that the decision tree has the right level of complexity, we prune it by setting the complexity parameter to an appropriate value rather than 0. One natural choice is the value that gives rise to the smallest cross-validation error. (Another commonly used choice will be introduced in Section 5.3.) In CHUNK 7
, we first extract the optimal complexity parameter and use the prune()
function to prune the tree to the corresponding level of complexity.
# CHUNK 7 cp.min <- dt$cptable[which.min(dt$cptable[, "xerror"]), "CP"] cp.min
[1] 0.726
To understand how the optimal complexity parameter is extracted coding-wise, we first realize that:
dt$cptable
is a matrix, whose elements can be extracted by the square bracket[, ]
operator (recall what we learned in Subsection 1.2.2).- The command
dt$cptable [, "xerror"]
extracts the column of cross-validation errors and which.min(dt$cptable [, "xerror"])
returns the row corresponding to the minimum cross-validation error.- Finally,
dt$cptable [which.min(dt$cptable[, "xerror"]), "CP"]
uses this row to select along the column of complexity parameters and produces the optimal cp.
dt.pruned <- prune(dt, cp = cp.min) rpart.plot(dt.pruned)
Figure 5.2.3: The pruned regression tree for Y with X1 and X2 as predictors
- This optimal value is then fed into the
cp
argument of theprune()
function to prune back branches ofdt
that do not satisfy the impurity reduction requirement prescribed by the optimal complexity parameter.
As we expect, the pruned tree in this case is the no-split tree with only the root node. This suggests that the fitted decision tree is, to our dismay, not very useful, perhaps due to the small sample size.