Predicting Churn with KNN


Although there are many examples, churn prediction is one of the classical applications of Data Science that works. Churn prediction gives businessmen and bussinesswomen the power to catch those consumers who are likely to leave the company. They can in turn take appropriate measure to keep their business. In this project, we consider a dataset of 7,043 observations. The goal here is to predict, with a fair amount of accuracy, which observations are likely to churn.

Exploratory Data Analysis

## Loading required package: ggplot2
## Loading required package: magrittr

Note that the echo = FALSE parameter was added to the code chunk to prevent printing of the R code that generated the plot.

DT::datatable(churn[,c(1:2,16:21)])

Including Plots

You can also embed plots, for example:

my.data <- cbind(churn,data.frame(num = 1) )
my.data$SeniorCitizen <- factor(my.data$SeniorCitizen)
my.data2 <- aggregate(num~. , data = my.data[, c(7:8, 12, 16,21, 22)], sum)
  
  

SankeyDiagram(my.data2[, - dim(my.data2)[2]],
              link.color = "Source", 
              weights = my.data2$num) 
PhoneService: Yes → MultipleLines: No
3,390 
PhoneService: Yes → MultipleLines: Yes
2,971 
DeviceProtection: No → Contract: Month-to-month
2,394 
Contract: Month-to-month → Churn: No
2,220 
Contract: Month-to-month → Churn: Yes
1,655 
Contract: Two year → Churn: No
1,647 
MultipleLines: No → DeviceProtection: No
1,443 
MultipleLines: Yes → DeviceProtection: Yes
1,354 
Contract: One year → Churn: No
1,307 
MultipleLines: Yes → DeviceProtection: No
1,275 
MultipleLines: No → DeviceProtection: No internet service
1,184 
DeviceProtection: Yes → Contract: Month-to-month
957 
DeviceProtection: Yes → Contract: Two year
819 
MultipleLines: No → DeviceProtection: Yes
763 
PhoneService: No → MultipleLines: No phone service
682 
DeviceProtection: Yes → Contract: One year
646 
DeviceProtection: No internet service → Contract: Two year
638 
DeviceProtection: No internet service → Contract: Month-to-month
524 
DeviceProtection: No → Contract: One year
463 
MultipleLines: No phone service → DeviceProtection: No
377 
DeviceProtection: No internet service → Contract: One year
364 
MultipleLines: Yes → DeviceProtection: No internet service
342 
MultipleLines: No phone service → DeviceProtection: Yes
305 
DeviceProtection: No → Contract: Two year
238 
Contract: One year → Churn: Yes
166 
Contract: Two year → Churn: Yes
48 
PhoneService: No
682
PhoneService: No
PhoneService: Yes
6,361
PhoneService: Yes
MultipleLines: No
3,390
MultipleLines: No
MultipleLines: No phone service
682
MultipleLines: No phone service
MultipleLines: Yes
2,971
MultipleLines: Yes
DeviceProtection: No
3,095
DeviceProtection: No
DeviceProtection: No internet service
1,526
DeviceProtection: No internet service
DeviceProtection: Yes
2,422
DeviceProtection: Yes
Contract: Month-to-month
3,875
Contract: Month-to-month
Contract: One year
1,473
Contract: One year
Contract: Two year
1,695
Contract: Two year
Churn: No
5,174
Churn: No
Churn: Yes
1,869
Churn: Yes
temp <- churn[churn$Churn == "Yes", ]
var <- temp$Contract  # the categorical data 

nrows <- 10
df <- expand.grid(y = 1:nrows, x = 1:nrows)
categ_table <- round(table(var) * ((nrows*nrows)/(length(var))))
categ_table[1] <- categ_table[1] - 1 #Adjust

df$category <- factor(rep(names(categ_table), categ_table))  
# NOTE: if sum(categ_table) is not 100 (i.e. nrows^2), it will need adjustment to make the sum to 100.

## Plot
ggplot(df, aes(x = x, y = y, fill = category)) + 
  geom_tile(color = "black", size = 0.5) +
  scale_x_continuous(expand = c(0, 0)) +
  scale_y_continuous(expand = c(0, 0), trans = 'reverse') +
  scale_fill_brewer(palette = "Set3") +
  labs(title="Customer churn % (+)", subtitle="Contract Type",
       caption="Telecom") + 
  theme(
    plot.title = element_text(size = rel(1.2)),
    axis.text = element_blank(),
    axis.title = element_blank(),
    axis.ticks = element_blank(),
    legend.title = element_blank(),
    legend.position = "right")

Exploring

plot.barplot(x = churn, 
             xname = "Churn",
             fillname = "Contract")

