For a hands-on learning experience to develop LLM applications, join our LLM Bootcamp today.
First 6 seats get an early bird discount of 30%! So hurry up!

classification method

Complete the tutorial to revisit and master the fundamentals of decision trees and classification models, one of the simplest and easiest models to explain.

Introduction

Data Scientists use machine learning techniques to make predictions under a variety of scenarios. Machine learning can be used to predict whether a borrower will default on his mortgage or not, or what might be the median house value in a given zip code area. Depending upon whether the prediction is being made for a quantitative variable or a qualitative variable, a predictive model can be categorized as a regression model (e.g. predicting median house values) or a classification (e.g. predicting loan defaults) model.

Decision trees happen to be one of the simplest and easiest classification models to explain and, as many argue, closely resemble human decision-making.

This tutorial has been developed to help you revisit and master the fundamentals of decision tree classification models which are expanded on in Data Science Dojo’s data science bootcamp and online data science certificate program. Our key focus will be to discuss the:

  1. Fundamental concepts on data-partitioning, recursive binary splitting, nodes, etc.
  2. Data exploration and data preparation for building classification models
  3. Performance metrics for decision tree models – Gini Index, Entropy, and Classification Error.

The content builds your classification model knowledge and skills in an intuitive and gradual manner.


The scenario

You are a Data Scientist working at the Centers for Disease Control (CDC) Division for Heart Disease and Stroke Prevention. Your division has recently completed a research study to collect health examination data among 303 patients who presented with chest pain and might have been suffering from heart disease.

The Chief Data Scientist of your division has asked you to analyze this data and build a predictive model that can accurately predict patients’ heart disease status, identifying the most important predictors of heart failure. Once your predictive model is ready, you will make a presentation to the doctors working at the health facilities where the research was conducted.

The data set has 14 attributes, including patients’ age, gender, blood pressure, cholesterol level, and heart disease status, indicating whether the diagnosed patient was found to have heart disease or not. You have already learned that to predict quantitative attributes such as “blood pressure” or “cholesterol level”, regression models are used, but to predict a qualitative attribute such as the “status of heart disease,”  classification models are used.

Classification models can be built using different techniques such as Logistic Regression, Discriminant Analysis, K-Nearest Neighbors (KNN), Decision Trees, etc. Decision Trees are very easy to explain and can easily handle qualitative predictors without the need to create dummy variables.

Although decision trees generally do not have the same level of predictive accuracy as the K-Nearest Neighbor or Discriminant Analysis techniques, They serve as building blocks for other sophisticated classification techniques such as “Random Forest” etc. which makes mastering Decision Trees, necessary!

We will now build decision trees to predict the status of heart disease i.e. to predict whether the patient has heart disease or not, and we will learn and explore the following topics along the way:

  • Data preparation for decision tree models
  • Classification trees using “rpart” package
  • Pruning the decision trees
  • Evaluating decision tree models

## You will need following libraries for this exercise 
library(dplyr) 
library(tidyverse)
library(ggplot2)
library(rpart)
library(rpart.plot)
library(rattle)
library(RColorBrewer)

## Following code will help you suppress the messages and warnings during package loading      
options(warn = -1) 

The data

You will be working with the Heart Disease Data Set which is available at UC Irvine’s Machine Learning Repository. You are encouraged to visit the repository and go through the data description. As you will find, the data folder has multiple data files available. You will use the processed.cleveland.data.

Let’s read the datafile into a data frame “cardio”

## Reading the data into "cardio" data frame
cardio <- read.csv("processed.cleveland.data", header = FALSE, na.strings = '?')            
## Let's look at the first few rows in the cardio data frame  
head(cardio)
V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V11 V12 V13 V14
63 1 1 145 233 1 2 150 0 2.3 3 0 6 0
67 1 4 160 286 0 2 108 1 1.5 2 3 3 2
67 1 4 120 229 0 2 129 1 2.6 2 2 7 1
37 1 3 130 250 0 0 187 0 3.5 3 0 3 0
41 0 2 130 204 0 2 172 0 1.4 1 0 3 0
56 1 2 120 236 0 0 178 0 0.8 1 0 3 0

As you can see, this data frame doesn’t have column names. However, we can refer to the data dictionary, given below, and add the column names:

Column Position Attribute Name Description Attribute Type
#1 Age Age of Patient Quantitative
#2 Sex Gender of Patient Qualitative
#3 CP Type of Chest Pain (1: Typical Angina, 2: Atypical Angina, 3: Non-anginal Pain, 4: Asymptomatic) Qualitative
#4 Trestbps Resting Blood Pressure (in mm Hg on admission) Quantitative
#5 Chol Serum Cholestrol in mg/dl Quantitative
#6 FBS (Fasting Blood Sugar>120 mg/dl) 1=true; 0=false Qualitative
#7 Restecg Resting ECG results (0=normal; 1 and 2 = abnormal) Qualitative
#8 Thalach Maximum heart Rate Achieved Quantitative
#9 Exang Exercise Induced Angina (1=yes; 0=no) Qualitative
#10 Oldpeak ST Depression Induced by Exercise Relative to Rest Quantitative
#11 Slope The slope of peak exercise st segment (1=upsloping; 2=flat; 3=downsloping) Qualitative
#12 CA Number of major vessels (0-3) colored by flourosopy Qualitative
#13 Thal Thalassemia (3=normal; 6=fixed defect; 7=reversable defect) Qualitative
#14 NUM Angiographic disease status (0=no heart disease; more than 0=no heart disease) Qualitative

The following code chunk will add column names to your data frame:

## Adding column names to dataframe 
names(cardio) <- c( "age", "sex", "cp", "trestbps", "chol","fbs", "restecg", 
                           "thalach","exang", "oldpeak","slope", "ca", "thal", "status")

You are going to build a decision tree model to predict values under variable #14 status, the “angiographic disease status” which labels or classifies each patient as “having heart disease” or “not having heart disease.

Intuitively, we expect some of these other 13 variables to help us predict the values under status. In other words, we expect variables #1 to #13, to segment the patients or create partitions in the cardio data frame in a manner that any given partition (or segment) thus created either has patients as “having heart disease” or “not having heart disease.


Data preparation for decision trees

It is time to get familiar with the data. Let’s begin with data types.

## We will use str() function  
str(cardio)
'data.frame':	303 obs. of  14 variables:
 $ age      : num  63 67 67 37 41 56 62 57 63 53 ...
 $ sex      : num  1 1 1 1 0 1 0 0 1 1 ...
 $ cp       : num  1 4 4 3 2 2 4 4 4 4 ...
 $ trestbps : num  145 160 120 130 130 120 140 120 130 140 ...
 $ chol     : num  233 286 229 250 204 236 268 354 254 203 ...
 $ fbs      : num  1 0 0 0 0 0 0 0 0 1 ...
 $ restecg  : num  2 2 2 0 2 0 2 0 2 2 ...
 $ thalach  : num  150 108 129 187 172 178 160 163 147 155 ...
 $ exang    : num  0 1 1 0 0 0 0 1 0 1 ...
 $ oldpeak  : num  2.3 1.5 2.6 3.5 1.4 0.8 3.6 0.6 1.4 3.1 ...
 $ slope    : num  3 2 2 3 1 1 3 1 2 3 ...
 $ ca       : num  0 3 2 0 0 0 2 0 1 0 ...
 $ thal     : num  6 3 7 3 3 3 3 3 7 7 ...
 $ status   : int  0 2 1 0 0 0 3 0 2 1 ...

As you can see, some qualitative variables in our data frame are included as quantitative variables

  • status is declared as $$ which makes it a quantitative variable but we know disease status must be qualitative
  • You can see that sexcpfbsrestecgexang,  slopeca, and thal too
    must be qualitative

The next code-chunk will convert and correct the datatypes:

