rm(list=ls()) # set the seed as the last three digits of your student ID seed = 111 if (seed == 111) { stop("\nThe random generator seed is still set to its default value.\nEdit the script and change it with the last three digits of your student ID\n\n", call. = FALSE) } set.seed(seed) # number of points in the dataset trainSetSize = 3000 rotate <- function(x) t(apply(x, 2, rev)) showdigit <- function(dataset,digitId,labels=dataset[,1]){ m = rotate(matrix(1-as.double(dataset[digitId,-1]+1)/2,16,16,byrow=TRUE)) image(m,axes=FALSE,col = grey(seq(0, 1, length = 256)), main=paste("Digit ID = ",digitId," (label = ",labels[digitId],")")) } ################################# MAIN CODE ########################################### cat("\n\n-------------------------------------------------------------------------------------\n") cat("Welcome to the third PAMI homework.\n\n") cat("In this demo you will run classification algorithms against two different datasets.\n") cat("In the first experiment, you will test your own email spam filter on the HP Labs SPAM E-mail Database.\n") cat("For more information about this dataset (and for the actual data), check out the following URLs:\n") cat("[1] http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/spam.info.txt\n") cat("[2] http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/spam.data\n") cat("[3] http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/spam.traintest\n") cat("[4] https://archive.ics.uci.edu/ml/datasets/Spambase\n\n") cat("The dataset contains features describing 4601 email messages (of which 1813=39.4% are spam).\n") cat("Note that you won't have access to the contents of the emails, but just to their 58 attributes:\n") cat("(see https://archive.ics.uci.edu/ml/machine-learning-databases/spambase/spambase.names)\n") cat("- 48 of them are frequencies of some relevant words\n") cat("- 6 of them are frequencies of some relevant characters\n") cat("- 3 of them relate to the use of capital letters in the message\n") cat("- the last one is a nominal {0,1} class attribute of type spam (1 = email was spam, 0 otherwise)\n\n") invisible(readline(prompt = "Press [enter] to continue")) spam = read.table("spam.data") test = read.table("spam.traintest") train = (test==0) spamdf = as.data.frame(spam) rm(test) # tell something about the training and test sets (specify they can be modified) cat("To evaluate our classification performance we are splitting the dataset in training and test sets.\n") cat("The spam dataset comes with a ready-made split (useful for comparison with different methods), with\n") cat("the following characteristics:\n") cat("- training set size: ", length(which(train==TRUE)), "(",length(which(spam[train,58]==1))," spam, ",length(which(spam[train,58]==0))," non-spam)\n") cat("- test set size: ", length(which(train==FALSE)), "(",length(which(spam[!train,58]==1))," spam, ",length(which(spam[!train,58]==0))," non-spam)\n\n") cat("As one of your task, you will be required to change this default set into a randomly generated\n") cat("one (just check the source code of this demo to see how to do it).\n\n") # Uncomment the following two lines of code to enable random generation of the training/test sets. # train = rep(FALSE,nrow(spam)) # train[sample(1:nrow(spam), trainSetSize)] = TRUE spam.train = spam[train,] spam.test = spam[!train,] cat("We will now run logistic regression to classify the email messages. What you will see below are\n") cat("the *confusion matrix* and the *test error rate* calculated by doing a default split at probability\n") cat("threshold = 0.5 (i.e. everything with predicted p>.5 will be spam, everything else will be not).\n") invisible(readline(prompt = "Press [enter] to continue")) glm.fit = suppressWarnings(glm(V58 ~ ., data=spam.train, family=binomial)) glm.probs = predict(glm.fit, spam.test, type="response") # tell we run the split by doing the default probs glm.pred = rep(0,dim(spam.test)[1]) glm.pred[glm.probs>.5]=1 V58.test = spamdf$V58[!train] V58.train = spamdf$V58[train] cat("=== Classification results for logistic regression ===\n\n") cat("Confusion matrix:\n") table(glm.pred,V58.test, dnn=c("Spam (predicted)","Spam (ground truth)")) cat("\nTest error rate = ", mean(glm.pred!=V58.test),"\n") cat("======================================================\n\n") cat("... not bad, isn't it? Let us now compare this results with the ones you get from other methods.\n") invisible(readline(prompt = "Press [enter] to continue")) # classify with LDA library(MASS) lda.fit = lda(V58 ~ ., data = spamdf, subset=train) lda.pred = predict(lda.fit, spam.test) lda.class = lda.pred$class cat("========== Classification results for LDA ===========\n\n") cat("Confusion matrix:\n") table(lda.class,V58.test, dnn=c("Spam (predicted)","Spam (ground truth)")) cat("\nTest error rate = ", mean(lda.class!=V58.test),"\n") cat("======================================================\n\n") invisible(readline(prompt = "Press [enter] to continue")) # classify with QDA qda.fit = qda(V58 ~ ., data = spamdf, subset=train) qda.pred = predict(qda.fit, spam.test) qda.class = qda.pred$class cat("========== Classification results for QDA ===========\n\n") cat("Confusion matrix:\n") table(qda.class,V58.test, dnn=c("Spam (predicted)","Spam (ground truth)")) cat("\nTest error rate = ", mean(qda.class!=V58.test),"\n") cat("======================================================\n\n") invisible(readline(prompt = "Press [enter] to continue")) # classify with KNN library(class) knn.pred = knn(spam.train, spam.test, V58.train, k = 1) cat("========== Classification results for KNN ===========\n\n") cat("Confusion matrix:\n") table(knn.pred,V58.test, dnn=c("Spam (predicted)","Spam (ground truth)")) cat("\nTest error rate = ", mean(knn.pred!=V58.test),"\n") cat("======================================================\n\n") invisible(readline(prompt = "Press [enter] to continue")) cat("Let us move now to our second experiment: handwritten digit classification (check out the dataset\n") cat("info here: http://statweb.stanford.edu/~tibs/ElemStatLearn/datasets/zip.info.txt).\n\n") cat("In this case the features characterizing the digits are the digits themselves, i.e. the colors\n") cat("(actually the gray level) of their pixels. To give you a grasp of it, here are few of them:\n") invisible(readline(prompt = "Press [enter] to continue")) digits.train = read.table(gzfile("zip.train.gz")) digits.test = read.table(gzfile("zip.test.gz")) randIdx = sample(1:nrow(digits.train),5) for (i in randIdx){ showdigit(digits.train,i) invisible(readline(prompt = "Press [enter] to continue")) } cat("We will now classify these digits with LDA, QDA, and KNN (this might take a while, please\n") cat("be patient...)\n") invisible(readline(prompt = "Press [enter] to continue")) lda.fit = lda(V1 ~ . , data = digits.train) lda.pred=predict(lda.fit,digits.test) lda.class=lda.pred$class cat("========== Classification results for LDA ===========\n\n") cat("Confusion matrix:\n") table(lda.class,digits.test$V1) cat("\nTest error rate = ", mean(lda.class!=digits.test$V1),"\n") cat("======================================================\n\n") invisible(readline(prompt = "Press [enter] to continue")) #digits.test$V1[1:10] #lda.class[1:10] # add jitter before QDA to avoid exact multicolinearity digits.train.J = digits.train digits.train.J[, -1] <- apply(digits.train[, -1], 2, jitter) qda.fit = qda(V1 ~ . , data = digits.train.J) qda.pred=predict(qda.fit,digits.test) qda.class=qda.pred$class cat("========== Classification results for QDA ===========\n\n") cat("Confusion matrix:\n") table(qda.class,digits.test$V1) cat("\nTest error rate = ", mean(qda.class!=digits.test$V1),"\n") cat("======================================================\n\n") invisible(readline(prompt = "Press [enter] to continue")) #digits.test$V1[1:10] #qda.class[1:10] cat("(don't panic... KNN is the slowest one ;-))\n") knn.pred = knn(digits.train, digits.test, digits.train$V1, k = 1) cat("========== Classification results for KNN ===========\n\n") cat("Confusion matrix:\n") table(knn.pred,digits.test$V1) cat("\nTest error rate = ", mean(knn.pred!=digits.test$V1),"\n") cat("======================================================\n\n") invisible(readline(prompt = "Press [enter] to continue")) # uncomment the following code to show a set of misclassified pics (just choose the proper # values for val1 and val2, you can guess them from a confusion matrix) # val1 = 0 # val2 = 0 # wrongIndices = which(digits.test$V1==val1 & qda.class==val2) # for (i in wrongIndices){ # showdigit(digits.test,i,qda.class) # invisible(readline(prompt = "Press [enter] to continue")) # } cat("That's it! If you are interested in the comparison between other methods on this dataset, also check:\n") cat("http://blog.quantitations.com/machine%20learning/2013/02/27/comparing-classification-algorithms-for-handwritten-digits/\n\n")