Taking too long? Close loading screen.

SOA ASA Exam: Predictive Analysis (PA) – 5.3. Extended Case Study: Classification Trees

[mathjax]

Extended Case Study: Classification Trees

LEARNING OBJECTIVES

The focus of this section is on constructing, evaluating, and interpreting base and ensemble trees. At the completion of this case study, you should be able to:

  • Understand how decision trees form tree splits based on categorical predictors.
  • Understand how decision trees deal with numeric predictors having a non-linear relationship with the target variable.
  • Build base classification trees, control their complexity by pruning, and interpret their output.
  • Build ensemble trees using the caret package and tune the model parameters for optimal performance.
  • Quantify the prediction accuracy of (base or ensemble) classification trees constructed.
  • Recommend a decision tree taking both prediction accuracy and interpretability into account.

 

Problem Set-up and Preparatory Steps

Data Description

This case study revolves around the Wage dataset in the ISLR package. This dataset contains the income and demographic information (e.g., age, education level, marital status) collected through an income survey for a group of 3,000 male workers residing in the Mid-Atlantic region of the US. The data dictionary is shown in Table 5.2.

Variable Description Values
year Calendar year that wage information was recorded Integer from 2003 to 2009
age Age of worker Integer from 18 to 80
maritl Marital status of worker Factor with 5 levels: 1. Never Married 2. Married 3. Widowed 4. Divorced 5. Separated
race Race of worker Factor with 4 levels: 1. White 2. Black 3. Asian 4. Other
education Education level of worker Factor with 5 levels: 1. < HS Grad 2. HS Grad 3. Some College 4. College Grad 5. Advanced Degree
region Region of the country Factor with 9 levels, but only 2. Middle Atlantic contains observations
jobclass Job class of worker Factor with 2 levels: 1. Industrial 2. Information
health Health level of worker Factor with 2 levels: 1. <=Good 2. >=Very Good
health_ins Whether worker has health levels: insurance Factor with 2 levels: Yes / No
logwage Log of worker’s raw wage Numeric from 3.000 to 5. 763
wage Worker’s raw wage (in $1,000s) Numeric from 20.09 to 318.34

 

Objective

Our aim is to identify the key determinants of a worker’s wage with the aid of appropriate decision trees.

The findings of our analysis will have practical implications for the society. Employers will have a quantitative basis for setting equitable wages for their employees and attracting new talents, and employees will be able to assess their fair market value.

Because the end users of these tree models are individuals who may not have the expertise in predictive analytics, having a model that is easy to interpret, communicate, and implement, and identifies key characteristics affecting wage, will be a big plus.

 

For starters, let’s run CHUNK 1 to load the ISLR package and the Wage data, and print a summary.

# CHUNK 1
# Load the data
library (ISLR)
data("Wage")

# Summarize the data
summary(Wage)
     year           age                     maritl           race                   education  
Min.   :2003   Min.   :18.00   1. Never Married: 648   1. White:2480   1. < HS Grad      :268  
1st Qu.:2004   1st Qu.:33.75   2. Married      :2074   2. Black: 293   2. HS Grad        :971  
Median :2006   Median :42.00   3. Widowed      :  19   3. Asian: 190   3. Some College   :650  
Mean   :2006   Mean   :42.41   4. Divorced     : 204   4. Other:  37   4. College Grad   :685  
3rd Qu.:2008   3rd Qu.:51.00   5. Separated    :  55                   5. Advanced Degree:426  
Max.   :2009   Max.   :80.00                                                                   
                                                                                               
                  region               jobclass               health      health_ins  
2. Middle Atlantic   :3000   1. Industrial :1544   1. <=Good     : 858   1. Yes:2083  
1. New England       :   0   2. Information:1456   2. >=Very Good:2142   2. No : 917  
3. East North Central:   0                                                            
4. West North Central:   0                                                            
5. South Atlantic    :   0                                                            
6. East South Central:   0                                                            
(Other)              :   0                                                            
   logwage           wage       
Min.   :3.000   Min.   : 20.09  
1st Qu.:4.447   1st Qu.: 85.38  
Median :4.653   Median :104.92  
Mean   :4.654   Mean   :111.70  
3rd Qu.:4.857   3rd Qu.:128.68  
Max.   :5.763   Max.   :318.34
str(Wage)
'data.frame':	3000 obs. of  11 variables:
 $ year      : int  2006 2004 2003 2003 2005 2008 2009 2008 2006 2004 ...
 $ age       : int  18 24 45 43 50 54 44 30 41 52 ...
 $ maritl    : Factor w/ 5 levels "1. Never Married",..: 1 1 2 2 4 2 2 1 1 2 ...
 $ race      : Factor w/ 4 levels "1. White","2. Black",..: 1 1 1 3 1 1 4 3 2 1 ...
 $ education : Factor w/ 5 levels "1. < HS Grad",..: 1 4 3 4 2 4 3 3 3 2 ...
 $ region    : Factor w/ 9 levels "1. New England",..: 2 2 2 2 2 2 2 2 2 2 ...
 $ jobclass  : Factor w/ 2 levels "1. Industrial",..: 1 2 1 2 2 2 1 2 2 2 ...
 $ health    : Factor w/ 2 levels "1. <=Good","2. >=Very Good": 1 2 1 2 1 2 2 1 2 2 ...
 $ health_ins: Factor w/ 2 levels "1. Yes","2. No": 2 2 1 1 1 1 1 1 1 1 ...
 $ logwage   : num  4.32 4.26 4.88 5.04 4.32 ...
 $ wage      : num  75 70.5 131 154.7 75 ...

There are 3,000 observations and 11 variables in the data. The target variable ( at this stage) is the last variable, wage ( note that the first letter is not capitalized, unlike the name of the dataset), representing the wage in $1,000s of each of the 3,000 workers in the data, and logwage, as its name implies, is the log-transformed version of wage. Out of the nine predictors, there are two integer variables, year and age, and the rest are factor variables, some of which like maritl, race, and education have multiple levels.

A peculiar variable in the data is region. The data summary suggests that it has multiple levels, but all of the 3,000 observations belong to “2. Middle Atlantic“. As this variable cannot differentiate any observations, it has no value and will be deleted in CHUNK 2.

# CHUNK 2 
Wage$region <- NULL

 

TASK 1: Consider transformations of wage
Your supervisor has asked you to construct a regression tree to predict wage.

  • Explore the distribution of wage.

You are now considering the following transformations of wage:

  • wage ( no transformation)
  • Log of wage
  • Square root of wage

Recommend which transformation of wage from the list above to use for a regression tree. Justify your recommendation.

 

Let’s begin by exploring the distribution of the target variable, wage. Because wage is a numeric variable, let’s visualize the distribution of wage by means of a histogram, which is exhibited in the top left panel of Figure 5.3.1 (the summary statistics for wage in CHUNK 1 are also useful).

The histogram shows that wage has a right-skewed distribution, with some substantially large values in the 250-300 range. Remember that for a regression tree, the splits are chosen to minimize the residual sum of squares (RSS), which, for a particular node, is the squared discrepancy between the observed target values and the average of the target variable in that node.

For a right-skewed target variable, the large observed values will exert a disproportionate effect on the RSS calculations, which may not be desirable. This motivates us to consider transforming wage in an attempt to symmetrize its distribution.

As we learned in Section 2.2, two commonly used transformations for reducing right skewness are the log transformation and the square root transformation, both of which are applicable in this setting because wage is a strictly positive variable.

The histograms of the resulting transformed versions of wage are displayed in the top right and bottom left panels of Figure 5.3.1. It appears that the two transformations reduce the right skewness of wage quite effectively, although the log transformation makes the distribution of wage slightly left-skewed. Between the two transformations, the square root transformation should be superior.

(Note: If a regression tree is fitted to the square root of wage, then for the purposes of prediction we will have to square the predictions produced by the tree to put things back to the original scale of wage.)

# CHUNK 3
library(ggplot2)
ggplot(Wage, aes(x = wage)) + geom_histogram()
ggplot(Wage, aes(x = log(wage))) + geom_histogram()
ggplot(Wage, aes(x = sqrt(wage))) + geom_histogram()

Figure 5.3.1: Histograms of wage, log-transformed wage, and the square root of wage in the Wage data

As a concrete illustration of how transformations of a target variable affect a decision tree, in CHUNK 4 we construct two regression trees on the full Wage data, one for wage untransformed, and one for the square root of wage, using age as the single predictor and a single split (achieved by setting maxdepth = 1). We will pay attention to the value of the split point and the compositions of the two resulting child nodes.

# CHUNK 4 
# Load needed packages for trees
library(rpart)
library(rpart.plot)

# Fit the two trees
tree.untransformed <- rpart(wage ~ age, 
                            data= Wage, 
                            method = "anova", 
                            control = rpart.control(maxdepth = 1))
                            
tree.sqrt <- rpart(sqrt(wage) ~ age,
                   data= Wage,
                   method = "anova",
                   control= rpart.control(maxdepth = 1))

# PLot the two trees 
rpart.plot(tree.untransformed, main = "Wage") 
rpart.plot(tree.sqrt, main = "Square Root of Wage") 

Findings:

  • The two regression trees are indeed different because the two tree splits make use of different cutoff values of age (30 and 26), leading to different numbers of observations in the two child nodes (15%/85% and 8%/92%).

 

EXAM NOTE

In the December 8, 2020 exam, many candidates incorrectly stated that transformations on the target variable have no impact on trees. Not good!

 

TASK 2:

Switch the regression problem to a classification problem Knowing that you are relatively new to decision trees, your supervisor has changed his mind and asked you to construct, for simplicity, classification trees to predict the probability of a worker earning more than $100,000.

  • Create a binary variable called wage_flag that equals 0 if the salary of a worker is less than $100,000, and 1 if the salary exceeds $100,000.
  • Discuss the pros and cons of modeling wage_flag in place of wage.

 

Creation of a wage flag variable

Instead of taking wage as the target variable, we are going to turn it into a binary variable, wage_flag, that is a flag of workers earning a 6-figure salary, and treat wage_flag as the target variable in the rest of this case study. Specifically, wage_flag equals 1 if the salary of a worker is higher than $100,000 and O otherwise.