## We can use lapply to convert data types across multiple columns  
cardio[c("sex", "cp", "fbs","restecg", "exang", 
                     "slope", "ca", "thal", "status")] <- lapply(cardio[c("sex", "cp", "fbs","restecg",
                                                                         "exang", "slope", "ca", "thal", "status")], factor)
## You can verify the data frame 
str(cardio)
'data.frame':	303 obs. of  14 variables:
 $ age     : num  63 67 67 37 41 56 62 57 63 53 ...
 $ sex     : Factor w/ 2 levels "0","1": 2 2 2 2 1 2 1 1 2 2 ...
 $ cp      : Factor w/ 4 levels "1","2","3","4": 1 4 4 3 2 2 4 4 4 4 ...
 $ trestbps: num  145 160 120 130 130 120 140 120 130 140 ...
 $ chol    : num  233 286 229 250 204 236 268 354 254 203 ...
 $ fbs     : Factor w/ 2 levels "0","1": 2 1 1 1 1 1 1 1 1 2 ...
 $ restecg : Factor w/ 3 levels "0","1","2": 3 3 3 1 3 1 3 1 3 3 ...
 $ thalach : num  150 108 129 187 172 178 160 163 147 155 ...
 $ exang   : Factor w/ 2 levels "0","1": 1 2 2 1 1 1 1 2 1 2 ...
 $ oldpeak : num  2.3 1.5 2.6 3.5 1.4 0.8 3.6 0.6 1.4 3.1 ...
 $ slope   : Factor w/ 3 levels "1","2","3": 3 2 2 3 1 1 3 1 2 3 ...
 $ ca      : Factor w/ 4 levels "0","1","2","3": 1 4 3 1 1 1 3 1 2 1 ...
 $ thal    : Factor w/ 3 levels "3","6","7": 2 1 3 1 1 1 1 1 3 3 ...
 $ status  : Factor w/ 5 levels "0","1","2","3",..: 1 3 2 1 1 1 4 1 3 2 ...

Also, note that status has 5 different values viz. 0, 1, 2, 3, 4. While status = 0, indicates no heart disease, all other values under status indicate a heart disease. In this exercise, you are building a decision tree model to classify each patient as “normal”(not having heart disease) or “abnormal” (having heart disease)”.

Therefore, you can merge status = 1, 2, 3, and 4 into a single-level status = “1”. This way you will convert status into a  Binary or Dichotomous variable having only two values status = “0” (normal) and status = “1” (abnormal)

Let’s do that!

##  We will use the 'forcats' package included in the s'tidyverse' package
##  The function to be used will be fct_collpase 
cardio$status <- fct_collapse(cardio$status, "1" = c("1","2", "3", "4"))  


## Let's also change the labels under the "status" from (0,1) to (normal, abnormal)  
levels(cardio$status) <- c("normal", "abnormal")  

## levels under sex can also be changed to (female, male)   
## We can change level names in other categorical variables as well but we are not doing that  
levels(cardio$sex) <- c("female", "male")  

So, you have corrected the data types. What’s next?

How about getting a summary of all the variables in the data?

## Overall summary of all the columns 
summary(cardio)
      age            sex      cp         trestbps          chol       fbs    
 Min.   :29.00   female: 97   1: 23   Min.   : 94.0   Min.   :126.0   0:258  
 1st Qu.:48.00   male  :206   2: 50   1st Qu.:120.0   1st Qu.:211.0   1: 45  
 Median :56.00                3: 86   Median :130.0   Median :241.0          
 Mean   :54.44                4:144   Mean   :131.7   Mean   :246.7          
 3rd Qu.:61.00                        3rd Qu.:140.0   3rd Qu.:275.0          
 Max.   :77.00                        Max.   :200.0   Max.   :564.0

 restecg    thalach      exang      oldpeak     slope      ca        thal    
 0:151   Min.   : 71.0   0:204   Min.   :0.00   1:142   0   :176   3   :166  
 1:  4   1st Qu.:133.5   1: 99   1st Qu.:0.00   2:140   1   : 65   6   : 18  
 2:148   Median :153.0           Median :0.80   3: 21   2   : 38   7   :117  
         Mean   :149.6           Mean   :1.04           3   : 20   NA's:  2  
         3rd Qu.:166.0           3rd Qu.:1.60           NA's:  4             
         Max.   :202.0           Max.   :6.20                                

       status   
 normal  :164  
 abnormal:139  


Did you notice the missing values (NAs) under the ca and thal columns? With the following code, you can count the missing values across all the columns in your data frame.

# Counting the missing values in the datframe 
sum(is.na(cardio))
6

Only 6 missing values across 303 rows which is approximately 2%. That seems to be a very low proportion of missing values. What do you want to do with these missing values, before you start building your decision tree model?

  • Option 1: discard the missing values before training.
  • Option 2: rely on the machine learning algorithm to deal with missing values during the model training.
  • Option 3: impute missing values before training.

For most learning methods, Option 3 the imputation approach is necessary. The simplest approach is to impute the missing values by the mean or median of the non-missing values for the given feature.

The choice of Option 2 depends on the learning algorithm. Learning algorithms such as CART and rpart simply ignore missing values when determining the quality of a split. To determine, whether a case with a missing value for the best split is to be sent left or right, the algorithm uses surrogate splits. You may want to read more on this here.

However, if the relative amount of missing data is small, you can go for Option 1 and discard the missing values as long as it doesn’t lead to or further alleviate the class imbalance which is briefly discussed in the following section.

As for your data set, you are safe to delete missing value cases. The following code-chunk does that for you.

## Removing missing values  
cardio <- na.omit(cardio)

Data exploration

Status is the variable that you want to predict with your model. As we have discussed earlier, other variables in the cardio dataset should help you predict status.

For example, amongst patients with heart disease, you might expect the average value of Cholesterol levels (chol), to be higher than amongst those who are normal. Likewise, amongst patients with high blood sugar (fbs = 1), the proportion of patients with heart disease would be expected to be higher than what it is amongst normal patients. You can do some data visualization and exploration.

You may want to start with a distribution of status. The following code-chunk will provide you with:

## plotting a histrogram for status
cardio %>%
          ggplot(aes(x = status)) + 
          geom_histogram(stat = 'count', fill = "steelblue") +
          theme_bw()

From this histogram, you can observe that there is almost an equal split between patients having status as normal and abnormal.

This may not always be the case. There might be datasets in which one of the classes in the predicted variable has a very low proportion. Such datasets are said to have a class imbalance problem where one of the classes in the predicted variable is rare within the dataset.

Credit Card Fraud Detection Model or a Mortgage Loan Default Model are some examples of classification models that are built with a dataset having a class imbalance problem. What other scenarios come to your mind?

You are encouraged to read this article: ROSE: A Package for Binary Imbalanced Learning

You should now explore the distribution of quantitative variables. You can make density plots with frequency counts on the Y-axis and split the plot by the two levels in the status variable.

The following code will produce the plots arranged in a grid of 2 rows

## frequency plots for quantitative variables, split by status  
cardio %>%
  gather(-sex, -cp, -fbs, -restecg, -exang, -slope, -ca, -thal, -status, key = "var", value = "value") %>%
            ggplot(aes(x = value, y = ..count.. , colour = status)) +
            scale_color_manual(values=c("#008000", "#FF0000"))+
            geom_density() +
            facet_wrap(~var, scales = "free",  nrow = 2) +
            theme_bw()

What are your observations from the quantitative plots? Some of your observations might be:

  • In all the plots, as we move along the X-axis, the abnormal curve, mostly but not always, lies below the normal curve. You should expect this, as the total number of patients with abnormal is
    smaller. However, for some values on the X-axis (which could be smaller values of X or larger, depending upon the predictor), the abnormal curve lies above.
  • For example, look at the age plot. Till x = 55 years, the majority of patients are included in the normal curve. Once x > 55 years, the majority goes to patients
    with
    abnormal and remains so until x = 68 years. Intuitively, age could be a good predictor of status and you may want to partition the data at x = 55 years
    and then again at x = 68 years. When you build your decision tree model, you may expect internal nodes with x > 55 years and x > 68 years.
  • Next, observe the plot for chol. Except for a narrow range (x = 275 mg/dl to x = 300 mg/dl), the normal curve always lies above the abnormal curve. You may want to
    form a hypothesis that Cholesterol is not a good predictor of status. In other words, you may not expect chol to be amongst the earliest internal nodes in your decision
    tree model.

