Friday

kNN Classification of Handwriting, in Python

Introduction

In today's blog, I will develop a simple classifier that recognizes handwritten digits.

I'll write a kNN (k-nearest-neighbor) classifier and test it on a set of scanned handwritten digit images. The images come from the MNIST data set. They have been pre-processed by image-processing software and stored as text files. Each digit is of the same size and color: 32x32 black and white. '0's stand for the black pixels in an image. Examples of handwritten digits '9' and '2':

    00000000000000111111111010000000 00000000011111110000000000000000
    00000000000011111111111110000000 00000001111111111000000000000000
    00000000000111111111111111100000 00000011111111111100000000000000
    00000000000111111111111111110000 00000001111111111111000000000000
    00000000011111111111111111110000 00000001111111111111000000000000
    00000000011111111000000111110000 00000001111111111111000000000000
    00000000111111000000001111100000 00000001111100111111100000000000
    00000001111110000000001111100000 00000001111100011111100000000000
    00000001111100000000001111100000 00000001111000011111100000000000
    00000001111100000000001111100000 00000001110000011111100000000000
    00000001111000000001111111100000 00000000100000011111100000000000
    00000001111000000011111111000000 00000000000000011111000000000000
    00000001111000000111111110000000 00000000000000001111100000000000
    00000011111000111111111100000000 00000000000000001111100000000000
    00000011111001111111111100000000 00000000000000001111100000000000
    00000001111111111111111000000000 00000000000000011111100000000000
    00000001111111111111111000000000 00000000000000011111100000000000
    00000001111111111111111000000000 00000000000000111111000000000000
    00000000011111111111111000000000 00000000000000111110000000000000
    00000000001111100111110000000000 00000000000001111110000000000000
    00000000000000001111000000000000 00000000000011111100000000000000
    00000000000000011111000000000000 00000000000001111110000000000000
    00000000000000011111000000000000 00000000000011111110000000000000
    00000000000000111110000000000000 00000000001111111100000000000000
    00000000000000111110000000000000 00000000011111111000000000000000
    00000000000001111110000000000000 00000000111111111111111111100000
    00000000000011111110000000000000 00000001111111111111111111110000
    00000000000111111100000000000000 00000011111111111111111111110000
    00000000000111111000000000000000 00000011111111111111111111110000
    00000000000111111100000000000000 00000011111111111111111111110000
    00000000000111111100000000000000 00000001111111111111111111110000
    00000000000011110000000000000000 00000000111111111110000000000000

Preliminaries

knn can be implemented quickly in Python or MATLAB. I will use Python.

