机器学习实战 手写识别系统

构造使用k-近邻分类器的手写识别系统

step1.准备数据:将图像转换为测试向量

把32*32的二进制图像矩阵转换成1*1024的向量,这样就可以之前的分类器来处理了

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

#代码解释

fileObject.readline(size)
size -- 从文件中读取的字节数。

zeros(shape, dtype=float, order='C')

返回一个给定形状和类型的用0填充的数组;

shape:形状

dtype:数据类型

order:可选参数,c代表与c语言类似,行优先;F代表列优先

step2.测试算法:使用k-近邻算法识别手写数字

算法作用:

把训练集的每个图片改成一行的数组  m个图片形成m*1024的数组

把图片的文件名拆开  记录该图片代表的数字

对测试集进行同样的操作  将测试集每个图片的结果和真实数字比较  记录测试集中一共有多少个不同的(也就是判断错误的)

计算出错误率  错误率=测试集中判断错误的 / 测试集总数目

def handwritingClassTest():
    hwLabels = []

    # load the training set
    # trainingFileList是各个文件
    # m是文件个数
    # trainingMat是各个文件变成一行之后的整合矩阵
    #fileNameStr单个文件名

    trainingFileList = listdir('trainingDigits')          
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        # take off .txt
        #按.分开 fileStr取分开后的第一个元素 classNumStr记录这个数字是什么
        fileStr = fileNameStr.split('.')[0]
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        #通过img2vector把图片变成一行的 并放入trainingMat对应行
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
        
    # iterate through the test set
    testFileList = listdir('testDigits')        
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print ("the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr))
        #当测试结果和训练结果不同时 错误+1
        if (classifierResult != classNumStr): errorCount += 1.0
    print ("\nthe total number of errors is: %d" % errorCount)
    print ("\nthe total error rate is: %f" % (errorCount/float(mTest)))

#代码解释

os.listdir(path)
返回指定路径path下的文件和文件夹列表

str.split(str="", num=string.count(str)).
split( ) 通过指定分隔符对字符串进行切片,如果参数 num 有指定值,则分隔 num+1 个子字符串,返回分割后的字符串列表。 

str-- 分隔符,默认为所有的空字符,包括空格、换行(\n)、制表符(\t)等。

num-- 分割次数。默认为 -1, 即分隔所有。

 

 

 

 

 


版权声明:本文为qq_36895331原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。