Likewise, you can make hypotheses for other quantitative variables as well. Of course, your decision tree model will help you validate your hypothesis.

Now you may want to turn your attention to qualitative variables.

## frequency plots for qualitative variables, split by status  
cardio %>%
       gather(-age, -trestbps, -chol, -thalach, -oldpeak, -status, key = "var", value = "value") %>%
        ggplot(aes(x = value, color = status)) + 
         scale_color_manual(values=c("#008000", "#FF0000"))+
          geom_histogram(stat = 'count', fill = "white") +
          facet_wrap(~var, nrow = 3) +
          facet_wrap(~var, scales = "free",  nrow = 3) +
          theme_bw()

What are your observations from the qualitative plots? How do you want to partition data along the qualitative variables?

  • Observe the cp or the chest pain plot. The presence of asymptotic chest pain indicated by cp = 4, could provide a partition in the data and could be among the earliest nodes in your decision tree.
  • Likewise, observe the sex plot. Clearly, the proportion of abnormal is much lower (approximately 25%) among females compared to the proportion among males (approximately
    50%). Intuitively, sex might also be a good predictor and you may want to partition the patients’ data along sex. When you build your decision tree model, you may expect internal nodes with sex.

At this point, you may want to go back to both plots and list down the partition (variables and, more importantly, variable values) that you expect to find in your decision tree model.

Of course, all our hypotheses will be validated once we build our decision tree model.


Partitioning data: Training and test sets

Before you start building your decision tree, split the cardio data into a training set and test set:

cardio.train: 70% of the dataset

cardio.test: 30% of the dataset

The following code-chunk will do that:

## Now you can randomly split your data in to 70% training set and 30% test set   
## You should set seed to ensure that you get the same training vs/ test split every time you run the code    
set.seed(1) 

## randomly extract row numbers in cardio dataset which will be included in the training set  
train.index <- sample(1:nrow(cardio), round(0.70*nrow(cardio),0))

## subset cardio data set to include only the rows in train.index to get cardio.train  
cardio.train <- cardio[train.index, ]

## subset cardio data set to include only the rows NOT in train.index to get cardio.test  
## Did you note the negative sign?
cardio.test <- cardio[-train.index,  ]

Classification trees using rpart

 

“rpart” Package

You will now use rpart package to build your decision tree model. The decision tree that you will build, can be plotted using packages rpart.plot or rattle which provides better-looking plots.

You will use function rpart() to build your decision tree model. The function has the following key arguments:

formula: rpart(, …)

The formula where you declare what predictors you are using in your decision tree. You can specify status ~. to indicate that you want to use all the predictors in your decision tree.

method: rpart(method = < >, …)

The same function can be used to build a decision tree as well as a regression tree. You can use “class” to specify that you are using rpart() function for building a classification tree. If you were building a regression tree, you would specify “anova” instead.

cp rpart(cp = <>,…)

The main role of the Complexity Parameter (cp) is to control the size of the decision tree. Any split that does not reduce the tree’s overall complexity by a factor of cp is not attempted. The default value is  0.01. A value of cp = 1 will result in a tree with no splits. Setting cp to negative values ensures a fully grown tree.

minsplit  rpart( minsplit = <>, …)

The minimum number of observations must exist in a node in order for a split to be attempted. The default value is 20.

minbucket  rpart( minbucket = <>, …)

The minimum number of observations in any terminal node. If only one minbucket or minsplit is specified, the code either sets minsplit to minbucket*3 or minbucket to minsplit/3, which is the default.

You are encouraged to read the package documentation rpart documentation

You can build a decision tree using all the predictors and with a cp = 0.05. The following code chunk will build your decision tree model:

## using all the predictors and setting cp = 0.05 
cardio.train.fit <- rpart(status ~ . , data = cardio.train, method = "class", cp = 0.05)

It is time to plot your decision tree. You can use the function rpart.plot() for plotting your tree. However, the function fancyRpartPlot() in the rattle package is more ‘fancy’

## Using fancyRpartPlot() from "rattle" package
fancyRpartPlot(cardio.train.fit, palettes = c("Greens", "Reds"), sub = "")

Interpreting decision tree plot

What are your observations from your decision tree plot?

Each square box is a node of one or the other type (discussed below):

Root Node cp = 1, 2, 3: The root node represents the entire population or 100% of the sample.

Decision Nodes thal = 3, and ca = 0: These are the two internal nodes that get split up either in further internal nodes or in terminal nodes. There are 3 decision nodes here.

Terminal Nodes (Leaf): The nodes that do not split further are called terminal nodes or leaves. Your decision tree has 4 terminal nodes.

The decision tree plot gives the following information:

Predictors Used in Model: Only the thalcp, and ca variables are included in this decision tree.

Predicted Probabilities: Predicted probability of a patient being normal or abnormal. Note that the two probabilities add to 100%, at each node.

Node Purities: Each node has two proportions written left and right. The leftmost leaf has 0.82 and 0.18. The number on the left, 0.82 tells you what proportion of the node actually belongs to the predicted class. You can see that this leaf has 82% purity.

Sample Proportion: Each node has a proportion of the sample. The proportion is 100% for the root node. The percentages under the split nodes add up to give the percentage in their parent node.

Predicted class: Each node shows the predicted class as normal or abnormal. It is the most commonly occurring predictor class in that node but the node might still include observations belonging to the other predictor class as well. This forms the concept of node impurity.


Fully grown decision tree

Is this the fully-grown decision tree?

No! Recall that you have grown the decision tree with the default value of cp = 0.05 which ensures that your decision tree doesn’t include any split that does not decrease the overall lack of fit by a factor of 5%.

However, if you change this parameter, you might get a different decision tree. Run the following code-chunk to get the plot of a fully grown decision tree, with a cp = 0

## using all the predictors and setting all other arguments to default 
cardioFull <- rpart(status ~ . , data = cardio.train, method = "class", cp = 0)

## Using fancyRpartPlot() from "rattle" package
fancyRpartPlot(cardioFull, palettes = c("Greens", "Reds"),sub = "")

The fully grown tree adds two more predictors thal and oldpeak to the tree that you built earlier. Now you have seen that changing the cp parameter, gives a decision tree of different sizes – more nodes and/or more leaves. At this stage, you might want to ask the following questions:

  • Which of the two decision trees you should go ahead with and present to your division’s Chief Data Scientist? The one developed with a default value of cp = 0.01 or the one with cp = 0?
  • Does a bigger decision tree present a better classification model or worse?
  • Is the default value of cp = 0.01, the best possible?
  • How would you select a cp value that ensures the best-performing decision tree model

There are no thumb rules on how large or small a decision tree should grow. However, you should be aware that:

  • large tree might overfit the data and thus might lead to a model with high variance
  • small tree might miss important parameters and thus might lead to a model with a high bias

So, which of the two decision trees you should present to your division’s Chief Data Scientist? What are the parameters that you can control to build your best decision tree? What are the metrics that you can use to justify the performance of your decision tree model? Conversely, what are the metrics that can help you evaluate the performance of your decision tree model?


Pruning the decision trees

The optimal tree size is chosen adaptively from the training data. The recommended approach is to build a fully-grown decision tree and then extract a nested sub-tree (prune it) in a way that you are left with a tree that has minimal node impurities.

As you have learned in your in-class module, there are three different metrics to calculate the node impurities that can be used for a given node m:

Gini Index:

