This KNN tutorial will go through the technical part of KNN: 1) Train and test error according to the choice of K, 2) Cross-validation in KNN, 3) Validation process.
This time, I will demonstrate the process of 1) generating i.i.d. Gaussian data (synthetic data), set random seeds, 2) Make visualization, 3) Test KNN for different K, 4) Make Train and Test sets and Perform Cross-Validation
For classification purpose, we are generating two classes of data from Gaussian (Normal) distributions with different means. We first generate training data and will use this training data set to choose the K for K-nearest-neighbors. We also generate some test data (not allowed to know about during training). Both the training and test data set should be taken with care.
# install.package("mvtnorm") for multivariate normal distribution
# install.packages("flexclust") , a package originally for Flexible Cluster Algorithms
library(mvtnorm)
## Warning: package 'mvtnorm' was built under R version 3.4.3
library(flexclust)
## Warning: package 'flexclust' was built under R version 3.4.4
## Loading required package: grid
## Loading required package: lattice
## Loading required package: modeltools
## Loading required package: stats4
library(class)
library(tidyverse)
## Warning: package 'tidyverse' was built under R version 3.4.2
## ── Attaching packages ────────────────────────────────── tidyverse 1.2.1 ──
## ✔ ggplot2 2.2.1 ✔ purrr 0.2.4
## ✔ tibble 1.4.2 ✔ dplyr 0.7.4
## ✔ tidyr 0.8.0 ✔ stringr 1.3.0
## ✔ readr 1.1.1 ✔ forcats 0.3.0
## Warning: package 'tibble' was built under R version 3.4.3
## Warning: package 'tidyr' was built under R version 3.4.3
## Warning: package 'purrr' was built under R version 3.4.2
## Warning: package 'dplyr' was built under R version 3.4.2
## Warning: package 'stringr' was built under R version 3.4.3
## Warning: package 'forcats' was built under R version 3.4.3
## ── Conflicts ───────────────────────────────────── tidyverse_conflicts() ──
## ✖ dplyr::filter() masks stats::filter()
## ✖ dplyr::lag() masks stats::lag()
library(stringr)
Now, we generate the synthetic data, for both training and test datasets. We can use Gaussian Mixture Model to generate such data. Our goal is to generate the training and test dataset with 2 different class, each with different mean.
# Set the same seed of mixture means for both training and test sets
mean_seed <- 238
# Draw train and test data
# Gaussian mixture model
GMM2d_distribution <- function(n_neg, n_pos, mean_seed=NA, data_seed=NA){
# 2 class gaussian mixture model distribution
# grand class means
grand_mu_pos <- c(1, 0)
grand_mu_neg <- c(0, 1)
# If mean_seed is provided, sample the means from the mean seed for reproducibility
if(!is.na(mean_seed)){
set.seed(mean_seed)
# Sample class means
mu_pos <- rmvnorm(n=10, mean=grand_mu_pos, sigma=diag(2))
mu_neg <- rmvnorm(n=10, mean=grand_mu_neg, sigma=diag(2))
# If the data seed is not set then remove the seed
if(is.na(data_seed)){
rm(.Random.seed, envir=globalenv())
}
} else{ # If mean_seed is not provided
# Sample class means
mu_pos <- rmvnorm(n=10, mean=grand_mu_pos, sigma=diag(2))
mu_neg <- rmvnorm(n=10, mean=grand_mu_neg, sigma=diag(2))
}
# If the data seed is provided, then set the data seed
if(!is.na(data_seed)){
set.seed(data_seed)
}
# Use sample function to pick which means to sample from
m_index_pos <- sample.int(10, n_pos, replace = TRUE)
m_index_neg <- sample.int(10, n_neg, replace = TRUE)
# Sample data from each class
X_pos <- map(1:n_pos,function(i) rmvnorm(n=1, mean=mu_pos[m_index_pos[i], ], sigma=diag(2)/5)) %>%
unlist %>%
matrix(ncol=n_pos) %>%
t %>%
as_tibble() %>%
mutate(y=1)
X_neg <- map(1:n_neg,function(i) rmvnorm(n=1, mean=mu_neg[m_index_neg[i], ], sigma=diag(2)/5)) %>%
unlist %>%
matrix(ncol=n_neg) %>%
t %>%
as_tibble() %>%
mutate(y=-1)
# Set column names
colnames(X_pos) <- gsub("V", 'x', colnames(X_pos))
colnames(X_neg) <- gsub("V", 'x', colnames(X_neg))
# Put data into one data frame
data <- rbind(X_pos, X_neg)%>%
mutate(y =factor(y)) # Finally, class label should be a factor
data
}
# We create a small sample of train data, well 50, 51 from each class
data <- GMM2d_distribution(n_neg=50, n_pos=51, mean_seed=mean_seed, data_seed=1232)
# We also create a test data for validation
test_data <- GMM2d_distribution(n_neg=100, n_pos=100, mean_seed=mean_seed, data_seed=52345)
# Create a function to plot the KNN graph
two_class_gaussian_plot <- function(n_pos, n_neg, mu_pos, mu_neg, sigma_pos, sigma_neg, seed=NA){
if(!is.na(seed)){
set.seed(seed)
}
# Generate data from negative class
class_neg <- rmvnorm(n=n_neg, mean=mu_neg, sigma=sigma_neg) %>% # mvrnorm comes from MASS
as_tibble() %>%
mutate(y=-1) %>%
rename(x1=V1, x2=V2)
# Generate data from positive class
class_pos <- rmvnorm(n=n_pos, mean=mu_pos, sigma=sigma_pos) %>%
as_tibble() %>%
mutate(y=1) %>%
rename(x1=V1, x2=V2)
# Put data into one data frame
data <- rbind(class_pos, class_neg)%>%
mutate(y =factor(y)) # class label should be a factor
data
}
data1 <- two_class_gaussian_plot(n_pos=50, n_neg=51,
mu_pos=c(1,0), mu_neg=c(-1,0),
sigma_pos=diag(2), sigma_neg=diag(2),
seed=100)
test_data1 <- two_class_gaussian_plot(n_pos=100, n_neg=100,
mu_pos=c(1,0), mu_neg=c(-1,0),
sigma_pos=diag(2), sigma_neg=diag(2),
seed=3240)
# Plot the training data
ggplot()+
geom_point(data=data1, aes(x=x1, y=x2, color=y, shape=y)) +
theme(panel.background = element_blank()) +
ggtitle('training data')
We typically perform the case when K = 1, K=5 and so on.
According to the above codes, the training error rate is 0.049505 and the test error is 0.04. Notice the training error is better than the test error. It’s almost always true that a statistical algorithm (KNN is an algorithm with a trained classifier) will perform better on the data it was trained on than on an independent test set (hence the problem of overfitting).
Now let’s look at the predictions resulting from KNN for different values of K. First we show what the predictions will be at every point in the plane (or really every point in our test grid).
k_values <- c(1, 3, 5, 9, 15, 31, 51, 101)