The original data can be found here. (If you'd like to replicate what follows, you can download this data set. Click on 'Raw' button on that page and unpack the zip file on your machine. Rename the two unpacked directories to train and test. Then put the python scripts that follow in the same directory.)

The images are stored in two directories, train, containing the training dataset, and test, containing the test dataset. The training dataset has about 1900 images, just about 200 samples from each digit. The test dataset contains about 900 examples.

I will use python's numpy package in what follows. (numpy considerably simplifies array operations. However, it is not packaged by default with every python distribution and may need to be installed separately.)

These are the modules we'll need:

  from numpy import *
  import operator
  from os import listdir

Writing the Classifier

Function knn_classify takes an image of a digit and outputs a label (0, 1, ..., or 9) that it thinks should be assigned to that image. It does this by looking at how 'close' this image is to all other images in the training data set. It selects the k closest images to the input image and looks at their labels. Of these labels it selects the majority and assigns that label to the input image, breaking ties arbitrarily. For example, if our 3NN classifier found 3 closest images with labels '9', '9', and '3', it would assign label '9' to its input image.

Of course we need to be clear about what is meant by images being 'close' to one-another. Take two images and consider 2 pixels, one from each, in the same position, say in row i and column j. If for some (i,j) the pixel values in the two images disagree, we'll say that there's a non-zero distance between them. Walking over all (i,j), the more agreements we get between the pixels, the closer the images are.

One could represent the 32x32 images as 1x1024 arrays (vectors). The number of mismatches in the corresponding positions (at indices) of the two arrays being compared can then be computed. Or we could compute the usual euclidean distance between the vectors. There are a few ways to skin this cat, and just about as many ways to write this classifier.

In a nutshell, here is what a kNN classifier is. (Thanks Koba for the link!)

Anyway, here's some python code that does this. It takes as input a vector inX (the test image that needs to be classified), a matrix called dataMat (the Nx1024 matrix where N is the number of images in the training data set), the list of labels for the N images (which better have the same length N as the number of images), and k, the number of nearest neighbors to use when taking the majority vote. The k smallest distances are selected by sorting the distances from the input image vector inX to each row vector in the dataMat matrix in ascending order and selecting top k of them.


def knn_classify(inX, dataMat, labels, k):
    N = dataMat.shape[0]
    diffMat = tile(inX, (N,1)) - dataMat
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5     #don't really need euclidean distance, but OK
    sorted_ind = distances.argsort() #the indices of sorted distances      
    class_count={}          
    for i in range(k):  #the lowest k distances are used to vote on the class of inX
        vote_i_label = labels[sorted_ind[i]]
        class_count[vote_i_label] = class_count.get(vote_i_label,0) + 1
    sorted_class_count = sorted(class_count.iteritems(), key=operator.itemgetter(1), reverse=True)
    return sorted_class_count[0][0]

(A python note: the class_count dictionary is decomposed into a list of tuples. The tuples are then sorted by the second item using the itemgetter method from the operator module imported earlier. )

Image to Vector

Next, let's write a small function img2vector that converts an image to a vector.

 
def img2vector(file):
    vec = zeros((1,1024))
    fh = open(file)
    for i in range(32):
        line = fh.readline()
        for j in range(32):
            vec[0,32*i+j] = int(line[j])
    return vec

Testing the kNN on Handwritten Digits

Now we have the data in the format that can be plugged into the classifier.

Finally, here's the function test_handwriting(). It checks how accurate our classifier is in predicting the labels:


def test_handwriting():
    labels = []
    training_files = listdir('train')           # get the list of training set files
    N = len(training_files)                     # N: number of training data points
    trainingMat = zeros((N,1024))
    for i in range(N):
        full_filename = training_files[i]        #file name including 'txt' extension
        filename = full_filename.split('.')[0]     #take off .txt
        label = int(filename.split('_')[0])
        labels.append(label)
        trainingMat[i,:] = img2vector('train/%s' % full_filename)
    test_files = listdir('test')        # get the list of test set files
    error_count = 0.0
    n_test = len(test_files)            # n_test: number of test data points
    for i in range(n_test):             # iterate through the test set
        full_filename = test_files[i]
        filename = full_filename.split('.')[0]     #take off .txt
        label = int(filename.split('_')[0])
        vectorUnderTest = img2vector('test/%s' % full_filename)
        classifierResult = knn_classify(vectorUnderTest, trainingMat, labels, 3)
        print "the kNN assigned label is: %d, the true label is: %d" % (classifierResult, label)
        if (classifierResult != label): error_count += 1.0
    print "\nthe total number of errors is: %d" % error_count
    #print "\nthe total error rate is: %f" % (error_count/float(n_test))
    print "\nthe accuracy is: %f" % (1.0 - error_count/float(n_test))


Put the above three functions in a file named knn.py and save the file in the same directory where you put the train and test directories containing the digit text files.

Then say this to your python interpreter:

  >>> import knn #or reload(knn) if already imported 
  >>> kNN.test_handwriting()

The output is interesting to observe. Depending on the speed of your computer, it might take a few minutes to complete. Note that most digits get classified correctly. Every now and then you observe a '7' misclassified as '1' or a '9' as a '3' or something like that.

I got 98.84% accuracy. Pretty good, I would say.

d.

No comments:

Post a Comment