A measure of total variance across all the classes in the predictor variable. A smaller value of G indicates a purer or more homogeneous node.

Gini Index

Here, Pmk gives the proportion of training observations in the mth region that are from the kth class.

Cross-Entropy or Deviance:

Another measure of node impurity:

Cross-Entropy or Deviance

As with the Gini index, the mth node is purer if the entropy D is smaller.

In your fitted decision tree model, there are two classes in the predictor variable therefore K = 2 and there are m = 5 regions.

Misclassification Error:

The fraction of the training observations in the mth node that do not belong to the most common class:

Misclassification Error

When growing a decision tree, Gini Index or Entropy is typically used to evaluate the quality of the split.

However, for pruning the tree, a Misclassification Error is used.

You can now get back to the fully grown decision tree that you built with cp = 0.

The Complexity Parameter Table will help you evaluate the fitted decision tree model. For your decision tree cardio.train.full, you can print the complexity parameter table using printcp() as well as plot using plotcp()

The CP table will help you select the decision tree that minimizes the misclassification error. CP table lists down all the trees nested within the fitted tree. The best-nested sub-tree can then be extracted by selecting the corresponding value for cp.

The following code will print the CP table for you:

## printing the CP table for the fully-grown tree 
printcp(cardioFull)
Classification tree:
rpart(formula = status ~ ., data = cardio.train, method = "class", 
    cp = 0)

Variables actually used in tree construction:
[1] ca      cp      oldpeak thal    thalach

Root node error: 95/208 = 0.45673

n= 208 

        CP nsplit rel error  xerror     xstd
1 0.536842      0   1.00000 1.00000 0.075622
2 0.063158      1   0.46316 0.52632 0.064872
3 0.031579      3   0.33684 0.38947 0.058056
4 0.015789      4   0.30526 0.35789 0.056138
5 0.000000      6   0.27368 0.36842 0.056794

The plotcp() gives a visual representation of the cross-validation results in an rpart object.

## plotting the cp 
plotcp(cardioFull, lty = 3, col = 2, upper = "splits" )

CP table

How do we interpret the cp table? What is your objective here?

Your objective is to prune the fitted tree i.e. select a nested sub-tree from this fitted tree, such that the cross-validated error or the xerror is the minimum.

The Complexity table for your decision tree lists down all the trees nested within the fitted tree. The complexity table is printed from the smallest tree possible (nsplit = 0 i.e. no splits) to the largest one (nsplit = 8, eight splits). The number of nodes included in the sub-tree is always 1+ the number of splits.

For easier reading, the error columns have been scaled so that the first node (nsplit = 0) has an error of 1. In your decision tree the model with no splits makes 123/267 misclassifications, you can multiply the columns rel errorxerror, and xstd by 123 to get the absolute values. In the first column, the complexity parameter has been similarly scaled. From the cp table we want to select the cp value that minimizes the cross-validated error (xerror).

CP plot

plotcp() gives a visual representation of the CP table. The Y-axis of the plot has the xerrors and the X-axis has the geometric means of the intervals of cp values, for which pruning is optimal. The red horizontal line is drawn 1-SE above the minimum of the curve. A good choice of cp for pruning is typical, the leftmost value for which the mean lies below the red line.

The following code chunk will help you select the best cp from the cp table

## selecting the best cp, corresponding to the minimum value in xerror 
bestcp <- cardioFull$cptable[which.min(cardioFull$cptable[,"xerror"]),"CP"]

## print the best cp
bestcp

0.0157894736842105

You can now use this bestcp to prune the fully-grown decision tree

## Prune the tree using the best cp.
cardio.pruned <- prune(cardioFull, cp = bestcp)
## You can now plot the pruned tree 
fancyRpartPlot(cardio.pruned, palettes = c("Greens", "Reds"), sub = "")   

You can use the summary() function to get a detailed summary of the pruned decision tree. It prints the call, the table shown by printcp, the variable importance (summing to 100), and details for each node (the details depend on the type of tree).

## printing the 
summary(cardio.pruned)  
Call:
rpart(formula = status ~ ., data = cardio.train, method = "class", 
    cp = 0)
  n= 208 

          CP nsplit rel error    xerror       xstd
1 0.53684211      0 1.0000000 1.0000000 0.07562158
2 0.06315789      1 0.4631579 0.5263158 0.06487215
3 0.03157895      3 0.3368421 0.3894737 0.05805554
4 0.01578947      4 0.3052632 0.3578947 0.05613824

Variable importance
      cp     thal    exang  thalach       ca  oldpeak trestbps      age 
      28       17       14       13       12       12        3        2 
     sex 
       1 

Node number 1: 208 observations,    complexity param=0.5368421
  predicted class=normal    expected loss=0.4567308  P(node) =1
    class counts:   113    95
   probabilities: 0.543 0.457 
  left son=2 (109 obs) right son=3 (99 obs)
  Primary splits:
      cp      splits as  LLLR,      improve=34.19697, (0 missing)
      thal    splits as  LRR,       improve=31.59722, (0 missing)
      exang   splits as  LR,        improve=23.76356, (0 missing)
      ca      splits as  LRRR,      improve=21.46291, (0 missing)
      thalach < 147.5 to the right, improve=17.90570, (0 missing)
  Surrogate splits:
      exang   splits as  LR,        agree=0.731, adj=0.434, (0 split)
      thal    splits as  LRR,       agree=0.702, adj=0.374, (0 split)
      thalach < 148.5 to the right, agree=0.683, adj=0.333, (0 split)
      ca      splits as  LRRR,      agree=0.625, adj=0.212, (0 split)
      oldpeak < 0.85  to the left,  agree=0.611, adj=0.182, (0 split)

Node number 2: 109 observations,    complexity param=0.03157895
  predicted class=normal    expected loss=0.1834862  P(node) =0.5240385
    class counts:    89    20
   probabilities: 0.817 0.183 
  left son=4 (98 obs) right son=5 (11 obs)
  Primary splits:
      oldpeak < 1.95  to the left,  improve=5.018621, (0 missing)
      slope   splits as  LRL,       improve=4.913298, (0 missing)
      thal    splits as  LRR,       improve=4.888193, (0 missing)
      ca      splits as  LRRR,      improve=3.642018, (0 missing)
      thalach < 152.5 to the right, improve=3.280350, (0 missing)

Node number 3: 99 observations,    complexity param=0.06315789
  predicted class=abnormal  expected loss=0.2424242  P(node) =0.4759615
    class counts:    24    75
   probabilities: 0.242 0.758 
  left son=6 (35 obs) right son=7 (64 obs)
  Primary splits:
      thal    splits as  LRR,       improve=8.002922, (0 missing)
      exang   splits as  LR,        improve=7.972659, (0 missing)
      ca      splits as  LRRR,      improve=7.539716, (0 missing)
      oldpeak < 0.7   to the left,  improve=3.625175, (0 missing)
      thalach < 175   to the right, improve=3.354320, (0 missing)
  Surrogate splits:
      trestbps < 116   to the left,  agree=0.717, adj=0.200, (0 split)
      oldpeak  < 0.05  to the left,  agree=0.707, adj=0.171, (0 split)
      thalach  < 175   to the right, agree=0.697, adj=0.143, (0 split)
      sex      splits as  LR,        agree=0.677, adj=0.086, (0 split)
      age      < 69.5  to the right, agree=0.667, adj=0.057, (0 split)

Node number 4: 98 observations
  predicted class=normal    expected loss=0.1326531  P(node) =0.4711538
    class counts:    85    13
   probabilities: 0.867 0.133 

Node number 5: 11 observations
  predicted class=abnormal  expected loss=0.3636364  P(node) =0.05288462
    class counts:     4     7
   probabilities: 0.364 0.636 