This binary variable is created by means of an if else() statement and summarized in the first part of CHUNK 5. (Remember that wage is measured in $1,000s, so we should use $100, 000/$1, 000 = 100 as the cutoff value.)

# CHUNK 5 
Wage$wage_flag <- ifelse(Wage$wage >= 100, 1, 0)
summary(Wage$wage_flag)
  Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
0.0000  0.0000  1.0000  0.5523  1.0000  1.0000

The summary shows that 55.23% of the workers earn more than $100,000 while 44.77% do not. For simplicity, we will refer to workers earning $100,000 or more as “high earners” and those earning less than $100,000 as “low earners.”

 

wage vs. wage_flag

Turning our attention from wage to wage_flag will change the nature of our problem from regression to classification and is an alternative approach to understanding wage. As with most decisions in predictive analytics, taking wage_flag as the target variable comes with pros and cons:

Pros

    • If we treat wage_flag as the target variable, we will be able to avoid any technical issues arising from the skewed distribution of wage. We no longer need to worry about transformations to remedy the right skewness of wage.
      (Note: This is a weaker point in the context of this case study as the square root transformation symmetrizes the distribution of wage successfully.)
    • Treating wage_flag as the target variable may yield more specific insights into the drivers of being a high earner, if $100,000 is a wage level that is of general interest, e.g., tax rate changes substantially beyond this level.

Cons

    • The biggest drawback of analyzing wage_flag instead of wage is the (significant) loss of information, with respect to both the input and output. All wage values will be turned into 0 or 1, depending on whether they are higher than 100, and their exact values will be discarded.
      Our model will not distinguish between a wage of 150 and a wage of 200, for example. Due to the reduced amount of information, our model will not provide a prediction for wage, but only whether a worker is a high earner or not as output. Such a model may have limited applicability if cutoffs other than $100,000 are of interest, and refitting will be needed.

 

We will end this task by removing the two wage variables, which will not be used later.

# CHUNK 5 (Cont.) 
# Remove the two wage variables 
Wage$wage <- Wage$logwage <- NULL 

 

TASK 3: Explore and edit the data

Your supervisor has asked you to construct separate plots showing the proportion of high earners by each predictor.

  • Make these plots.
  • Identify which variable is least likely to have a significant impact on the probability of being a high earner. Justify your choice.
  • Identify two variables that are most likely to have a significant impact on the probability of being a high earner. Justify your choices.
  • Explain the advantage and disadvantage of keeping year as a numeric variable.
  • Reduce the number of levels of factor variables, if applicable.

 

In this preparatory task, we will get to understand the variables in the data. Specifically, we will identify key variables associated with high earners and process some of the variables to prepare for later analysis.

 

Bivariate Data Exploration

Recall that our target variable, wage_flag, is a binary variable. To explore the relationship between wage_flag and each of the eight predictors, we will make use of filled bar charts, which show the proportion of high earners over each value or level of the predictors. This approach also works for the two numeric predictors, year and age, which are discrete (in fact, integer-valued) with a relatively small number of distinct values. In CHUNK 6, we use a for loop to make these bar charts for each predictor. The plots for year, age, maritl and education, in this order, are shown on the next two pages for illustration.

# CHUNK 6
vars <- colnames(Wage)[0:8] # exclude wage_flag
for (i in vars) {
    plot <- ggplot(Wage, aes(x = Wage[, i], fill = factor(wage_flag))) +
            geom_bar(position = "fill") +
            labs(x = i, y = "Proportion of High Earners") +
            theme(axis.text.x = element_text(angle = 90, hjust = 1))
    print(plot)
}

Findings:

Judging by how much the proportion of high earners changes across different values or levels of the predictors, we can see that all predictors are useful for predicting the probability of being a high earner, with the exception of year, which seems to be the weakest predictor.

It appears that the proportion of high earners increases with year, perhaps due to inflation (everything else equal, wage tends to increase over time), but only mildly. The predictive power of year is easily overshadowed by other predictors in the data.

The other seven predictors are all strongly associated with the proportion of high earners, the most conspicuous ones being age and education.

age

We can see that the proportion of high earners follows a downward parabolic shape, increasing steadily from age 20 to roughly age 60, then decreasing thereafter. The parabolic behavior can be attributed to two opposing forces. It makes sense that an elder worker earns more as a result of more experience and expertise, but the very old workers may have received less up-to-date training than their younger counterparts and be less productive.

Approaching the retirement age, they may also work fewer hours, leading to a lower wage. A GLM (which we will not use in this case study) will have trouble capturing this downward parabolic behavior unless high-order polynomial terms are manually inserted.

education

The education level of a worker clearly makes a huge difference. As education increases, there is a sharp and unmistakable rise in the proportion of high earners. It is perfectly intuitive that the more educated a worker, the higher their salary, all else equal.

Note: If the task requests that you choose one and only one predictor that is most strongly associated with the proportion of high earners, then either age or education should be acceptable as long as your justification is sound.

In an exam environment, it seems easier to choose education because the narrative is more intuitive. From a statistical point of view, the significance of education is slightly stronger. The summary in CHUNK 1 shows that each level of education has a relatively bountiful number of observations, whereas the two ends of the distribution of age involve rather scarce observations.

(Exercise: Make a histogram for age to see this!).

 

Data Processing

Having explored the data and identified variables with predictive power, let’s make some simple data adjustments so that the dataset is fully ready for analysis.

Variable type conversions (from numeric to factor, or from factor to numeric) are less of a problem for decision trees than for GLMs. One of the numeric variables, year, deserves some special attention. As an integer variable, the values of year have a numeric order and the output in CHUNK 6 shows that year varies monotonically with wage_flag. In fact, for the purposes of making predictions in the future, it is necessary to keep year as a numeric variable.

  • year as a factor variable: If year is converted to a factor variable with seven levels, “2003”, “2004”, … , “2009”, then a predictive model using year as a predictor will have trouble making predictions for a worker after year 2009. Year 2022, for example, does not belong to any of the seven levels of year, meaning that the model cannot be used in practice to predict the probability that a worker in future years is a high earner.
  • year as a numeric variable: Taking year as a numeric variable, we are able to make predictions for workers in future years by extrapolating the increasing trend to the future. This, however, relies on the assumption that the proportion of high earners will continue to increase over time, which is not always correct.

Among the factor variables, two of them have an unusually small number of observations at some of their levels.

  • race
    There are only 37 observations belonging to “4. Other”, where the proportion of high earners is the lowest. It is fine to combine these observations with the 293 Black workers, who have the second lowest proportion of high earners, but it is also valid to leave this class as is due to its special nature so that future workers who are not white, Black, or Asian will constitute a separate class and their effects on wage_flag can be accounted for separately.
  • maritl
    There are only 19 widowed workers and 55 separated workers. The proportions of high earners for these workers are close to the proportion for divorced workers. Because ” 3. Widowed”, “4. Divorced”, and “5. Separated” have somewhat similar meaning (they are all related to what happens after getting married) and similar relationships to wage_flag, we can combine them into a more populous level known as “3. Other” to improve the robustness of the models to be constructed. This is done in CHUNK 7 by means of the levels() function coupled with the elementary method of vector subsetting.
# CHUNK 7
levels(Wage$maritl)[3:5] <- "3. Other"
table(Wage$maritl)
1. Never Married       2. Married         3. Other 
             648             2074              278

 

Construction and Evaluation of Base Classification Trees

TASK 4: Construct base classification trees

  • Split the data into training and test sets.

Your assistant has provided code (see CHUNK 9) for fitting an unpruned classification tree for wage_flag using all predictors on the training set.

  • Interpret the output of the first split of the unpruned tree.
    Review the complexity parameter table for the unpruned tree. Propose and justify two different choices of the complexity parameter value to get two pruned trees of different levels of complexity.

For the next two items, take the simplest tree you have constructed.

  • Interpret this tree as a whole.
  • Use this tree to predict the probability of being a high earner for the following worker:
    A 30-year old, married, Asian worker who has health insurance and is a college graduate

 

In this long but important task, we will construct a total of three classification trees of different levels of complexity for wage_f lag. Prior to fitting any trees, we split the data into the training (70%, 2101 observations) and test (30%, 899 observations) sets again with the use of the createDataPartition() function.

# CHUNK 8 
library(caret)
set.seed (2021)
partition <- createDataPartition(y = as.factor(Wage$wage_flag),
                                 p = .7,
                                 list = FALSE)
data.train <- Wage[partition, ]
data.test <- Wage[-partition, ]
mean(data.train$wage_flag)
mean(data.test$wage_flag)
> mean(data.train$wage_flag)
[1] 0.552118
> mean(data.test$wage_flag)
[1] 0.5528365

To check that the two sets are representative, we note that the proportion of high earners is 55.2118% on the training set and 55.28365% on the test set. While these two proportions do not exactly match, they are close enough, showing the effectiveness of the built-in stratification of the target variable. We will use the same training/test partition for all tree models in this case study.

 

Tree 1: Fitting a large, unpruned classification tree

To begin with, we build a large, unpruned classification tree for wage_flag using all available predictors. The control parameters (supposedly given by your assistant) are deliberately set so that a sufficiently complex tree can be grown. The complex tree will be pruned at a later stage to reduce its complexity. Here are the control parameters we use:

  • minbucket: The minimum number of observations in any terminal node is 5, which is reasonable in view of the relatively small size of the data.
  • cp: The complexity parameter is set to 0.0005. The low value of cp allows us to construct a sufficiently large tree while pre-pruning small branches that are not worthwhile to pursue.
  • maxdepth: The maximum depth of the tree is 7.

 

EXAM NOTE

On the exam, you will most likely be given the control parameters directly and asked to run the code to grow a large tree without having to make any changes. Although you will rarely be asked to specify the control parameters yourself, some conceptual exam items may test how these parameters affect tree complexity, so it is necessary to know what they are.

# CHUNK 9
set.seed(60)

# method = "class" ensures that the target is treated as a categorical variable
tree1 <- rpart(wage_flag ~ ., 
               data = data.train, 
               method = "class", 
               control = rpart.control(minbucket = 5, 
                                       cp = 0.0005, 
                                       maxdepth = 7), 
                                       parms = list(split = "gini"))