##      
##       Month-to-month One year Two year
##   No            2220     1307     1647
##   Yes           1655      166       48
plot.barplot(churn, 
             xname = "Churn" ,
             fillname = "Dependents")

##      
##         No  Yes
##   No  3390 1784
##   Yes 1543  326
plot.barplot(churn, 
             xname = "Churn", 
             fillname = "StreamingMovies")

##      
##         No No internet service  Yes
##   No  1847                1413 1914
##   Yes  938                 113  818
plot.barplot(churn, 
             xname = "Churn", 
             fillname = "PaymentMethod") 

##      
##       Bank transfer (automatic) Credit card (automatic) Electronic check
##   No                       1286                    1290             1294
##   Yes                       258                     232             1071
##      
##       Mailed check
##   No          1304
##   Yes          308
ggplot(churn, aes(x = Churn, y = tenure, fill= Contract)) + 
  geom_boxplot() + theme_minimal()

ggplot(churn, aes(x = Churn, y = tenure)) + 
  geom_boxplot() + theme_minimal()

ggplot(churn, aes(x = Churn, y = tenure, fill= TechSupport)) + 
  geom_boxplot() + facet_wrap(~SeniorCitizen)+ theme_minimal()

ggplot(churn, aes(x = Churn, y = MonthlyCharges, fill= Contract)) + 
  geom_boxplot() + theme_minimal()

ggplot(churn, aes(x = tenure, y = TotalCharges, color = Contract)) + geom_point( ) + geom_smooth()+ theme_minimal()
## `geom_smooth()` using method = 'gam'
## Warning: Removed 11 rows containing non-finite values (stat_smooth).
## Warning: Removed 11 rows containing missing values (geom_point).

ggplot(churn, aes(x = MonthlyCharges, y = TotalCharges, color = Contract)) + geom_point( ) + geom_smooth()+ theme_minimal()
## `geom_smooth()` using method = 'gam'
## Warning: Removed 11 rows containing non-finite values (stat_smooth).

## Warning: Removed 11 rows containing missing values (geom_point).

ggplot(churn, aes(x = tenure, y = MonthlyCharges, color = Churn)) + 
  geom_point( )+ geom_smooth()+ facet_wrap(~factor(InternetService))+ theme_minimal()
## `geom_smooth()` using method = 'gam'

ggplot(churn, aes(x = Churn, y = tenure, fill = InternetService)) + 
  geom_boxplot() + facet_wrap(~TechSupport, nrow = 1)+ theme_minimal()

ggplot(churn, aes(x = MonthlyCharges, fill = Churn)) + 
  geom_histogram() + theme_minimal()#Monthly Charges by Churn
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

ggplot(churn, aes(x = MonthlyCharges, fill = TechSupport)) + 
  geom_histogram() + theme_minimal()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

ggplot(churn, aes(x = MonthlyCharges, fill = InternetService)) + 
  geom_histogram() + theme_minimal()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

ggplot(churn, aes(x = TotalCharges, fill = Contract)) + 
  geom_histogram() + facet_wrap(~InternetService) + theme_minimal()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## Warning: Removed 11 rows containing non-finite values (stat_bin).

 ggplot(churn, aes(x = tenure, y = MonthlyCharges, color = InternetService)) + 
  geom_point( )+ geom_smooth()+ theme_minimal()
## `geom_smooth()` using method = 'gam'

Model Building

For this mini-project, we will use logistic regression, decision trees, and k-nearest neighbors.

set.seed(100)
indx <- sample(1:nrow(churn), .8*nrow(churn), replace = F)

Logistic Regression

Since I am only using logistic regression to explore relationships, and I’m not necessarily too concerned with the accuracy of this particular model, I will hold off on splitting the data 80-20 until the subsequent models.

churn$Churn <- factor(churn$Churn)
churn$SeniorCitizen <- factor(churn$SeniorCitizen )
fit <- glm(Churn ~ SeniorCitizen + MonthlyCharges + tenure + Contract +
             PaymentMethod +TechSupport, data = churn, family = "binomial")