Node number 6: 35 observations,    complexity param=0.06315789
  predicted class=normal    expected loss=0.4857143  P(node) =0.1682692
    class counts:    18    17
   probabilities: 0.514 0.486 
  left son=12 (20 obs) right son=13 (15 obs)
  Primary splits:
      ca       splits as  LRRR,      improve=7.619048, (0 missing)
      exang    splits as  LR,        improve=6.294925, (0 missing)
      trestbps < 126.5 to the right, improve=2.519048, (0 missing)
      thalach  < 170   to the right, improve=2.057143, (0 missing)
      age      < 53.5  to the left,  improve=1.866667, (0 missing)
  Surrogate splits:
      thalach  < 134   to the right, agree=0.743, adj=0.400, (0 split)
      trestbps < 129   to the right, agree=0.714, adj=0.333, (0 split)
      exang    splits as  LR,        agree=0.686, adj=0.267, (0 split)
      oldpeak  < 1.7   to the left,  agree=0.686, adj=0.267, (0 split)
      age      < 62.5  to the left,  agree=0.657, adj=0.200, (0 split)

Node number 7: 64 observations
  predicted class=abnormal  expected loss=0.09375  P(node) =0.3076923
    class counts:     6    58
   probabilities: 0.094 0.906 

Node number 12: 20 observations
  predicted class=normal    expected loss=0.2  P(node) =0.09615385
    class counts:    16     4
   probabilities: 0.800 0.200 

Node number 13: 15 observations
  predicted class=abnormal  expected loss=0.1333333  P(node) =0.07211538
    class counts:     2    13
   probabilities: 0.133 0.867 

Evaluating decision tree models

You can now use the predict function in rpart package to predict the status of patients included in the test data cardio.test

The following code-chunk predicts the status values for test data and will also print the confusion matrix for actual v/s. predicted values:

## You can now use your pruned tree model to predict the status for your test data 
cardio.predict <- predict(cardio.pruned, cardio.test, type = "class")

You should now evaluate the performance of your model on the test data. You will use your Confusion Matrix and calculate the Classification Error in the predictions:

# confusion matrix (training data)
conf.matrix <- table(cardio.test$status, cardio.predict)
rownames(conf.matrix) <- paste("Actual", rownames(conf.matrix), sep = ":")
colnames(conf.matrix) <- paste("Predicted", colnames(conf.matrix), sep = ":")
print(conf.matrix)
                 cardio.predict
                  Predicted:normal Predicted:abnormal
  Actual:normal                 40                  7
  Actual:abnormal               14                 28

You can calculate the classification error as:

## caclulating the classification error 
round((14 + 7)/89,3)
0.236

So, your decision tree has a 23.6% prediction error. In other words, your model has been able to classify the patients as normal or abnormal with an accuracy of 76.4%. Your division’s Chief Data Scientist should be impressed. Also, you have a classification model that you can very easily explain to doctors.

However, before we wind up, here is a small exercise for you.

Small Exercise:

Decision tree models can suffer from extremely high variance. A small change in the training data can give you very different results. This short exercise is designed to make this point. In the code chunk given below change the values, one at a time, for the following parameters, run the code, and then observe how the decision tree model changes:

set.seed (a): Set the seed to a different number: ‘1234’ or ‘1729’ or ‘9999’ or whatever you like

Training set proportion (p): Set the proportion to different numbers: ‘70%’ or ‘80%’, ‘90%’ or whatever you like

You can go ahead and use the code till the calculation of the prediction error but even plotting the fitted tree would help!

## You should keep the original data frame intact so let's make a copy cardioplay  
cardioplay <- cardio 

## you set the seed to ensure that you get the same training v/s. test split every time you run the code
## Keeping all else constant, you should change the seed from '1234' to any other number 
a <- as.numeric(1234) 


## randomly extract row numbers in cardio dataset which will be included in the training set
## Keeping all else constant, you should change the proportion from '50%' to any other proportion 
p <- as.numeric(0.50)
## You don't need to make any changes in this code-chunk
## Make changes in the code-chunk just above and observe the changes in the output of this code-chunk  

## seed 
set.seed(a) 

## rows in training data 
trainset <- sample(1:nrow(cardioplay), round(p*nrow(cardioplay),0))
cardioplay.train <- cardio[trainset, ]

## rows in test data  
cardioplay.test <- cardio[-trainset,  ] 

## fit the tree 
cardioplay.train.fit <- rpart(status ~ . , data = cardioplay.train, method = "class") 

## plot the tree 
fancyRpartPlot(cardioplay.train.fit, palettes = c("Greens", "Reds"), sub = "")


Conclusion

Now, you have a good understanding of how to perform the exploratory data analysis and prepare your dataset, before you can set out to build a decision tree. You are also familiar with various functions in the rpart package with which you can build decision trees, plot the trees, and prune decision trees to build. As we have discussed earlier, there are other tree-based approaches such as BaggingRandom Forests, and Boosting which improve the accuracy.

You are all set to start practicing exercises on these advanced topics!

August 18, 2022

Learn how logistic regression fits a dataset to make predictions in R, as well as when and why to use it.

Logistic regression is one of the statistical techniques in machine learning used to form prediction models. It is one of the most popular classification algorithms mostly used for binary classification problems (problems with two class values, however, some variants may deal with multiple classes as well). It’s used for various research and industrial problems.

Therefore, it is essential to have a good grasp of logistic regression algorithms while learning data science. This tutorial is a sneak peek from many of Data Science Dojo’s hands-on exercises from their data science Bootcamp program, you will learn how logistic regression fits a dataset to make predictions, as well as when and why to use it.

In short, Logistic Regression is used when the dependent variable(target) is categorical. For example:

  • To predict whether an email is spam (1) or not spam (0)
  • Whether the tumor is malignant (1) or not (0)

Intro to Logistic Regression

It is named ‘Logistic Regression’ because its underlying technology is quite the same as Linear Regression. There are structural differences in how linear and logistic regression operate. Therefore, linear regression isn’t suitable to be used for classification problems. This link answers in detail why linear regression isn’t the right approach for classification.

Its name is derived from one of the core functions behind its implementation called the logistic function or the sigmoid function. It’s an S-shaped curve that can take any real-valued number and map it into a value between 0 and 1, but never exactly at those limits.

Logistic regression - classification technique

The hypothesis function of logistic regression can be seen below where the function g(z) is also shown.

hypothesis function

The hypothesis for logistic regression now becomes:

hypothesis function

Here θ (theta) is a vector of parameters that our model will calculate to fit our classifier.

After calculations from the above equations, the cost function is now as follows:

cost function

Here m is several training examples. Like Linear Regression, we will use gradient descent to minimize our cost function and calculate the vector θ (theta).

This tutorial will follow the format below to provide you with hands-on practice with Logistic Regression:

  1. Importing Libraries
  2. Importing Datasets
  3. Exploratory Data Analysis
  4. Feature Engineering
  5. Pre-processing
  6. Model Development
  7. Prediction
  8. Evaluation

The scenario

In this tutorial, we will be working with the Default of Credit Card Clients Data Set. This data set has 30000 rows and 24 columns. The data set could be used to estimate the probability of default payment by credit card clients using the data provided. These attributes are related to various details about a customer, his past payment information, and bill statements. It is hosted in Data Science Dojo’s repository.

Think of yourself as a lead data scientist employed at a large bank. You have been assigned to predict whether a particular customer will default on their payment next month or not. The result is an extremely valuable piece of information for the bank to make decisions regarding offering credit to its customers and could massively affect the bank’s revenue. Therefore, your task is very critical. You will learn to use logistic regression to solve this problem.

The dataset is a tricky one as it has a mix of categorical and continuous variables. Moreover, you will also get a chance to practice these concepts through short assignments given at the end of a few sub-modules. Feel free to change the parameters in the given methods once you have been through the entire notebook.

Download Exercise Files

1) Importing libraries

We’ll begin by importing the dependencies that we require. The following dependencies are popularly used for data-wrangling operations and visualizations. We would encourage you to have a look at their documentation.