tree1 
n= 2101 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 2101 941 1 (0.44788196 0.55211804)  
    2) education=1. < HS Grad,2. HS Grad 853 295 0 (0.65416178 0.34583822)  
      4) health_ins=2. No 334  62 0 (0.81437126 0.18562874) *
      5) health_ins=1. Yes 519 233 0 (0.55105973 0.44894027)  
       10) maritl=1. Never Married 89  16 0 (0.82022472 0.17977528)  
         20) health=2. >=Very Good 59   5 0 (0.91525424 0.08474576) *
         21) health=1. <=Good 30  11 0 (0.63333333 0.36666667)  
           42) age>=51.5 9   1 0 (0.88888889 0.11111111) *
           43) age< 51.5 21  10 0 (0.52380952 0.47619048)  
             86) year< 2005.5 11   4 0 (0.63636364 0.36363636) *
             87) year>=2005.5 10   4 1 (0.40000000 0.60000000) *
       11) maritl=2. Married,3. Other 430 213 1 (0.49534884 0.50465116)  
         22) age>=65.5 17   2 0 (0.88235294 0.11764706) *
         23) age< 65.5 413 198 1 (0.47941889 0.52058111)  
           46) education=1. < HS Grad 67  26 0 (0.61194030 0.38805970)  
             92) age< 34.5 10   0 0 (1.00000000 0.00000000) *
             93) age>=34.5 57  26 0 (0.54385965 0.45614035)  
              186) year< 2004.5 18   5 0 (0.72222222 0.27777778) *
              187) year>=2004.5 39  18 1 (0.46153846 0.53846154) *
           47) education=2. HS Grad 346 157 1 (0.45375723 0.54624277)  
             94) maritl=3. Other 45  19 0 (0.57777778 0.42222222)  
              188) age< 53.5 36  13 0 (0.63888889 0.36111111) *
              189) age>=53.5 9   3 1 (0.33333333 0.66666667) *
             95) maritl=2. Married 301 131 1 (0.43521595 0.56478405) *
    3) education=3. Some College,4. College Grad,5. Advanced Degree 1248 383 1 (0.30689103 0.69310897)  
      6) health_ins=2. No 302 137 0 (0.54635762 0.45364238)  
       12) maritl=1. Never Married 76  10 0 (0.86842105 0.13157895) *
       13) maritl=2. Married,3. Other 226  99 1 (0.43805310 0.56194690)  
         26) education=3. Some College,4. College Grad 179  85 1 (0.47486034 0.52513966)  
           52) age< 28.5 8   1 0 (0.87500000 0.12500000) *
           53) age>=28.5 171  78 1 (0.45614035 0.54385965)  
            106) year< 2005.5 85  39 0 (0.54117647 0.45882353)  
              212) jobclass=2. Information 47  18 0 (0.61702128 0.38297872) *
              213) jobclass=1. Industrial 38  17 1 (0.44736842 0.55263158) *
            107) year>=2005.5 86  32 1 (0.37209302 0.62790698)  
              214) race=2. Black,3. Asian 20   8 0 (0.60000000 0.40000000) *
              215) race=1. White 66  20 1 (0.30303030 0.69696970) *
         27) education=5. Advanced Degree 47  14 1 (0.29787234 0.70212766)  
           54) age>=33.5 41  14 1 (0.34146341 0.65853659)  
            108) age< 46.5 14   6 0 (0.57142857 0.42857143) *
            109) age>=46.5 27   6 1 (0.22222222 0.77777778)  
              218) race=2. Black,3. Asian 7   2 0 (0.71428571 0.28571429) *
              219) race=1. White 20   1 1 (0.05000000 0.95000000) *
           55) age< 33.5 6   0 1 (0.00000000 1.00000000) *
      7) health_ins=1. Yes 946 218 1 (0.23044397 0.76955603)  
       14) age< 31.5 149  70 1 (0.46979866 0.53020134)  
         28) education=3. Some College 65  23 0 (0.64615385 0.35384615)  
           56) age< 25.5 17   2 0 (0.88235294 0.11764706) *
           57) age>=25.5 48  21 0 (0.56250000 0.43750000)  
            114) age>=30.5 8   1 0 (0.87500000 0.12500000) *
            115) age< 30.5 40  20 0 (0.50000000 0.50000000)  
              230) jobclass=1. Industrial 23   9 0 (0.60869565 0.39130435) *
              231) jobclass=2. Information 17   6 1 (0.35294118 0.64705882) *
         29) education=4. College Grad,5. Advanced Degree 84  28 1 (0.33333333 0.66666667)  
           58) year< 2004.5 23  10 0 (0.56521739 0.43478261)  
            116) age< 29.5 16   5 0 (0.68750000 0.31250000) *
            117) age>=29.5 7   2 1 (0.28571429 0.71428571) *
           59) year>=2004.5 61  15 1 (0.24590164 0.75409836) *
       15) age>=31.5 797 148 1 (0.18569636 0.81430364)  
         30) education=3. Some College 260  77 1 (0.29615385 0.70384615)  
           60) year< 2003.5 41  20 1 (0.48780488 0.51219512)  
            120) jobclass=2. Information 18   5 0 (0.72222222 0.27777778) *
            121) jobclass=1. Industrial 23   7 1 (0.30434783 0.69565217)  
              242) age< 38 6   2 0 (0.66666667 0.33333333) *
              243) age>=38 17   3 1 (0.17647059 0.82352941) *
           61) year>=2003.5 219  57 1 (0.26027397 0.73972603) *
         31) education=4. College Grad,5. Advanced Degree 537  71 1 (0.13221601 0.86778399)  
           62) maritl=1. Never Married,3. Other 114  29 1 (0.25438596 0.74561404)  
            124) year< 2004.5 28  12 1 (0.42857143 0.57142857)  
              248) age>=52 6   2 0 (0.66666667 0.33333333) *
              249) age< 52 22   8 1 (0.36363636 0.63636364) *
            125) year>=2004.5 86  17 1 (0.19767442 0.80232558) *
           63) maritl=2. Married 423  42 1 (0.09929078 0.90070922) *

Coding-wise, note that:

  • We have preceded the call of the rpart() function with a random seed. This is to make sure that the internal cross-validation results returned by rpart() are reproducible.
    (No random seed was used in CHUNK 6 of Section 5.2 because the number of folds used equals the number of training observations, making the training/validation splits non-random.)
  • The command parms = list(split = "gini") instructs that the large tree with cp = 0 is fitted using Gini as the tree-building criterion. This command is, in fact, unnecessary because Gini is the default criterion for classification trees, but the code in almost all exams testing classification trees has this line specified for clarity.

 

Interpreting the output of a classification tree

As we learned in Section 5.2, a decision tree (classification trees included) can be displayed in two more or less equivalent ways:

List Form

When we type the name of a classification tree (tree 1 in this case), we will have at a single glance a lot of useful information about all of the tree splits made. For each node, the output still shows the splitting criterion in the split column, the number of training observations in the n column, and the prediction (more precisely, the predicted class) in the yval column. Unlike a regression tree, we no longer have the RSS, which is now replaced by the number of misclassifications in each node, as shown in the loss column, along with the proportion of observations in each class (in this case, “0” and “1”, in this order) of the categorical target variable, as shown in the (yprob) column. The output of Tree 1 is formidably long. Let’s take the first split for illustration.

n= 2101 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

  1) root 2101 941 1 (0.44788196 0.55211804)  
    2) education=1. < HS Grad,2. HS Grad 853 295 0 (0.65416178 0.34583822)
    3) education=3. Some College,4. College Grad,5. Advanced Degree 1248 383 1 (0.30689103 0.69310897)

The output says that Tree 1 starts with the root node, where there are:

  • 2101 observations;
  • 941 of which are misclassified the majority;
  • (55.21 %) of the observations in the training set are high earners and so:
  • “1” is the predicted class (see the yval column), but these 941 observations are low earners.

Then education defines the first split, meaning that out of all the predictors, partitioning the data on the basis of education (as a factor variable) leads to the greatest reduction in impurity (measured by Gini) at this point. The 2101 observations in the training set are differentiated as follows:

Case 1

Those belonging to “1. < HS Grad” or “2. HS Grad” will be sent to the left branch, or node 2, where:

    • only 34.583822% of the 853 observations are high earners, so:
    • the predicted class is “0”.
    • These 295 observations are misclassified.

Case 2

Those belonging to “3. Some College“, “4. College Grad” or “5. Advanced Degree” will be sent to the right branch, or node 3, where:

    • 69.310897% of the 1248 observations are high earners and so:
    • the predicted class is “1”.
    • 30.689103% of the observations, or 383 workers, are low earners and are misclassified.

The fact that education appears in the first split of Tree 1 indicates that it is the most distinguishing predictor for whether a worker is a high earner, which is intuitive and consistent with the bivariate data exploration we performed in Task 3.

After the first split is made, you can see that the number of misclassifications decreases from 941 (in the root note) to only 295 + 383 = 678 in nodes 2 and 3 combined, which is a decent amount of improvement. The splitting continues in each of the two new child nodes recursively until the impurity of the model cannot be improved by more than the cp parameter. We can see that some terminal nodes such as nodes 42, 52, 55, 114, 117, 189, 218, 242, and 248 contain fewer than 10 observations, which may suggest that Tree 1 is overfitted.

 

Graphical Form

Figure 5.3.2 visualizes Tree 1 and shows the following for each node:

    • The predicted class (top value)
    • The proportion of observations in that node lying in the second class of the target variable, i.e., high earners (middle value)
    • The proportion of training observations belonging to that node (bottom value)
# CHUNK 9 (cont.)
rpart.plot(tree1, tweak = 2)

Figure 5.3.2: The first classification tree constructed for the Wage data

For instance, node 3, which is the right branch coming out of the root node, contains 59% of the training observations and has a predicted class of “1” since 69% of the observations there are high earners. Note that these percentages are approximate values. To get more precise values, we have to refer to the tree output in list form.

As you can see in Figure 5.3.2, Tree 1 is way overblown, with a total of 36 splits (equivalently, 37 terminal nodes). The labels of the tree nodes and the splitting criteria are barely legible even if we set the tweak argument to 2 to make the texts double the size of the default.

Remember that the classification tree is intended for use by individuals who may not be familiar with predictive analytics, so interpretability and ease of implementation should be one of the key factors when deciding on which model to use, in addition to prediction accuracy.

 

Complexity parameter tables and plots