summary(fit)
## 
## Call:
## glm(formula = Churn ~ SeniorCitizen + MonthlyCharges + tenure + 
##     Contract + PaymentMethod + TechSupport, family = "binomial", 
##     data = churn)
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.8322  -0.6735  -0.3045   0.7601   3.1709  
## 
## Coefficients:
##                                       Estimate Std. Error z value Pr(>|z|)
## (Intercept)                          -1.330859   0.149909  -8.878  < 2e-16
## SeniorCitizen1                        0.349465   0.081552   4.285 1.83e-05
## MonthlyCharges                        0.022466   0.001796  12.508  < 2e-16
## tenure                               -0.034080   0.002107 -16.174  < 2e-16
## ContractOne year                     -0.812642   0.103350  -7.863 3.75e-15
## ContractTwo year                     -1.560966   0.171277  -9.114  < 2e-16
## PaymentMethodCredit card (automatic) -0.060629   0.112223  -0.540 0.589023
## PaymentMethodElectronic check         0.424624   0.092594   4.586 4.52e-06
## PaymentMethodMailed check            -0.068929   0.111202  -0.620 0.535353
## TechSupportNo internet service       -0.484025   0.142733  -3.391 0.000696
## TechSupportYes                       -0.536799   0.082908  -6.475 9.51e-11
##                                         
## (Intercept)                          ***
## SeniorCitizen1                       ***
## MonthlyCharges                       ***
## tenure                               ***
## ContractOne year                     ***
## ContractTwo year                     ***
## PaymentMethodCredit card (automatic)    
## PaymentMethodElectronic check        ***
## PaymentMethodMailed check               
## TechSupportNo internet service       ***
## TechSupportYes                       ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 8150.1  on 7042  degrees of freedom
## Residual deviance: 6000.7  on 7032  degrees of freedom
## AIC: 6022.7
## 
## Number of Fisher Scoring iterations: 6
exp(coef(fit))
##                          (Intercept)                       SeniorCitizen1 
##                            0.2642503                            1.4183090 
##                       MonthlyCharges                               tenure 
##                            1.0227206                            0.9664939 
##                     ContractOne year                     ContractTwo year 
##                            0.4436841                            0.2099333 
## PaymentMethodCredit card (automatic)        PaymentMethodElectronic check 
##                            0.9411728                            1.5290152 
##            PaymentMethodMailed check       TechSupportNo internet service 
##                            0.9333929                            0.6162979 
##                       TechSupportYes 
##                            0.5846168
exp(cbind(OR = coef(fit), confint(fit)))
## Waiting for profiling to be done...
##                                             OR     2.5 %    97.5 %
## (Intercept)                          0.2642503 0.1966656 0.3539899
## SeniorCitizen1                       1.4183090 1.2087080 1.6641164
## MonthlyCharges                       1.0227206 1.0191468 1.0263494
## tenure                               0.9664939 0.9624845 0.9704687
## ContractOne year                     0.4436841 0.3615988 0.5423228
## ContractTwo year                     0.2099333 0.1485725 0.2911130
## PaymentMethodCredit card (automatic) 0.9411728 0.7551599 1.1726155
## PaymentMethodElectronic check        1.5290152 1.2760141 1.8345824
## PaymentMethodMailed check            0.9333929 0.7507904 1.1611240
## TechSupportNo internet service       0.6162979 0.4650220 0.8139424
## TechSupportYes                       0.5846168 0.4965972 0.6873646

Some intersting things to point out: based on the data, those who pay with elecronic check have a 50% higher odds of churning over the baseline. Moreover, a one unit increase in montly charges cost corresponds with a 2% greater odds of churning.

tree = rpart::rpart(Churn ~ SeniorCitizen + MonthlyCharges + tenure + Contract +
               PaymentMethod +TechSupport + Dependents , data = churn)
rpart.plot::rpart.plot(tree)

train <- churn[indx,c("MonthlyCharges", "tenure", "TotalCharges", "Churn")]
test <- churn[-indx,c("MonthlyCharges", "tenure", "TotalCharges", "Churn")]

train <- na.omit(train)
test <- na.omit(test)
kn <- class::knn(train[,-4], test[,-4],train[,4] , k = 22, l = 4, prob = F, use.all = TRUE)
length(kn)
## [1] 1408
dim(test)
## [1] 1408    4
df <- table(Prob = kn, Obs = test[,"Churn"]) %>% data.frame()

ggplot(df, aes(x = Prob, y = Freq, fill = Obs)) + 
  geom_bar(stat = 'identity', position = "dodge") + theme_light()

table(Prob = kn, Obs = test[,"Churn"])
##      Obs
## Prob   No Yes
##   No  950 223
##   Yes  86 149
table(Prob = kn, Obs = test[,"Churn"]) %>% prop.table()
##      Obs
## Prob          No        Yes
##   No  0.67471591 0.15838068
##   Yes 0.06107955 0.10582386

This final model was able to achieve about 80% accuracy. KNN is very competitve with logistic regression and decision trees out of the box model.

Back to blog

In God we trust. All others must bring data.

- W. Edwards Deming, Statistician