library(knitr)
library(tidyverse)
library(ggplot2)
library(mice)
library(lattice)
library(reshape2)
#install.packages("DataExplorer") if the following package is not available
library(DataExplorer)

2) Importing Datasets

The dataset is available at Data Science Dojo’s repository in the following link. We’ll use the head method to view the first few rows.

## Need to fetch the excel file
path <- "https://code.datasciencedojo.com/datasciencedojo/datasets/raw/master/
Default%20of%20Credit%20Card%20Clients/default%20of%20credit%20card%20clients.csv"
data <- read.csv(file = path, header = TRUE)
head(data)
Dataset

Since the header names are in the first row of the dataset, we’ll use the code below to first assign the headers to be the one from the first row and then delete the first row from the dataset. This way we will get our desired form.

colnames(data) <- as.character(unlist(data[1,]))
data = data[-1, ]
head(data)

To avoid any complications ahead, we’ll rename our target variable “default payment next month” to a name without spaces using the code below.

colnames(data)[colnames(data)=="default payment next month"] <- "default_payment"
head(data)

3) Exploratory data analysis

Data Exploration is one of the most significant portions of the machine-learning process. Clean data can ensure a notable increase in the accuracy of our model. No matter how powerful our model is, it cannot function well unless the data we provide has been thoroughly processed.

This step will briefly take you through this step and assist you in visualizing your data, finding the relation between variables, dealing with missing values and outliers, and assisting in getting some fundamental understanding of each variable we’ll use.

Moreover, this step will also enable us to figure out the most important attributes to feed our model and discard those that have no relevance.

We will start by using the dim function to print out the dimensionality of our data frame.

dim(data)

30000 25

The str method will allow us to know the data type of each variable. We’ll transform it to a numeric data type since it’ll be easier to use for our functions ahead.

str(data)
'data.frame':	30000 obs. of  25 variables:
 $ ID             : Factor w/ 30001 levels "1","10","100",..: 1 11112 22223 23335 24446 25557 26668 27779 28890 2 ...
 $ LIMIT_BAL      : Factor w/ 82 levels "10000","100000",..: 14 5 81 48 48 48 49 2 7 14 ...
 $ SEX            : Factor w/ 3 levels "1","2","SEX": 2 2 2 2 1 1 1 2 2 1 ...
 $ EDUCATION      : Factor w/ 8 levels "0","1","2","3",..: 3 3 3 3 3 2 2 3 4 4 ...
 $ MARRIAGE       : Factor w/ 5 levels "0","1","2","3",..: 2 3 3 2 2 3 3 3 2 3 ...
 $ AGE            : Factor w/ 57 levels "21","22","23",..: 4 6 14 17 37 17 9 3 8 15 ...
 $ PAY_0          : Factor w/ 12 levels "-1","-2","0",..: 5 1 3 3 1 3 3 3 3 2 ...
 $ PAY_2          : Factor w/ 12 levels "-1","-2","0",..: 5 5 3 3 3 3 3 1 3 2 ...
 $ PAY_3          : Factor w/ 12 levels "-1","-2","0",..: 1 3 3 3 1 3 3 1 5 2 ...
 $ PAY_4          : Factor w/ 12 levels "-1","-2","0",..: 1 3 3 3 3 3 3 3 3 2 ...
 $ PAY_5          : Factor w/ 11 levels "-1","-2","0",..: 2 3 3 3 3 3 3 3 3 1 ...
 $ PAY_6          : Factor w/ 11 levels "-1","-2","0",..: 2 4 3 3 3 3 3 1 3 1 ...
 $ BILL_AMT1      : Factor w/ 22724 levels "-1","-10","-100",..: 13345 10030 10924 15026 21268 18423 12835 1993 1518 307 ...
 $ BILL_AMT2      : Factor w/ 22347 levels "-1","-10","-100",..: 11404 5552 3482 15171 16961 17010 13627 12949 3530 348 ...
 $ BILL_AMT3      : Factor w/ 22027 levels "-1","-10","-100",..: 18440 9759 3105 15397 12421 16866 14184 17258 2072 365 ...
 $ BILL_AMT4      : Factor w/ 21549 levels "-1","-10","-100",..: 378 11833 3620 10318 7717 6809 16081 8147 2129 378 ...
 $ BILL_AMT5      : Factor w/ 21011 levels "-1","-10","-100",..: 385 11971 3950 10407 6477 6841 14580 76 1796 2638 ...
 $ BILL_AMT6      : Factor w/ 20605 levels "-1","-10","-100",..: 415 11339 4234 10458 6345 7002 14057 15748 12215 3230 ...
 $ PAY_AMT1       : Factor w/ 7944 levels "0","1","10","100",..: 1 1 1495 2416 2416 3160 5871 4578 4128 1 ...
 $ PAY_AMT2       : Factor w/ 7900 levels "0","1","10","100",..: 6671 5 1477 2536 4508 2142 4778 6189 1 1 ...
 $ PAY_AMT3       : Factor w/ 7519 levels "0","1","10","100",..: 1 5 5 646 6 6163 4292 1 4731 1 ...
 $ PAY_AMT4       : Factor w/ 6938 levels "0","1","10","100",..: 1 5 5 337 6620 5 2077 5286 5 813 ...
 $ PAY_AMT5       : Factor w/ 6898 levels "0","1","10","100",..: 1 1 5 263 5777 5 950 1502 5 408 ...
 $ PAY_AMT6       : Factor w/ 6940 levels "0","1","10","100",..: 1 2003 4751 5 5796 6293 963 1267 5 1 ...
 $ default_payment: Factor w/ 3 levels "0","1","default payment next month": 2 2 1 1 1 1 1 1 1 1 ...
data[, 1:25] <- sapply(data[, 1:25], as.character)

We have involved an intermediate step by converting our data to character first. We need to use as.character before as.numeric. This is because factors are stored internally as integers with a table to give the factor level labels. Just using as.numeric will only give the internal integer codes.

data[, 1:25] <- sapply(data[, 1:25], as.numeric)
str(data)
'data.frame':	30000 obs. of  25 variables:
 $ ID             : num  1 2 3 4 5 6 7 8 9 10 ...
 $ LIMIT_BAL      : num  20000 120000 90000 50000 50000 50000 500000 100000 140000 20000 ...
 $ SEX            : num  2 2 2 2 1 1 1 2 2 1 ...
 $ EDUCATION      : num  2 2 2 2 2 1 1 2 3 3 ...
 $ MARRIAGE       : num  1 2 2 1 1 2 2 2 1 2 ...
 $ AGE            : num  24 26 34 37 57 37 29 23 28 35 ...
 $ PAY_0          : num  2 -1 0 0 -1 0 0 0 0 -2 ...
 $ PAY_2          : num  2 2 0 0 0 0 0 -1 0 -2 ...
 $ PAY_3          : num  -1 0 0 0 -1 0 0 -1 2 -2 ...
 $ PAY_4          : num  -1 0 0 0 0 0 0 0 0 -2 ...
 $ PAY_5          : num  -2 0 0 0 0 0 0 0 0 -1 ...
 $ PAY_6          : num  -2 2 0 0 0 0 0 -1 0 -1 ...
 $ BILL_AMT1      : num  3913 2682 29239 46990 8617 ...
 $ BILL_AMT2      : num  3102 1725 14027 48233 5670 ...
 $ BILL_AMT3      : num  689 2682 13559 49291 35835 ...
 $ BILL_AMT4      : num  0 3272 14331 28314 20940 ...
 $ BILL_AMT5      : num  0 3455 14948 28959 19146 ..
 $ BILL_AMT6      : num  0 3261 15549 29547 19131 ...
 $ PAY_AMT1       : num  0 0 1518 2000 2000 ...
 $ PAY_AMT2       : num  689 1000 1500 2019 36681 ...
 $ PAY_AMT3       : num  0 1000 1000 1200 10000 657 38000 0 432 0 ...
 $ PAY_AMT4       : num  0 1000 1000 1100 9000 ...
 $ PAY_AMT5       : num  0 0 1000 1069 689 ...
 $ PAY_AMT6       : num  0 2000 5000 1000 679 ...
 $ default_payment: num  1 1 0 0 0 0 0 0 0 0 ...