To produce a classification tree of reasonable size for both predictions and interpretation, let’s look at the complexity parameter table, or the cptable component of Tree 1, in CHUNK 10.

# CHUNK 10
tree1$cptable
             CP nsplit rel error    xerror       xstd
1  0.2794899044      0 1.0000000 1.0000000 0.02422262
2  0.0297555792      1 0.7205101 0.7810840 0.02323092
3  0.0100956429      3 0.6609989 0.6833156 0.02244820
4  0.0085015940      5 0.6408077 0.6599362 0.02222665
5  0.0074388948      9 0.6068013 0.6567481 0.02219536
6  0.0046050301     10 0.5993624 0.6620616 0.02224737
7  0.0042507970     13 0.5855473 0.6695005 0.02231896
8  0.0031880978     15 0.5770457 0.6673751 0.02229865
9  0.0028338647     18 0.5674814 0.6726886 0.02234921
10 0.0021253985     21 0.5589798 0.6663124 0.02228845
11 0.0017711654     22 0.5568544 0.6461211 0.02208915
12 0.0015940489     28 0.5462274 0.6471838 0.02209991
13 0.0007084662     30 0.5430393 0.6599362 0.02222665
14 0.0005000000     36 0.5387885 0.6737513 0.02235924

Just like a regression tree, the cptable of a classification tree shows how the relative error (which, in a classification setting, is the number of misclassification errors of the corresponding tree scaled by the number of misclassification errors of the tree with no split) and scaled cross-validation error vary with the complexity parameter.

As the complexity parameter decreases, the tree becomes more complex and the relative training error decreases consistently, as shown in the rel error column. We know that, however, it is the cross-validation error shown in the xerror column that is an accurate measure of the predictive performance of trees of different sizes. For Tree 1, the cross-validation error generally decreases down the cptable and reaches its minimum at 22 splits, beyond which it starts to rise.

Besides examining the entries of the xerror column of the cptable, we can visualize the behavior of the cross-validation error by means of the complexity parameter plot, which exhibits how the cross-validation error of a decision tree varies with the complexity parameter. Run the rest of CHUNK 10 to use the plotcp() function to generate such a plot for Tree 1 (Figure 5.3.3).

The top of the plot shows the size of the tree measured by the number of terminal nodes (which equals the number of splits plus one) associated with different values of the complexity parameter.

# CHUNK 10 (cont.)
plotcp(tree1)

Figure 5.3.3: Complexity parameter plot for Tree 1 in the Wage data

Tree 2: Pruning Tree 1 using the minimizer of xerror

Based on the cptable or cp plot, a common way to simplify a decision tree is to prune it using the cp value that corresponds to the lowest cross-validation error (xerror). In the case of Tree 1, this minimum cp value is 0.0017711654, which is the value in the 11th row of the cp table. In CHUNK 11, we pass this minimum value to the cp argument of the prune() function applied to Tree 1 to get a classification tree with 22 splits or 23 terminal nodes. Let’s call this reduced tree Tree 2.

Find cp value with the lowest cross-validation error

# CHUNK 11 
# Get the minimum cp vaLue 
cp.min <- tree1$cptable[which.min(tree1$cptable[, "xerror"]), "CP"] 
cp.min
[1] 0.001771165

 

Prune the Tree

# Prune the tree
tree2 <- prune(tree1, cp = cp.min)

 

EXAM NOTE

If you find the command tree1$cptable[which.min(tree1$cptable [, "xerror"]), "CP"] too complex, on the exam you may want to directly specify cp to be any value between the 10th and 11th values in the cptable, e.g., tree2 <- prune(tree1, cp = 0.002)

 

Tree 3: Pruning Tree 1 using the one-standard-error rule

Even with pruning, Tree 2 is still too bulky and complex. However, we can observe from the cptable that many of the smaller trees (i.e. , those with 5 to 21 splits) have a cross-validation error which is comparable to that of Tree 1 while having a much smaller size. To identify a tree that is much more interpretable but is comparably predictive, one common practice is to employ the one-standard-error rule:

To select the smallest tree whose cross-validation error is within one standard error from the minimum cross-validation error.

In the current setting, this cutoff level equals 0.6461211 + 0.02208915 = 0.66821025, which is also shown in the complexity parameter plot (Figure 5.3.3) as a dotted line. Among all trees with a cross-validation error within one standard error from the minimum cross-validation error, the simplest tree has 5 splits (or 6 terminal nodes) tree, as we can see from the cptable, with a complexity parameter of 0.0085015940. In CHUNK 12, we prune Tree 1 using this cp value to produce a much simpler tree, called Tree 3 and coded tree3.

# CHUNK 12 
tree3 <- prune(tree1, cp = tree1$cptable[4, "CP"]) 
tree3 
n= 2101 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 2101 941 1 (0.4478820 0.5521180)  
   2) education=1. < HS Grad,2. HS Grad 853 295 0 (0.6541618 0.3458382) *
   3) education=3. Some College,4. College Grad,5. Advanced Degree 1248 383 1 (0.3068910 0.6931090)  
     6) health_ins=2. No 302 137 0 (0.5463576 0.4536424)  
      12) maritl=1. Never Married 76  10 0 (0.8684211 0.1315789) *
      13) maritl=2. Married,3. Other 226  99 1 (0.4380531 0.5619469) *
     7) health_ins=1. Yes 946 218 1 (0.2304440 0.7695560)  
      14) age< 31.5 149  70 1 (0.4697987 0.5302013)  
        28) education=3. Some College 65  23 0 (0.6461538 0.3538462) *
        29) education=4. College Grad,5. Advanced Degree 84  28 1 (0.3333333 0.6666667) *
      15) age>=31.5 797 148 1 (0.1856964 0.8143036) *
# PLot the tree 
rpart.plot(tree3)

Interpreting Tree 3

Compared to Trees 1 and 2, Tree 3 is considerably simpler and more interpretable. It only has five splits and six terminal nodes, starting with education, which appears to be the most influential predictor, then further splitting the larger bucket (59% of the training observations) by health_ins, maritl, age, or education again.

Examining each split in the tree, we can see that the characteristics of each resulting node make sense in relation to drivers of being a high earner. For instance:

  • The first split differentiates relatively low education levels on the left, where the predicted class is low earners, from relatively high education levels on the right, where the predicted class is high earners. As we noted in Task 3, it makes perfect sense that workers with a higher education level are more likely to be a high earner, all else equal.
  • Node 7 is further partitioned using age and 32 (or 31.5) as the cutoff point. Workers younger (resp. older) than 32 have a lower (resp. higher) probability of being a high earner, consistent with what we saw in Task 3.

This split has a peculiarity: Nodes 7, 14, and 15 all share the same predicted class (high earners), so splitting by age does not reduce the number of misclassification errors. Why is then the split performed at all? While the split has no effect on misclassification errors, it does reduce the Gini index of the classification tree and lead to improved node purity.

Remember that the command parms = list(split = "gini") in CHUNK 9 specified Gini as the criterion for building the largest classification tree with cp = 0.

 

EXAM NOTE

The December 2018 PA model solutions says that “only some [candidates] interpreted [a decision tree] fully.” Here are things you can comment on (this list is not exhaustive):

  • How many splits or terminal nodes the tree has.
  • Which variables appear to be the most informative or influential as shown in the first few splits, and whether that makes sense or not (hopefully so!).
  • Which terminal nodes contain the most observations.
  • For classification trees, do point out at least some interesting combinations of the feature values that lead to the event of interest.

 

Using Tree 3 to make a prediction

As an illustration, let’s apply Tree 3 to make a prediction for a sample worker. Suppose that we have a 30-year old, married, Asian worker who has health insurance and is a college graduate. This worker will travel along different branches of Tree 3 and be classified as follows:

Step 1. Belonging to “4. College Grad” rather than “1. < HS Grad” or “2. HS Grad”, the worker is assigned to the right in the first split.

Step 2. With health insurance, he violates the condition health_ins = 2. No, so he continues to go to the right in the second split.

Step 3. Of age 30, he fulfills the criterion age < 32 and is sent to the left in the third split.

Step 4. In the fourth and last split, he does not satisfy the criterion education = 3. Some College and arrives at the right terminal node (Node 29), where the proportion of high earners is 0.67 and will serve as the predicted probability that the worker is a high earner. As the probability exceeds 0.5, the predicted class is high earner.

Instead of using the classification rules depicted in Tree 3 to make the prediction manually, we can ask R to do the work for us, but this requires some coding effort. We will defer the use of the predict() function for decision trees in the next task.

 

Options of the rpart.plot() function

Before moving on to the next task, let’s digress a bit to explore some of the commonly used options of the rpart.plot() function. None of these extra options change a displayed tree materially, but some of them were used in the code provided on
past PA exams, so it pays to take a quick look at these options. It is unnecessary to memorize what they do, however. (If you are interested, feel free to run ?rpart.plot and read the function’s documentation to learn more.)

  • type = 0: This command suppresses the ingredients of the intermediate, non-terminal nodes of a decision tree. Only the splitting criteria are drawn.
  • digits: This argument controls the number of significant digits to display.
  • extra = 4: This command, only for classification trees, displays the proportions for all classes (not only the non-baseline level) of the target variable in order (in the case of wage_flag, "0" and "1") in every node.

In CHUNK 13, we add these options to the rpart.plot() function to display Tree 3. The resulting plots, which can be compared with the plot in CHUNK 12, are given below.

# CHUNK 13 
rpart.plot(tree3, type = 0)

rpart.plot(tree3, digits = 4, extra = 4)

TASK 5: Select a base classification tree

  • Evaluate the three base trees constructed in Task 4 on the test set.
  • Make a recommendation as to which tree to use.
  • Interpret the accuracy, sensitivity, and specificity of the recommended tree.

 

Comparing the predictive performance of the two trees

Now that we have constructed three different base classification trees for wage_flag on the training set with rather different degrees of complexity, it is time to evaluate and rank their predictive performance on the test set. The first thing is to generate predictions for the test observations. As you can expect, this can be accomplished again by the super versatile predict() function we have been using. When the predict() function is applied to a classification tree fitted by rpart(), the type argument can be used to indicate the desired type of predictions:

  • "prob": To produce a matrix of predicted class probabilities, one column for each class of the target variable. This is the default for a classification tree fitted by rpart().
  • "class": To produce a vector of predicted class labels based on a cutoff of 0.5.

 

