标签:
之前学过svm相关知识,基本原理不算复杂,今天做了一个手写字识别程序,总算验证了svm的效果。
因为只是验证效果,实现上原则是简单,使用python + libsvm + PIL(python image library)。这部分工作花了一些时间:
PIL:
http://www.pythonware.com/products/pil/
下载源码包,解压之后运行:python setup.py install即可。
max下python libsvm安装使用:http://blog.csdn.net/u012774963/article/details/14640583
libsvm python接口介绍:http://blog.csdn.net/lqhbupt/article/details/8599295
说是手写字,其实只是一到十这十个汉字,这样比较简单,而且收集的样本不太多。这十个汉字,在mac上用paintbrush前前后后画了259个80*80的png图片。图片缩放为16*16,二值化之后用一个256维的向量表示,简单粗暴。准备训练数据文件:inittraindata.py
#! /usr/bin/env python import Image import os f = [] for i in range(1,11): f.append(open('ocr_' + str(i), 'wb')) for i in range(1,11): for item in os.listdir(str(i)): path = os.path.join(str(i), item) if os.path.isfile(path) and path.endswith(".png"): img_org = Image.open(path) img = img_org.resize((16,16), Image.NEAREST) pixdata = img.load() # -1 for j in range(1,i): line = "-1 " for k in range(0, 256): line += str(k + 1) if pixdata[k / 16,k % 16][0] == 255: line += ":0 " else: line += ":1 " f[j - 1].write(line + "\n") # -1 for j in range(i + 1, 11): line = "-1 " for k in range(0, 256): line += str(k + 1) if pixdata[k / 16, k % 16][0] == 255: line += ":0 " else: line += ":1 " f[j - 1].write(line + "\n") # 1 line = "1 " for k in range(0, 256): line += str(k + 1) if pixdata[k / 16, k % 16][0] == 255: line += ":0 " else: line += ":1 " f[i - 1].write(line + "\n") for o in f: o.close
训练数据并保存模型save.py:
#! /usr/bin/env python import sys from svmutil import * import Image import random for i in range(1, 11): y, x = svm_read_problem('./ocr_' + str(i)) # if i == 4 or i == 3: # m = svm_train(y, x, '-c 10000') # else: m = svm_train(y, x, '-c 3 -g 0.015625') svm_save_model('./model_' + str(i), m)
#! /usr/bin/env python import sys from svmutil import * import Image # load m = [] for i in range(1, 11): m.append(svm_load_model('./model_' + str(i))) # predict path = sys.argv[1] img_org = Image.open(path) img = img_org.resize((16,16), Image.NEAREST) pixdata = img.load() line = "-1 " tmpfile = open("tmpfile", "wb") for i in range(0, 256): line += str(i + 1) if pixdata[i / 16, i % 16][0] == 255: line += ":0 " else: line += ":1 " tmpfile.write(line + "\n") tmpfile.close() max = 100.0 maxidx = -1 for i in range(1, 11): y, x = svm_read_problem("tmpfile") label, acc, val = svm_predict(y, x, m[i - 1]) print val[0][0] if abs(val[0][0] - 1.0) < max: max = abs(val[0][0] - 1.0) maxidx = i print "probably is: ", maxidx
#! /usr/bin/env python from svmutil import * import random def test(y, x, c, g): count = len(y[0]) correct_rate = 0.0 # n-fold cross-validation for i in range(0, 10): marr = [] tarr = [] answers = [] for k in range(count*i/10, count*(i+1)*10): answers.append(0) for k in range(1, 11): # training sets yy = [] xx = [] for j in range(0, count*i/10): yy.append(y[k - 1][j]) xx.append(x[k - 1][j]) for j in range(count*(i + 1)/10, count): yy.append(y[k - 1][j]) xx.append(x[k - 1][j]) m = svm_train(yy, xx, '-c ' + str(c) + ' -g ' + str(g)) marr.append(m) yyy = [] xxx = [] for j in range(count*i/10, count*(i+1)/10): yyy.append(y[k - 1][j]) if y[k - 1][j] == 1: answers[j - count*i/10] = k xxx.append(x[k - 1][j]) # test sets tarr.append((yyy, xxx)) print answers # predicting correct_count = 0 for j in range(0, len(tarr[0][0])): max = 10000.0 maxidx = -1 for k in range(1, 11): label, acc, val = svm_predict(tarr[k - 1][0][j:j+1], tarr[k - 1][1][j:j+1], marr[k - 1]) if abs(val[0][0] - 1.0) < max: max = abs(val[0][0] - 1.0) maxid = k print "probably is", maxid, " answer is", answers[j] if answers[j] == maxid: correct_count += 1 correct_rate += float(correct_count) / len(tarr[0][0]) correct_rate /= 10 print 'c=',c,'g=',g,'avg_correct_rate=',correct_rate return correct_rate def main(): yarr = [] xarr = [] for i in range(1, 11): y, x = svm_read_problem('./ocr_' + str(i)) yarr.append(y) xarr.append(x) #shuffle arr = [] for i in range(0, len(yarr[0])): arr.append(i) random.shuffle(arr) print "RANDOM ARR:",arr count = len(yarr[0]) for i in range(1, 11): yy = [] xx = [] y = yarr[i - 1] x = xarr[i - 1] for j in range(0, count): yy.append(y[arr[j]]) xx.append(x[arr[j]]) yarr[i - 1] = yy xarr[i - 1] = xx # grid search maxcorrect = -1 cpos = 0 gpos = 0 for c in range(1, 16, 1): for gg in range(0, 256, 1): g = gg * 1.0 / 256 ret = test(yarr, xarr, c, g) if ret > maxcorrect: maxcorrect = ret cpos = c gpos = g print "current c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect print "c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect #test(yarr, xarr, 3, 1.0 / 64) if __name__ == '__main__': main()
标签:
原文地址:http://blog.csdn.net/jollyjumper/article/details/45457223