When applied to a data frame, the summary() function is essentially applied to each column, and the results for all columns are shown together. For a continuous (numeric) variable like “age”, it returns the 5-number summary showing 5 descriptive statistics as these are numeric values.

summary(data)
       ID          LIMIT_BAL            SEX          EDUCATION    
 Min.   :    1   Min.   :  10000   Min.   :1.000   Min.   :0.000  
 1st Qu.: 7501   1st Qu.:  50000   1st Qu.:1.000   1st Qu.:1.000  
 Median :15000   Median : 140000   Median :2.000   Median :2.000  
 Mean   :15000   Mean   : 167484   Mean   :1.604   Mean   :1.853  
 3rd Qu.:22500   3rd Qu.: 240000   3rd Qu.:2.000   3rd Qu.:2.000  
 Max.   :30000   Max.   :1000000   Max.   :2.000   Max.   :6.000  
    MARRIAGE          AGE            PAY_0             PAY_2        
 Min.   :0.000   Min.   :21.00   Min.   :-2.0000   Min.   :-2.0000  
 1st Qu.:1.000   1st Qu.:28.00   1st Qu.:-1.0000   1st Qu.:-1.0000  
 Median :2.000   Median :34.00   Median : 0.0000   Median : 0.0000  
 Mean   :1.552   Mean   :35.49   Mean   :-0.0167   Mean   :-0.1338  
 3rd Qu.:2.000   3rd Qu.:41.00   3rd Qu.: 0.0000   3rd Qu.: 0.0000  
 Max.   :3.000   Max.   :79.00   Max.   : 8.0000   Max.   : 8.0000  
     PAY_3             PAY_4             PAY_5             PAY_6        
 Min.   :-2.0000   Min.   :-2.0000   Min.   :-2.0000   Min.   :-2.0000  
 1st Qu.:-1.0000   1st Qu.:-1.0000   1st Qu.:-1.0000   1st Qu.:-1.0000  
 Median : 0.0000   Median : 0.0000   Median : 0.0000   Median : 0.0000  
 Mean   :-0.1662   Mean   :-0.2207   Mean   :-0.2662   Mean   :-0.2911  
 3rd Qu.: 0.0000   3rd Qu.: 0.0000   3rd Qu.: 0.0000   3rd Qu.: 0.0000  
 Max.   : 8.0000   Max.   : 8.0000   Max.   : 8.0000   Max.   : 8.0000  
   BILL_AMT1         BILL_AMT2        BILL_AMT3         BILL_AMT4      
 Min.   :-165580   Min.   :-69777   Min.   :-157264   Min.   :-170000  
 1st Qu.:   3559   1st Qu.:  2985   1st Qu.:   2666   1st Qu.:   2327  
 Median :  22382   Median : 21200   Median :  20089   Median :  19052  
 Mean   :  51223   Mean   : 49179   Mean   :  47013   Mean   :  43263  
 3rd Qu.:  67091   3rd Qu.: 64006   3rd Qu.:  60165   3rd Qu.:  54506  
 Max.   : 964511   Max.   :983931   Max.   :1664089   Max.   : 891586  
   BILL_AMT5        BILL_AMT6          PAY_AMT1         PAY_AMT2      
 Min.   :-81334   Min.   :-339603   Min.   :     0   Min.   :      0  
 1st Qu.:  1763   1st Qu.:   1256   1st Qu.:  1000   1st Qu.:    833  
 Median : 18105   Median :  17071   Median :  2100   Median :   2009  
 Mean   : 40311   Mean   :  38872   Mean   :  5664   Mean   :   5921  
 3rd Qu.: 50191   3rd Qu.:  49198   3rd Qu.:  5006   3rd Qu.:   5000  
 Max.   :927171   Max.   : 961664   Max.   :873552   Max.   :1684259  
    PAY_AMT3         PAY_AMT4         PAY_AMT5           PAY_AMT6       
 Min.   :     0   Min.   :     0   Min.   :     0.0   Min.   :     0.0  
 1st Qu.:   390   1st Qu.:   296   1st Qu.:   252.5   1st Qu.:   117.8  
 Median :  1800   Median :  1500   Median :  1500.0   Median :  1500.0  
 Mean   :  5226   Mean   :  4826   Mean   :  4799.4   Mean   :  5215.5  
 3rd Qu.:  4505   3rd Qu.:  4013   3rd Qu.:  4031.5   3rd Qu.:  4000.0  
 Max.   :896040   Max.   :621000   Max.   :426529.0   Max.   :528666.0  
 default_payment 
 Min.   :0.0000  
 1st Qu.:0.0000  
 Median :0.0000  
 Mean   :0.2212  
 3rd Qu.:0.0000  
 Max.   :1.0000

Using the introduced method, we can get to know the basic information about the dataframe, including the number of missing values in each variable.

introduce(data)

As we can observe, there are no missing values in the dataframe.

The information in summary above gives a sense of the continuous and categorical features in our dataset. However, evaluating these details against the data description shows that categorical values such as EDUCATION and MARRIAGE have categories beyond those given in the data dictionary. We’ll find out these extra categories using the value_counts method.

count(data, vars = EDUCATION)
vars n
0 14
1 10585
2 14030
3 4917
4 123
5 280
6 51

The data dictionary defines the following categories for EDUCATION: “Education (1 = graduate school; 2 = university; 3 = high school; 4 = others)”. However, we can also observe 0 along with numbers greater than 4, i.e. 5 and 6. Since we don’t have any further details about it, we can assume 0 to be someone with no educational experience and 0 along with 5 & 6 can be placed in others along with 4.

count(data, vars = MARRIAGE)
vars n
0 54
1 13659
2 15964
3 323

The data dictionary defines the following categories for MARRIAGE: “Marital status (1 = married; 2 = single; 3 = others)”. Since category 0 hasn’t been defined anywhere in the data dictionary, we can include it in the ‘others’ category marked as 3.

#replace 0's with NAN, replace others too
data$EDUCATION[data$EDUCATION == 0] <- 4
data$EDUCATION[data$EDUCATION == 5] <- 4
data$EDUCATION[data$EDUCATION == 6] <- 4
data$MARRIAGE[data$MARRIAGE == 0] <- 3
count(data, vars = MARRIAGE)
count(data, vars = EDUCATION)
vars n
1 13659
2 15964
3 377
vars n
1 10585
2 14030
3 4917
4 468

We’ll now move on to a multi-variate analysis of our variables and draw a correlation heat map from the DataExplorer library. The heatmap will enable us to find out the correlation between each variable. We are more interested in finding out the correlation between our predictor attributes with the target attribute default payment next month. The color scheme depicts the strength of the correlation between the 2 variables.

This will be a simple way to quickly find out how much of an impact a variable has on our final outcome. There are other ways as well to figure this out.

plot_correlation(na.omit(data), maxcat = 5L)

Plot correlation heatmap

We can observe the weak correlation of AGE, BILL_AMT1, BILL_AMT2, BILL_AMT3, BILL_AMT4, BILL_AMT5, and BILL_AMT6 with our target variable.

Now let’s have a univariate analysis of our variables. We’ll start with the categorical variables and have a quick check on the frequency of distribution of categories. The code below will allow us to observe the required graphs. We’ll first draw the distribution for all PAY variables.

plot_histogram(data)

Plot histogram data

We can make a few observations from the above histogram. The distribution above shows that nearly all PAY attributes are rightly skewed.

4) Feature engineering

This step can be more important than the actual model used because a machine learning algorithm only learns from the data we give it, and creating features that are relevant to a task is absolutely crucial.