Confusion Matrices

In CHUNK 14, we produce the predicted classes and use them to construct the confusion matrices (defined in Subsection 4.1.4) for the three classification trees on the test set.

# CHUNK 14 
pred1.class <- predict(tree1, newdata = data.test, type = "class") 
pred2.class <- predict(tree2, newdata = data.test, type = "class") 
pred3.class <- predict(tree3, newdata = data.test, type = "class") 

confusionMatrix(pred1.class, as.factor(data.test$wage_flag), positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 253  90
         1 149 407
                                         
               Accuracy : 0.7341         
                 95% CI : (0.704, 0.7628)
    No Information Rate : 0.5528         
    P-Value [Acc > NIR] : < 2.2e-16      
                                         
                  Kappa : 0.4546         
                                         
 Mcnemar's Test P-Value : 0.0001756      
                                         
            Sensitivity : 0.8189         
            Specificity : 0.6294         
         Pos Pred Value : 0.7320         
         Neg Pred Value : 0.7376         
             Prevalence : 0.5528         
         Detection Rate : 0.4527         
   Detection Prevalence : 0.6185         
      Balanced Accuracy : 0.7241         
                                         
       'Positive' Class : 1     

 

confusionMatrix(pred2.class, as.factor(data.test$wage_flag), positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 267  90
         1 135 407
                                          
               Accuracy : 0.7497          
                 95% CI : (0.7201, 0.7777)
    No Information Rate : 0.5528          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.4883          
                                          
 Mcnemar's Test P-Value : 0.003353        
                                          
            Sensitivity : 0.8189          
            Specificity : 0.6642          
         Pos Pred Value : 0.7509          
         Neg Pred Value : 0.7479          
             Prevalence : 0.5528          
         Detection Rate : 0.4527          
   Detection Prevalence : 0.6029          
      Balanced Accuracy : 0.7415          
                                          
       'Positive' Class : 1   

 

confusionMatrix(pred3.class, as.factor(data.test$wage_flag), positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 298 143
         1 104 354
                                          
               Accuracy : 0.7253          
                 95% CI : (0.6948, 0.7542)
    No Information Rate : 0.5528          
    P-Value [Acc > NIR] : < 2e-16         
                                          
                  Kappa : 0.4494          
                                          
 Mcnemar's Test P-Value : 0.01561         
                                          
            Sensitivity : 0.7123          
            Specificity : 0.7413          
         Pos Pred Value : 0.7729          
         Neg Pred Value : 0.6757          
             Prevalence : 0.5528          
         Detection Rate : 0.3938          
   Detection Prevalence : 0.5095          
      Balanced Accuracy : 0.7268          
                                          
       'Positive' Class : 1 

 

The test accuracy of the three classification trees is ordered as follows:

Tree 3 < Tree 1 < Tree 2

As the cptable in CHUNK 10 shows, Tree 2 has the best out-of-sample performance, which, not surprisingly, translates into the highest accuracy on the test set.

 

AUC

Instead of ranking the three classification trees on the basis of their accuracy, sensitivity, or specificity on the test set based on a given cutoff (which is 0.5 by default), we can compare them using a cutoff-free metric such as the test AUC. In CHUNK 15, we first generate the predicted probabilities for the positive class “1” on the test set. Then we feed these predicted probabilities into the roc() function (remember that we first used this function in Section 4.3).

# CHUNK 15 
library(pR0C) 
# Extract the predicted probabilities 
pred1.prob <- predict(tree1, newdata = data.test, type = "prob")[, 2] 
pred2.prob <- predict(tree2, newdata = data.test, type = "prob")[, 2] 
pred3.prob <- predict(tree3, newdata = data.test, type = "prob")[, 2]
roc(data.test$wage_flag, pred1.prob)
roc(data.test$wage_flag, pred2.prob)
roc(data.test$wage_flag, pred3.prob)
> roc(data.test$wage_flag, pred1.prob)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = data.test$wage_flag, predictor = pred1.prob)

Data: pred1.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
Area under the curve: 0.8096

> roc(data.test$wage_flag, pred2.prob)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = data.test$wage_flag, predictor = pred2.prob)

Data: pred2.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
Area under the curve: 0.8135

> roc(data.test$wage_flag, pred3.prob)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = data.test$wage_flag, predictor = pred3.prob)

Data: pred3.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
Area under the curve: 0.759

Findings:

The three test AUCs, equal to 0.8096, 0.8135, and 0.759 for Trees 1, 2, and 3, respectively, have the same order as the test accuracy: (This is not necessarily true for other datasets.)

Tree 3 < Tree 1 < Tree 2

As mentioned above, the predict() function when applied to a classification tree produces, by default, a matrix of predicted probabilities, so we use the subscript [, 2] to extract the predicted probabilities for the second class, "1". Alternatively, we can specify the class “1” explicitly in the second subscript:

pred1.prob <- predict(tree1, newdata = data.test, type = "prob")[, "1"]
pred2.prob <- predict(tree2, newdata = data.test, type = "prob")[, "1"]
pred3.prob <- predict(tree3, newdata = data.test, type = "prob")[, "1"]

 

Which tree is the best?

The three classification trees have different (though not substantially different) prediction performance measured by quantities based on the confusion matrices and the AUC. We can associate the different performance with the different levels of tree complexity as follows:

  • Tree 1: Tree 1 is the most complex tree out of the three trees. The performance metrics suggest that it may be unnecessarily complex and have overfitted the data.
  • Tree 3: Tree 3 is the opposite of Tree 1. While it is the most compact tree and easiest to interpret, it may have underfitted the data and fails to capture the signal in the data sufficiently, as evidenced by its relatively low test accuracy and markedly low test AUC.
  • Tree 2: Tree 2 appears to have reached the optimal level of tree complexity (with respect to cp) and has the best prediction performance out of the three trees. It is simpler but more predictive than Tree 1 and strikes a good balance between the other two trees.

 

If we have to recommend one tree to use, then it is reasonable to suggest using Tree 2 on the basis of test set performance. Even though Tree 3 is appealing with respect to interpretability, its inferior prediction performance is a cause for concern.

Based on Tree 2, we can say that:

  • Accuracy: It correctly classifies 74.97% (= (267 + 407) / 899) of new, unseen workers in the test set to their correct class.
  • Sensitivity: Among all low earners in the test set, it correctly classifies 66.42% (= 267 / (267 + 135)) of them to be low earners. (By default, R takes the first class of wage_flag as the positive class)
  • Specificity: Among all high earners in the test set, it correctly classifies 81.89% (= 407 / (407 + 90)) of them to be high earners.

 

Alternative to rpart: Fitting decision trees using caret

Besides the rpart package, the caret package, which we have been using to generate training and test sets and confusion matrices, can also be used to fit decision trees and automate the process of selecting the complexity parameter.

To fit a single tree, the rpart package is, in my opinion, easier to use and good enough to produce most of the information we need in Exam PA. The syntax of the tree-building functions in the caret package, called trainControl() and train(), is more complex, but we will see in the next subsection that the caret package provides a convenient platform for streamlining the construction and evaluation process for a wide range of predictive models. In particular, it is very useful for tuning parameters when ensemble tree models are fitted and used to make predictions. We will defer a detailed discussion of the caret package for fitting decision trees to the next subsection.

 

Exercise (Practice with the predict() function)

We previously dealt with the last item in Task 4 by visually inspecting Tree 3.

Write simple R code in terms of the predict() function to complete the same item.

 

Solution

We have to set up in advance a data frame containing the combinations of variable values of interest.

# CHUNK 16 
sample_worker <- data.frame(year = 2021, 
              		    age = 30, 
              		    maritl = 11 2. Married", 
              		    race = "3. Asian", 
              		    education = "4. College Grad", 
              		    jobclass = "1. Industrial", 
              		    health = "1. <=Good", 
              		    health_ins = "1. Yes") 

Note that even though Tree 3 does not make use of year, jobclass, or health as split variables, the training set includes these three variables, so they must be present in the data frame. The values and levels of these three variables, however, can be arbitrarily set without affecting the prediction.

Now we feed the newly created data frame to the newdata argument of the predict() function to get the desired prediction.

# CHUNK 16 (Cont.) 
predict(tree3, newdata = sample_worker)
          0         1
1 0.3333333 0.6666667

As expected, the predicted probability of being a high earner is 0.6667, which is the same as what we got before.

 

Exercise (Explore how decision trees handle non-linearity)

To help you appreciate how decision trees handle non-linear relationships, your supervisor has kindly provided code (see CHUNK 17) for fitting a classification tree for wage_flag using age as the only predictor on the training data.

  • Run the code to fit and display the tree.
  • Using the fitted tree as an example, describe how decision trees treat numeric predictors that have a non-linear relationship with the target variable.

Your assistant, who has not passed Exam PA yet, is curious about what happens to the classification tree if the square of age is manually added as an extra feature.

  • Explain, without fitting a new classification tree, the effect of adding the square of age.
  • Confirm your thoughts by running a new classification tree for wage_flag using age and the square of age as the only predictors.

 

Solution

In Subsection 5.1.1, we mentioned that decision trees are capable of handling non-linear relationships even if their users do not identify the non-linearity in advance. In this experimental exercise (which is independent of Tasks 4 and 5), we are going to see this in action using the age variable for illustration. We saw in Task 3 that age has a quadratic relationship with the proportion of high earners. How does a decision tree take care of this quadratic relationship?

In CHUNK 17, we fit a classification tree called tree.age for wage_flag on the training set using age as the only predictor and the default control parameters.

# CHUNK 17 
tree.age <- rpart(wage_flag ~ age, method = "class" , data = data.train) 
tree.age
n= 2101 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 2101 941 1 (0.44788196 0.55211804)  
   2) age< 25.5 155  14 0 (0.90967742 0.09032258) *
   3) age>=25.5 1946 800 1 (0.41109969 0.58890031)  
     6) age< 31.5 241 108 0 (0.55186722 0.44813278) *
     7) age>=31.5 1705 667 1 (0.39120235 0.60879765)  
      14) age>=64.5 52  17 0 (0.67307692 0.32692308) *
      15) age< 64.5 1653 632 1 (0.38233515 0.61766485) *
