码迷,mamicode.com
首页 > 其他好文 > 详细

我的第一个svm程序:手写字识别

时间:2015-05-03 12:04:11      阅读:172      评论:0      收藏:0      [点我收藏+]

标签:

之前学过svm相关知识,基本原理不算复杂,今天做了一个手写字识别程序,总算验证了svm的效果。

因为只是验证效果,实现上原则是简单,使用python + libsvm + PIL(python image library)。这部分工作花了一些时间:

PIL:

http://www.pythonware.com/products/pil/

下载源码包,解压之后运行:python setup.py install即可。


max下python libsvm安装使用:http://blog.csdn.net/u012774963/article/details/14640583
libsvm python接口介绍:http://blog.csdn.net/lqhbupt/article/details/8599295


说是手写字,其实只是一到十这十个汉字,这样比较简单,而且收集的样本不太多。这十个汉字,在mac上用paintbrush前前后后画了259个80*80的png图片。图片缩放为16*16,二值化之后用一个256维的向量表示,简单粗暴。准备训练数据文件:inittraindata.py

#! /usr/bin/env python

import Image
import os

f = []
for i in range(1,11):
  f.append(open('ocr_' + str(i), 'wb'))

for i in range(1,11):
  for item in os.listdir(str(i)):
    path = os.path.join(str(i), item)
    if os.path.isfile(path) and path.endswith(".png"):
      img_org = Image.open(path)
      img = img_org.resize((16,16), Image.NEAREST)
      pixdata = img.load()
      # -1
      for j in range(1,i):
        line = "-1 "
	for k in range(0, 256):
	  line += str(k + 1)
	  if pixdata[k / 16,k % 16][0] == 255:
	    line += ":0 "
	  else:
	    line += ":1 "
        f[j - 1].write(line + "\n")
      # -1
      for j in range(i + 1, 11):
        line = "-1 "
	for k in range(0, 256):
	  line += str(k + 1)
	  if pixdata[k / 16, k % 16][0] == 255:
	    line += ":0 "
	  else:
	    line += ":1 "
	f[j - 1].write(line + "\n")
      # 1
      line = "1 "
      for k in range(0, 256):
        line += str(k + 1)
	if pixdata[k / 16, k % 16][0] == 255:
	  line += ":0 "
	else:
	  line += ":1 "
      f[i - 1].write(line + "\n")

for o in f:
  o.close

训练数据并保存模型save.py:

#! /usr/bin/env python

import sys
from svmutil import *
import Image
import random

for i in range(1, 11):
  y, x = svm_read_problem('./ocr_' + str(i))
#  if i == 4 or i == 3:
#    m = svm_train(y, x, '-c 10000')
#  else:
  m = svm_train(y, x, '-c 3 -g 0.015625')
  svm_save_model('./model_' + str(i), m)

预测predict.py:

#! /usr/bin/env python

import sys
from svmutil import *
import Image

# load
m = []
for i in range(1, 11):
  m.append(svm_load_model('./model_' + str(i)))

# predict
path = sys.argv[1]
img_org = Image.open(path)
img = img_org.resize((16,16), Image.NEAREST)
pixdata = img.load()

line = "-1 "
tmpfile = open("tmpfile", "wb")
for i in range(0, 256):
  line += str(i + 1)
  if pixdata[i / 16, i % 16][0] == 255:
    line += ":0 "
  else:
    line += ":1 "
tmpfile.write(line + "\n")
tmpfile.close()

max = 100.0
maxidx = -1
for i in range(1, 11):
  y, x = svm_read_problem("tmpfile")
  label, acc, val = svm_predict(y, x, m[i - 1])
  print val[0][0]
  if abs(val[0][0] - 1.0) < max:
    max = abs(val[0][0] - 1.0)
    maxidx = i

print "probably is: ", maxidx


使用c-svm,核函数使用RBF,参数c=3,gama=1.0/64,参数怎么选的,用的是简单粗暴的grid search,gridsearch.py:

#! /usr/bin/env python

from svmutil import *
import random

def test(y, x, c, g):
  count = len(y[0])
  correct_rate = 0.0
  # n-fold cross-validation
  for i in range(0, 10):
    marr = []
    tarr = []
    answers = []
    for k in range(count*i/10, count*(i+1)*10):
      answers.append(0)

    for k in range(1, 11):
      # training sets
      yy = []
      xx = []
      for j in range(0, count*i/10):
        yy.append(y[k - 1][j])
        xx.append(x[k - 1][j])
      for j in range(count*(i + 1)/10, count):
        yy.append(y[k - 1][j])
        xx.append(x[k - 1][j])
      m = svm_train(yy, xx, '-c ' + str(c) + ' -g ' + str(g))
      marr.append(m)
      yyy = []
      xxx = []
      for j in range(count*i/10, count*(i+1)/10):
        yyy.append(y[k - 1][j])
	if y[k - 1][j] == 1:
	  answers[j - count*i/10] = k
	xxx.append(x[k - 1][j])
      # test sets
      tarr.append((yyy, xxx))

    print answers
    # predicting
    correct_count = 0
    for j in range(0, len(tarr[0][0])):
      max = 10000.0
      maxidx = -1
      for k in range(1, 11):
        label, acc, val = svm_predict(tarr[k - 1][0][j:j+1], tarr[k - 1][1][j:j+1], marr[k - 1])
	if abs(val[0][0] - 1.0) < max:
	  max = abs(val[0][0] - 1.0)
	  maxid = k
      print "probably is", maxid, " answer is", answers[j]  
      if answers[j] == maxid:
        correct_count += 1
    correct_rate += float(correct_count) / len(tarr[0][0])

  correct_rate /= 10
  print 'c=',c,'g=',g,'avg_correct_rate=',correct_rate
  return correct_rate

def main():
  yarr = []
  xarr = []
  for i in range(1, 11):
    y, x = svm_read_problem('./ocr_' + str(i))
    yarr.append(y)
    xarr.append(x)

  #shuffle
  arr = []
  for i in range(0, len(yarr[0])):
    arr.append(i)
  random.shuffle(arr)
  print "RANDOM ARR:",arr

  count = len(yarr[0])
  for i in range(1, 11):
    yy = []
    xx = []
    y = yarr[i - 1]
    x = xarr[i - 1]
    for j in range(0, count):
      yy.append(y[arr[j]])
      xx.append(x[arr[j]])
    yarr[i - 1] = yy
    xarr[i - 1] = xx

  # grid search
  maxcorrect = -1
  cpos = 0
  gpos = 0
  for c in range(1, 16, 1):
    for gg in range(0, 256, 1):
      g = gg * 1.0 / 256
      ret = test(yarr, xarr, c, g)
      if ret > maxcorrect:
        maxcorrect = ret
	cpos = c
	gpos = g
        print "current c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect

  print "c=",cpos,"g=",gpos,"maxcorrect=",maxcorrect
  #test(yarr, xarr, 3, 1.0 / 64)

if __name__ == '__main__':
  main()


使用最优参数,估计出来的识别率在85%左右(参数调整影响只有几个点),和样本有关。如果写字比较规范,识别率应该在95%以上,可以想见用印刷体,识别率会有多高。如果歪着写字,或者大小比率比较奇怪,误识别率还是蛮高的。

我的第一个svm程序:手写字识别

标签:

原文地址:http://blog.csdn.net/jollyjumper/article/details/45457223

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!