Analyzing our data above, we’ve been able to note the extremely weak correlation of some variables with the final target variable. The following are the ones that have significantly low correlation values: AGE, BILL_AMT2, BILL_AMT3, BILL_AMT4, BILL_AMT5, BILL_AMT6.

#deleting columns

data_new <- select(data, -one_of('ID','AGE', 'BILL_AMT2',
       'BILL_AMT3','BILL_AMT4','BILL_AMT5','BILL_AMT6'))
head(data_new)

correlation values in dataset

5) Pre-processing

Standardization is a transformation that centers the data by removing the mean value of each feature and then scaling it by dividing (non-constant) features by their standard deviation. After standardizing data the mean will be zero and the standard deviation one.

It is most suitable for techniques that assume a Gaussian distribution in the input variables and work better with rescaled data, such as linear regression, logistic regression, and linear discriminate analysis. If a feature has a variance that is orders of magnitude larger than others, it might dominate the objective function and make the estimator unable to learn from other features correctly as expected.

In the code below, we’ll use the scale method to transform our dataset using it.

data_new[, 1:17] <- scale(data_new[, 1:17])
head(data_new)

scale method - dataset

The next task we’ll do is to split the data for training and testing as we’ll use our test data to evaluate our model. We will now split our dataset into train and test. We’ll change it to 0.3. Therefore, 30% of the dataset is reserved for testing while the remaining is for training. By default, the dataset will also be shuffled before splitting.

#create a list of random number ranging from 1 to number of rows from actual data 
#and 70% of the data into training data  

data2 = sort(sample(nrow(data_new), nrow(data_new)*.7))

#creating training data set by selecting the output row values
train <- data_new[data2,]

#creating test data set by not selecting the output row values
test <- data_new[-data2,]

Let us print the dimensions of all these variables using the dim method. You can notice the 70-30% split.

dim(train)
dim(test)

21000 18

9000 18

6) Model development

We will now move on to the most important step of developing our logistic regression model. We have already fetched our machine learning model in the beginning. Now with a few lines of code, we’ll first create a logistic regression model which has been imported from sci-kit learn’s linear model package to our variable named model.

Following this, we’ll train our model using the fit method with X_train and y_train which contain 70% of our dataset. This will be a binary classification model.

## fit a logistic regression model with the training dataset
log.model <- glm(default_payment ~., data = train, family = binomial(link = "logit"))
summary(log.model)
Call:
glm(formula = default_payment ~ ., family = binomial(link = "logit"), 
    data = train)

Deviance Residuals: 
    Min       1Q   Median       3Q      Max  
-3.1171  -0.6998  -0.5473  -0.2946   3.4915  

Coefficients:
             Estimate Std. Error z value Pr(>|z|)    
(Intercept) -1.465097   0.019825 -73.900  < 2e-16 ***
LIMIT_BAL   -0.083475   0.023905  -3.492 0.000480 ***
SEX         -0.082986   0.017717  -4.684 2.81e-06 ***
EDUCATION   -0.059851   0.019178  -3.121 0.001803 ** 
MARRIAGE    -0.107322   0.018350  -5.849 4.95e-09 ***
PAY_0        0.661918   0.023605  28.041  < 2e-16 ***
PAY_2        0.069704   0.028842   2.417 0.015660 *  
PAY_3        0.090691   0.031982   2.836 0.004573 ** 
PAY_4        0.074336   0.034612   2.148 0.031738 *  
PAY_5        0.018469   0.036430   0.507 0.612178    
PAY_6        0.006314   0.030235   0.209 0.834584    
BILL_AMT1   -0.123582   0.023558  -5.246 1.56e-07 ***
PAY_AMT1    -0.136745   0.037549  -3.642 0.000271 ***
PAY_AMT2    -0.246634   0.056432  -4.370 1.24e-05 ***
PAY_AMT3    -0.014662   0.028012  -0.523 0.600677    
PAY_AMT4    -0.087782   0.031484  -2.788 0.005300 ** 
PAY_AMT5    -0.084533   0.030917  -2.734 0.006254 ** 
PAY_AMT6    -0.027355   0.025707  -1.064 0.287277    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 22176  on 20999  degrees of freedom
Residual deviance: 19535  on 20982  degrees of freedom
AIC: 19571

Number of Fisher Scoring iterations: 6

7) Prediction

Below we’ll use the prediction method to find out the predictions made by our Logistic Regression method. We will first store the predicted results in our y_pred variable and print the first 10 rows of our test data set. Following this we will print the predicted values of the corresponding rows and the original labels that were stored in y_test for comparison.

test[1:10,]

Predicted values in dataset

## to predict using logistic regression model, probablilities obtained
log.predictions <- predict(log.model, test, type="response")

## Look at probability output
head(log.predictions, 10)
2
0.539623162720197
7
0.232835137994762
10
0.25988780274953
11
0.0556716133560243
15
0.422481223473459
22
0.165384552048511
25
0.0494775267027534
26
0.238225423596718
31
0.248366972046479
37
0.111907725985513

Below we are going to assign our labels with the decision rule that if the prediction is greater than 0.5, assign it 1 else 0.

log.prediction.rd <- ifelse(log.predictions > 0.5, 1, 0)
head(log.prediction.rd, 10)
2
1
7
0
10
0
11
0
15
0
22
0
25
0
26
0
31
0
37
0

Evaluation

We’ll now discuss a few evaluation metrics to measure the performance of our machine-learning model here. This part has significant relevance since it will allow us to understand the most important characteristics that led to our model development.

We will output the confusion matrix. It is a handy presentation of the accuracy of a model with two or more classes.

The table presents predictions on the x-axis and accuracy outcomes on the y-axis. The cells of the table are the number of predictions made by a machine learning algorithm.

According to an article the entries in the confusion matrix have the following meaning in the context of our study:

[[a b][c d]]

  • a is the number of correct predictions that an instance is negative,
  • b is the number of incorrect predictions that an instance is positive,
  • c is the number of incorrect predictions that an instance is negative, and
  • d is the number of correct predictions that an instance is positive.
table(log.prediction.rd, test[,18])
                 
log.prediction.rd    0    1
                0 6832 1517
                1  170  481

We’ll write a simple function to print the accuracy below

accuracy <- table(log.prediction.rd, test[,18])
sum(diag(accuracy))/sum(accuracy)

0.812555555555556

Conclusion

This tutorial has given you a brief and concise overview of the Logistic Regression algorithm and all the steps involved in achieving better results from our model. This notebook has also highlighted a few methods related to Exploratory Data Analysis, Pre-processing, and Evaluation, however, there are several other methods that we would encourage you to explore on our blog or video tutorials.

If you want to take a deeper dive into several data science techniques. Join our 5-day hands-on Data Science Bootcamp preferred by working professionals, we cover the following topics:

  • Fundamentals of Data Mining
  • Machine Learning Fundamentals
  • Introduction to R
  • Introduction to Azure Machine Learning Studio
  • Data Exploration, Visualization, and Feature Engineering
  • Decision Tree Learning
  • Ensemble Methods: Bagging, Boosting, and Random Forest
  • Regression: Cost Functions, Gradient Descent, Regularization
  • Unsupervised Learning
  • Recommendation Systems
  • Metrics and Methods for Evaluating Predictive Models
  • Introduction to Online Experimentation and A/B Testing
  • Fundamentals of Big Data Engineering
  • Hadoop and Hive
  • Message Queues and Real-time Analytics
  • NoSQL Databases and HBase
  • Hack Project: Creating a Real-time IoT Pipeline
  • Naive Bayes
  • Logistic Regression
  • Times Series Forecasting

This post was originally sponsored on What’s The Big Data.

August 18, 2022

Related Topics

Statistics
Resources
rag
Programming
Machine Learning
LLM
Generative AI
Data Visualization
Data Security
Data Science
Data Engineering
Data Analytics
Computer Vision
Career
AI