rpart.plot(tree.age, digits = 4)

From the tree above, the predicted probability of being a high earner behaves as a function of age as follows:

Range of Age Predicted Probability
[0, 26) 0.0903
[26, 32) 0.4481
[32, 65) 0.6177
[65, ∞) 0.3269

We can see that the predicted probability increases with age up to a certain point, beyond which it drops, exhibiting a downward parabolic shape consistent with the bivariate data exploration we did in Task 3. This illustrate show a decision tree accommodates non-linear relationships: It divides the feature space recursively into a set of mutually exclusive regions, each of which has a possibly different target mean.

As a function of the numeric predictor of interest, these target means can behave in a highly non-linear fashion, depending on the true relationship between the numeric predictor and the target variable. What is most amazing is that the user of the decision tree need not supply any extra features like polynomial terms manually.

What happens to the fitted classification tree above if we add the square of age as an additional feature? As we learned in Subsection 5.1.1, the presence of age2 will have no impact on the fitted tree. This is because age2 is a monotone function of age (as long as age is positive), so a split based on age2 divides the feature space in the same way as a split based on age itself. A split with age2 > 1,600 versus age2 ≤ 1,600, for instance, is the same as age > 40 versus age ≤ 40. The introduction of age2 will not inject extra flexibility to the classification tree.

In CHUNK 18, we confirm that the classification tree remains unchanged even if we revise the tree in CHUNK 17 by adding age2 as an extra feature. (Recall from Subsection 3.3.3 that the I() function is needed so that we can perform usual arithmetic in a formula in R.)

# CHUNK 18 
tree.age2 <- rpart(wage_flag ~ age + I(age^2), 
                   method = "class", 
                   data = data.train) 
tree.age2 
n= 2101 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

 1) root 2101 941 1 (0.44788196 0.55211804)  
   2) age< 25.5 155  14 0 (0.90967742 0.09032258) *
   3) age>=25.5 1946 800 1 (0.41109969 0.58890031)  
     6) age< 31.5 241 108 0 (0.55186722 0.44813278) *
     7) age>=31.5 1705 667 1 (0.39120235 0.60879765)  
      14) age>=64.5 52  17 0 (0.67307692 0.32692308) *
      15) age< 64.5 1653 632 1 (0.38233515 0.61766485) *
rpart.plot(tree.age2, digits = 4)

So .. .don’t bother to add polynomial terms when you fit decision trees on the exam!

 

Construction and Evaluation of Ensemble Trees

In this subsection, we turn to ensemble classification trees, including random forests and boosted trees, and see if the use of ensemble trees will lead to a significant improvement on predictive performance over the base trees fitted in Subsection 5.3.2 (we, of course, hope for a good, if not significant improvement!). Coding-wise, we will switch from the rpart package to the caret package. The code will be more cumbersome, but it is (very!) likely that code chunks will be provided on the exam, as in the December 2019 PA exam and the Student Success sample project (if not, consult ?trainControl and ?train!), so you won’t need to write complex code from scratch. When studying this section, pay more attention to the tuning parameters you can control.

 

TASK 6: Construct random forests

Your assistant has provided code (see CHUNKs 19-21) to construct three random forests with different numbers of base trees for predicting the probability
that a worker is a high earner.

  • Run the code to construct these three random forests on the training set.
  • Construct the confusion matrix and calculate the AUC for each random forest on the test set.
  • Describe the differences in the results with reference to the ntree parameter.
  • Recommend one random forest to use. For the recommended random forest, construct a variable importance plot and a partial dependence plot for the most important numeric variable. Interpret the two plots.

 

Ensemble Tree 1: Random Forest

In R, there are several packages for fitting a random forest, such as randomForest and caret, both of which are available on the exam. Here we will use the caret package because its functions provide a convenient unifying framework for building different types of models, including random forests and boosted trees, and for tuning tree parameters by cross-validation. In the December 2019 PA exam, the caret package was used to construct and tune a random forest and a boosted tree, so more exposure to functions in this package should be beneficial.

In the caret package, a generic function for fitting a number of classification and regression models is train(), which is the workhorse of caret, but this function is often preceded by the trainControl() function, which allows you to specify the type of resampling scheme for tuning hyperparameters. In Exam PA, the most important arguments of the trainControl() function are:

  • method: This argument determines the resampling method to be used. Common choices in Exam PA (as suggested by the PA e-learning modules and past exams) include "cv" and "repeatedcv", corresponding to cross-validation and repeated cross-validation, respectively.
  • number and repeats: The number argument specifies the number of folds used in k-fold cross-validation and the repeats argument, applicable only if method = "repeatedcv", controls how many times cross-validation is performed.

In the first part of CHUNK 19, we request that 5-fold cross-validation (number = 5) repeated 3 times (repeats = 3) for robustness (remember, the folds are created randomly) be conducted for tuning the random forest parameters (in fact, only mtry; see below) and under-sampling (sampling = "down") be applied. Because wage_flag is a rather balanced binary variable, the use of undersampling is only for illustration purposes and not absolutely necessary. (You can also try oversampling by specifying sampling = "up".)

# CHUNK 19 
# Set the controls 
ctrl <- trainControl(method = "repeatedcv", 
                     number = 5, # 5- fold CV 
                     repeats = 3, # 5- fold CV repeated 3 times 
                     sampling = "down") # undersampling

 

Given the manner in which cross-validation will be performed, we can proceed to set up a “grid” containing all possible combinations of the values of the tuning parameter(s). The model training algorithm will then iterate over these parameter values to determine the optimal combination based on the cross-validation results. Technically, the grid should be a data frame whose rows correspond to the combinations of tuning parameter values of interest and whose columns are named to identify the tuning parameters.

Among all the parameters of a random forest, the number of features considered in each split, represented by the mtry parameter, is arguably the most important (in fact, the caret package only supports the tuning of mtry for a random forest model). Its default value is \(\sqrt{p}\) for a classification tree and p/3 for a regression tree, where p is the number of predictors.

In the second part of CHUNK 19, we use the expand.grid() function to set up a data frame known as rf.grid containing the possible values of mtry (in this case, 1 to 5) that we would like to consider. (Here we follow the PA e-learning modules and the December 2019 PA exam, and use the expand.grid() function, which is actually not needed for this simple case. This function, however, will prove useful in CHUNK 26 below when we deal with multiple tuning parameters and will be explained there.)

# CHUNK 19 (Cont.) 
# Setup the tuning grid 
rf.grid <- expand.grid(mtry = 1:5) 
rf.grid
  mtry
1    1
2    2
3    3
4    4
5    5

The control parameters and tuning grid are then passed to the trControl and tuneGrid arguments of the train() function to “train” a predictive model (with hyperparameters tuned by cross-validation). In what follows, we are going to construct three random forests on the training set of the Wage data with different numbers of trees to grow.

In CHUNK 20, we first fit a random forest with 5 individual trees, called rf1, and generate its output.

# CHUNK 20 
# Setup the x and y variabLes 
target <- factor(data.train$wage_flag) 
predictors <- data.train[, -9] 

# Uncomment the next Line the first time you use randomForest 
# instaill.packages("randomForest") 

# Train the first random forest 
set.seed(20) # because cross-vaLidation will be done 
rf1 <- train( 
  y = target, 
  x = predictors, 
  method = "rf", # use the randomForest aLgorithm 
  ntree = 5, # only 5 trees (default = 500) 
  importance = TRUE, 
  trControl = ctrl, 
  tuneGrid = rf.grid 
) 

The syntax of the train() function is quite involved. Let’s look at its arguments in turn:

  • The train() function accepts the usual formula and data arguments, like glm() and rpart(), but here we have opted to use the x and y arguments, containing the observations of the predictors and target variable, respectively. Unlike the glmnet() function, the x argument is a data frame, not a matrix, of predictor values/levels. This alternative specification is not discussed at all in the PA e-learning modules, but was adopted in the December 2019 PA exam. The Rmd template of the exam says that:

“This set-up for y (the target) and x (the predictors) produces a more accurate model,”

but does not explain why. In fact, the main difference between the two specifications is that the formula interface automatically binarizes categorical variables, but the x-y specification does not. As we learned in Subsection 5.1.1, fitting decision trees with categorical variables binarized imposes unnecessary restrictions and is undesirable.

  • We have specified method = "rf" to instruct the train() function to construct a random forest (which is what "rf" stands for). In general, the method argument specifies the modeling method or the type of predictive model used, e.g., setting method = "rpart" fits a base tree using rpart().

  • The ntree parameter determines the number of trees to be grown. The default value is 500. In theory, it is desirable to have as many trees as possible, as long as computational resources allow, to ensure that each observation and each feature is represented at least once in the random forest, and that the variance reduction contributed by averaging is sufficient. On the exam, you may want to set ntree to a relatively small value (e.g., not more than 100), especially if your dataset is large, to ease computation burden (e.g., the dataset in Section 7.3 of the PA e-learning modules has about 350,000 observations, for which ntree = 500 will freeze your computer for one day!).
  • With importance = TRUE, importance scores of the predictors will be computed. (The default is importance = FALSE.)

Now let’s view the output of the random forest.

# CHUNK 20 (Cont.)
# View the output
rf1
Random Forest 

2101 samples
   8 predictor
   2 classes: '0', '1' 

No pre-processing
Resampling: Cross-Validated (5 fold, repeated 3 times) 
Summary of sample sizes: 1681, 1681, 1680, 1681, 1681, 1680, ... 
Addtional sampling using down-sampling

Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
  1     0.6798398  0.3493491
  2     0.6871251  0.3717869
  3     0.6783976  0.3520302
  4     0.6657083  0.3291085
  5     0.6655571  0.3298912

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 2.
ggplot(rf1) # OR simply plot(rf1)

Unlike a base decision tree, typing the name of a random forest will not return how the tree splits of the individual trees are defined (remember, there are a total of ntree trees fitted!); instead, various aspects of the random forest are shown:

  • the number of training observations (2101),
  • predictors (8),
  • levels of the target variable (“0” and “1”),
  • the resampling method used (5-fold cross-validation, repeated 3 times), and,
  • most importantly, a profile of the model performance over all combinations of hyperparameter values specified in the tuning grid.
    This profile can be visualized by applying the ggplot() function (or simply plot()) to a train object. For a classification (ensemble) tree, by default the combination of hyperparameter values that gives rise to the highest accuracy (The kappa statistic that is beside the Accuracy column is not discussed in the PA e-learning modules.) is selected, as you can see from the sentence:

Accuracy was used to select the optimal model using the largest value.

The last line of the output says that the optimal value of mtry is 2. This means that for best out-of-sample performance, only two predictors should be randomly sampled and considered in every split of every base tree built. (You can also directly access the best combination of hyperparameters by extracting the best Tune component of a train object, which is a list, i.e., by running rf1$bestTune) Then the random forest is fitted to the whole training set with mtry = 2.

 

EXAM NOTE

Don’t be scared by the code in CHUNK 20. It is extremely unlikely for you to write code involving the train() function from scratch on the exam (there are so many arguments!). You only need to understand what each argument does and, if necessary, make modest changes.

 

In more or less the same way, in CHUNK 21 we construct two more random forests on the training set, one with ntree = 20 and one with ntree = 100, to study the impact of the ntree parameter.

# CHUNK 21
set.seed(50)

# Train the second random forest
rf2 <- train( 
  y = target, 
  x = predictors, 
  method = "rf", # use the randomForest aLgorithm 
  ntree = 20, # increased to 20
  importance = TRUE, 
  trControl = ctrl, 
  tuneGrid = rf.grid 
)

# View the output
rf2
Random Forest 

2101 samples
   8 predictor
   2 classes: '0', '1' 

No pre-processing
Resampling: Cross-Validated (5 fold, repeated 3 times) 
Summary of sample sizes: 1681, 1681, 1681, 1681, 1680, 1681, ... 
Addtional sampling using down-sampling

Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
  1     0.7117230  0.4159599
  2     0.7004562  0.3958890
  3     0.6904656  0.3766939
  4     0.6784070  0.3547542
  5     0.6682581  0.3347144

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 1.
# Train the third random forest
rf3 <- train( 
  y = target, 
  x = predictors, 
  method = "rf", # use the randomForest aLgorithm 
  ntree = 100, # increased to 100
  importance = TRUE, 
  trControl = ctrl, 
  tuneGrid = rf.grid 
) 

# View the output
rf3
Random Forest 

2101 samples
   8 predictor
   2 classes: '0', '1' 

No pre-processing
Resampling: Cross-Validated (5 fold, repeated 3 times) 
Summary of sample sizes: 1681, 1681, 1680, 1681, 1681, 1681, ... 
Addtional sampling using down-sampling

Resampling results across tuning parameters:

  mtry  Accuracy   Kappa    
  1     0.7147329  0.4214448
  2     0.7060136  0.4057759
  3     0.6942748  0.3843967
  4     0.6793575  0.3570236
  5     0.6688923  0.3350292

Accuracy was used to select the optimal model using the largest value.
The final value used for the model was mtry = 1.

It turns out that for both random forests, the optimal value of mtry is 1.

 

Confusion Matrices and AUC of the Random Forests

In CHUNK 22, we apply the three random forests to classify each case in the test set, generate their confusion matrices, and compute their AUCs. In contrast to an rpart object, for which the predict() function outputs predicted probabilities by default, the default type of predictions returned by predict() applied to an object of the train class is the predicted class (which can be specified explicitly by the type = "raw" option if you like).

# CHUNK 22 
pred.rf1.class <- predict(rf1, newdata = data.test, type = "raw") 
pred.rf2.class <- predict(rf2, newdata = data.test, type = "raw") 
pred.rf3.class <- predict(rf3, newdata = data.test, type = "raw") 

confusionMatrix(pred.rf1.class, as.factor(data.test$wage_flag), positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 288 148
         1 114 349
                                          
               Accuracy : 0.7086          
                 95% CI : (0.6777, 0.7381)
    No Information Rate : 0.5528          
    P-Value [Acc > NIR] : < 2e-16         
                                          
                  Kappa : 0.4153          
                                          
 Mcnemar's Test P-Value : 0.04148         
                                          
            Sensitivity : 0.7022          
            Specificity : 0.7164          
         Pos Pred Value : 0.7538          
         Neg Pred Value : 0.6606          
             Prevalence : 0.5528          
         Detection Rate : 0.3882          
   Detection Prevalence : 0.5150          
      Balanced Accuracy : 0.7093          
                                          
       'Positive' Class : 1  
confusionMatrix(pred.rf2.class, as.factor(data.test$wage_flag), positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 323 136
         1  79 361
                                          
               Accuracy : 0.7608          
                 95% CI : (0.7316, 0.7884)
    No Information Rate : 0.5528          
    P-Value [Acc > NIR] : < 2.2e-16       
                                          
                  Kappa : 0.5228          
                                          
 Mcnemar's Test P-Value : 0.0001339       
                                          
            Sensitivity : 0.7264          
            Specificity : 0.8035          
         Pos Pred Value : 0.8205          
         Neg Pred Value : 0.7037          
             Prevalence : 0.5528          
         Detection Rate : 0.4016          
   Detection Prevalence : 0.4894          
      Balanced Accuracy : 0.7649          
                                          
       'Positive' Class : 1  
confusionMatrix(pred.rf3.class, as.factor(data.test$wage_flag), positive = "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 314  98
         1  88 399
                                          
               Accuracy : 0.7931          
                 95% CI : (0.7651, 0.8191)
    No Information Rate : 0.5528          
    P-Value [Acc > NIR] : <2e-16          
                                          
                  Kappa : 0.5825          
                                          
 Mcnemar's Test P-Value : 0.5093          
                                          
            Sensitivity : 0.8028          
            Specificity : 0.7811          
         Pos Pred Value : 0.8193          
         Neg Pred Value : 0.7621          
             Prevalence : 0.5528          
         Detection Rate : 0.4438          
   Detection Prevalence : 0.5417          
      Balanced Accuracy : 0.7920          
                                          
       'Positive' Class : 1 

The test accuracy of the three random forests is increasing as we go from rf1, rf2 to rf3. In other words, the larger the ntree parameter, the more accurate the predictions produced by the random forests on the test set. The fact that increasing ntree leads to better prediction performance is further confirmed by the test AUCs, which we generate in CHUNK 23.

# CHUNK 23 
# Add the type = "prob" option to return predicted probabilities 
pred.rf1.prob <- predict(rf1, newdata = data.test, type = "prob")[, 2]
pred.rf2.prob <- predict(rf2, newdata = data.test, type = "prob")[, 2]
pred.rf3.prob <- predict(rf3, newdata = data.test, type = "prob")[, 2]

roc(data.test$wage_flag , pred.rf1.prob) 
roc(data.test$wage_flag , pred.rf2.prob) 
roc(data.test$wage_flag , pred.rf3.prob)
> roc(data.test$wage_flag , pred.rf1.prob)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = data.test$wage_flag, predictor = pred.rf1.prob)

Data: pred.rf1.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
Area under the curve: 0.7883
> roc(data.test$wage_flag , pred.rf2.prob) 
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = data.test$wage_flag, predictor = pred.rf2.prob)

Data: pred.rf2.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
Area under the curve: 0.8391
> roc(data.test$wage_flag , pred.rf3.prob) 
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = data.test$wage_flag, predictor = pred.rf3.prob)

Data: pred.rf3.prob in 402 controls (data.test$wage_flag 0) < 497 cases (data.test$wage_flag 1).
Area under the curve: 0.8469

Consistent with the rankings based on the test accuracy, rf1 has the lowest test AUC and rf3 has the highest. These results illustrate the general phenomenon we learned earlier that building more base trees tends to improve the prediction performance of a random forest. With more base trees, the variance reduction contributed by averaging becomes more significant and the model predictions become more precise. Of the three random forests , it is reasonable to recommend using rf3, which is most predictive. After all, 100 trees only entail a small amount of computational power.

 

Variable Importance Plots

Because ensemble trees consist of potentially hundreds of decision trees, it is difficult, if not impossible, to interpret the relationship between the predictors and the target variable using a series of easy-to-understand classification rules like what a single decision tree shows. Fortunately, there are two graphical devices that try to overcome this model opacity issue (for both random forests and boosted trees) .

One useful tool is a variable importance plot, which ranks the predictors according to their importance scores. The importance score for a particular predictor is computed by totaling the drop in node impurity (RSS for regression trees and Gini index for classification trees) due to that predictor, averaged over all the trees in the ensemble tree.

In general, variables that are used to form most of the top splits in the individual trees lead to larger improvements in node purity and therefore are more important as captured by the variable importance score.

In CHUNK 24, we use the varimp() function to compute the variable importance of each of the eight predictors in the Wage data and make a variable importance plot for rf3. To show relative importance more clearly, the importance scores have been scaled so that the most important predictor has a score of 100 and the predictors are sorted in descending order of variable importance.

# CHUNK 24
imp <- varImp(rf3) 
imp
rf variable importance

           Importance
education      100.00
health_ins      97.97
age             63.09
maritl          59.23
health          26.71
jobclass        26.01
race            10.22
year             0.00
plot(imp, main = "Variable Importance of Classification Random Forest")

We can make a few observations from the importance plot above:

  • education is the most important predictor of being a high earner, followed by health_ins, age, and maritl. These four predictors also appear in the top splits of the base trees back in Task 4, so rf3 and the base trees agree on the most influential predictors.
    (Note: What if they don’t? In most cases, the order of importance indicated by a random forest should be more credible than that of a base tree because a random forest is based on a total of ntree trees rather than a single tree, which is notorious for being unstable.)
  • The year variable is the least important predictor, with an importance of exactly zero. This means that it was not used in any of the 100 base trees in rf3.

These findings align well with the exploratory data analysis we performed in Task 3.

 

Partial Dependence Plots

Although variable importance plots tell us which predictors are most influential, the importance scores do not shed light on the relationship ( directional effect, in particular) between the predictors and the target variable. In other words, we know that a certain variable contributes significantly to the target variable, but whether that contribution is positive, negative, or follows a more complex relationship remains unknown. Partial (or average) dependence plots fill this gap and attempt to visualize the effect of a given variable on the model prediction after averaging the values or levels of other variables. Looking at these plots, we can gain some insights into how the target variable “depends” on the predictors.

Mathematically, consider a target variable Y and p predictors X1, … , Xp. The partial dependence of a predictive model for Y on X1 is defined as:

\(PD(x_1):=\dfrac{1}{n_{tr}}\sum\nolimits_{i=1}^{n_{tr}}{\hat{f}(x_1,x_{i2}…,x_{ip})}\),

where:

  • \(\hat{f}\) is the fitted signal function (i.e., the predicted target mean)
  • x1 is a given value of X1
  • {(xi2, …, Xip)} is the collection of observed values of 2, …, Xp in the training set.

In essence, PD(x1) is simply the model predictions averaged over all the observed values of X2, …, Xp in the training set while keeping the value of X1 fixed at x1. We can then examine the behavior of PD(x1) as a function of x1 with the goal of understanding how X1 affects the target variable.

Let’s look at a real partial dependence plot and try to interpret it. In R, partial dependence plots can be generated by the partial() function in the pdp package (available on the exam) and the variable of interest is specified as a character string in the pred.var argument. CHUNK 25 produces the partial dependence plot for age, which is the most influential numeric predictor for wage_flag.

# CHUNK 25
library(pdp)
partial(rf3, train = data.train, pred.var = "age", 
        plot = TRUE, rug = TRUE, smooth = TRUE) 

For a categorical target variable, the predictions are on the logit scale, meaning that what is shown on the vertical axis of a partial dependence plot is \(\ln(\dfrac{\hat{p}}{1-\hat{p}})\), and the plot above shows how the odds varies with the value of age, with a blue smoothed curve superimposed. Here it is important to note that p is the predicted probability that the target variable belongs to its first level, or, in this case study, the predicted probability that a worker is a low earner (because “0” precedes “1”). As we can see, the odds that a worker is a low earner (resp. high earner) decreases (resp. increases) up to about age 50, beyond which it increases (resp. decreases), which conforms to the findings in Task 3.

 

TASK 7: Construct a boosted tree

Your assistant has also provided code (see CHUNK 26) to construct a boosted tree for predicting the probability that a worker is a high earner.

  • Run the code to construct the boosted tree.
  • Construct the confusion matrix and calculate the AUC of the boosted tree on the test set.

This task is a version of Task 6 for boosted trees.

 

Ensemble Tree 2: Boosted Model

The second ensemble tree we will fit is a boosted tree, also by the train() function in the caret package. Compared to a random forest, a boosted tree has a lot more parameters to tune and requires more coding effort. To begin with, we set up the grid of tuning parameters in the first part of CHUNK 26.

# CHUNK 26 
# Setup the tuning grid 
xgb.grid <- expand.grid(max_depth = 7, 
                        min_child_weight = 1, 
                        gamma = 0, 
                        nrounds = c(10, 50, 100), 
                        eta = c(0.01, 0.1), 
                        colsample_bytree = 0.6, 
                        subsample = 0.6) 

xgb.grid 
  max_depth min_child_weight gamma nrounds  eta colsample_bytree subsample
1         7                1     0      10 0.01              0.6       0.6
2         7                1     0      50 0.01              0.6       0.6
3         7                1     0     100 0.01              0.6       0.6
4         7                1     0      10 0.10              0.6       0.6
5         7                1     0      50 0.10              0.6       0.6
6         7                1     0     100 0.10              0.6       0.6

Note that all of the seven parameters need to be specified; missing one of them will lead to an error when a boosted tree is constructed. Parameters that we are interested in tuning (i.e., nrounds and eta) are given a vector of possible values and those that are not of interest are assigned a single value. These parameters are described below:

  • max_depth, min_child_weight, gamma: These parameters control the complexity of the underlying trees. At least one of these parameters should be tuned, but more can be tuned if computational resources allow.
  • colsample_bytree, subsample: These two parameters determine the proportion of features and observations used in each individual tree, respectively. Typically, at least the proportion of features should be tuned.
  • (Important!) nrounds: This parameter controls the maximum number of trees to grow or iterations in the model fitting process. It is often set to around 1,000, large enough so that sufficient trees are grown to capture the signal in the data effectively but not excessively large to avoid overfitting. When a good fit has been achieved, the algorithm may stop early.
  • (Important!) eta: This is the learning rate or shrinkage parameter, a scalar multiple between 0 and 1 that applies to the contribution of each tree. The higher the learning rate, the faster the model will reach optimality and the fewer the number of iterations required, though the resulting model will more likely overfit. Typically, eta is set to values between 0.01 and 0.2.

 

EXAM NOTE

On the exam, it is likely that you will be directly given the tuning parameters of a boosted tree. You can also type ?xgboost to learn more about these parameters. There is no need to memorize their definitions.

 

The expand.grid() function creates a data frame called xgb.grid containing all combinations of the possible values of the tuning parameters. Passing this grid and the control parameters to the train() function and setting method = "xgbTree" (meaning an “extreme gradient boosting tree”), we build the boosted tree on the training set of the Wage data in the remainder of CHUNK 26. (For some reason, the train() function applied with method = "xgbTree" does not allow the x-y specification, so we use the usual formula interface. Do not forget to treat wage_flag as a factor variable!)

# CHUNK 26 (cont.)
# Set the controls
ctrl <- trainControl(method "cv", 
                     munber = 5, 
                     sampling= "down")

# Train the boosted tree
set.seed(42) 
xgb.tuned <- train(as.factor(wage_flag) ~ ., 
                   data = data.train, 
                   method = "xgbTree", 
                   trControl = ctrl, 
                   tuneGrid = xgb.grid)

# View the output
xgb.tuned 

eXtreme Gradient Boosting 

2101 samples
   8 predictor
   2 classes: '0', '1' 

No pre-processing
Resampling: Cross-Validated (5 fold, repeated 3 times) 
Summary of sample sizes: 1681, 1681, 1681, 1681, 1680, 1681, ... 
Addtional sampling using down-sampling

Resampling results across tuning parameters:

  eta   nrounds  Accuracy   Kappa    
  0.01   10      0.6977593  0.3908443
  0.01   50      0.7085447  0.4124306
  0.01  100      0.7053731  0.4059247
  0.10   10      0.7039502  0.4034157
  0.10   50      0.6985552  0.3929323
  0.10  100      0.6893519  0.3752518

Tuning parameter 'max_depth' was held constant at a value of 7
Tuning parameter 'gamma' was held

Tuning parameter 'min_child_weight' was held constant at a value of 1
Tuning parameter 'subsample'
 was held constant at a value of 0.6
Accuracy was used to select the optimal model using the largest value.
The final values used for the model were nrounds = 50, max_depth = 7, eta = 0.01, gamma =
 0, colsample_bytree = 0.6, min_child_weight = 1 and subsample = 0.6.
ggplot(xgb.tuned)

The output shows that the optimal values of nrounds and eta are 50 and 0.01, respectively.

 

NOTE

Depending on the version of R and/or the version of the xgboost package you are using, you may get slightly different results in CHUNK 26. Regardless, the results can be interpreted in essentially the same way.

 

The test set performance of the boosted tree is evaluated in CHUNK 27.

# CHUNK 27 
pred.xgb.class <- predict(xgb.tuned, newdata = data.test, type= "raw") 
confusionMatrix(pred.xgb.class, as.factor(data.test$wage_flag), positive= "1")
Confusion Matrix and Statistics

          Reference
Prediction   0   1
         0 314 126
         1  88 371
                                          
               Accuracy : 0.762           
                 95% CI : (0.7327, 0.7895)
    No Information Rate : 0.5528          
    P-Value [Acc > NIR] : < 2e-16         
                                          
                  Kappa : 0.5228          
                                          
 Mcnemar's Test P-Value : 0.01143         
                                          
            Sensitivity : 0.7465          
            Specificity : 0.7811          
         Pos Pred Value : 0.8083          
         Neg Pred Value : 0.7136          
             Prevalence : 0.5528          
         Detection Rate : 0.4127          
   Detection Prevalence : 0.5106          
      Balanced Accuracy : 0.7638          
                                          
       'Positive' Class : 1  
pred.xgb.prob <- predict(xgb.tuned, newdata = data.test, type = "prob")[, 2]
roc(as.numeric(data.test$wage_flag), pred.xgb.prob) 
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = as.numeric(data.test$wage_flag), predictor = pred.xgb.prob)

Data: pred.xgb.prob in 402 controls (as.numeric(data.test$wage_flag) 0) < 497 cases (as.numeric(data.test$wage_flag) 1).
Area under the curve: 0.842

The accuracy of the boosted tree on the test set is 0. 7686 and the test AUC is 0.8423, which are comparable to those of the third random forest in Task 6.

 

TASK 8: Select the final model

Recommend which model should be used, the base tree selected in Task 5, the random forest in Task 6, or the boosted tree in Task 7. Do not base your recommendation solely on predictive performance.

 

Ensemble Trees vs. Base Trees: Which one to use here?

You probably have this feeling when fitting the random forests in Task 6 and the boosted tree in Task 7: What do they actually look like and what are they really trying to do? This is typical of the “black box” nature of ensemble methods, which often fare well in terms of prediction accuracy, but their users tend to have little clue as to how they work and how to interpret the model results.

In this case study, we have constructed a number of tree-based models to predict the probability that a worker is a high earner. The following table summarizes the prediction performance of these models:

Model Test Accuracy Test AUC
Tree 2 (Task 5) 0.7497 0.8135
Random Forest 3 (Task 6) 0.7898 0.8469
Boosted Tree (Task 7) 0.7686 0.8423

 

We can see that the random forest and the boosted tree do result in an improvement in prediction performance measured by the test accuracy and the test AUC, but the improvement appears to be incremental compared to base tree 2 and is achieved at the expense of a significant loss of interpretability. Since one of the key considerations when selecting a predictive model in this case study is the ease of interpretation, it seems judicious to forgo a small amount of prediction accuracy for interpretability and to recommend a pruned base tree like Tree 2, which makes reasonably accurate predictions and lends itself to easy interpretation.

 

EXAM NOTE

Exam tasks that require you to recommend a model usually don’t have a clear-cut answer. The quality of your justification is more important than your final decision. In the current context, a case can also be made in favor of the random forest and the boosted tree if you can justify why the societal benefits brought by the improvement in prediction performance may outweigh the loss of interpretability. The examiners will particularly appreciate a response that keeps the business context of the exam project well in mind.