DeepLearning tutorial(2)机器学习算法在训练过程中保存参数
@author:wepon
@blog:http://blog.csdn.net/u012162613/article/details/43169019
参考:pickle — Python object serialization、DeepLearning Getting started
用到python里的gzip以及cPickle模块,简单的使用代码如下,如果想详细了解可以参考上面给出的链接。
#以读取mnist.pkl.gz为例
import cPickle, gzip
f = gzip.open('mnist.pkl.gz', 'rb')
train_set, valid_set, test_set = cPickle.load(f)
f.close()其实就是分两步,先读取gz文件,再读取pkl文件。pkl文件的应用正是下文要讲的,我们用它来保存机器学习算法训练过程中的参数。
a=[1,2,3]
b={4:5,6:7}
#保存,cPickle.dump函数。/home/wepon/ab是路径,ab是保存的文件的名字,如果/home/wepon/下本来就有ab这个文件,将被覆写#,如果没有,则创建。'wb'表示以二进制可写的方式打开。dump中的-1表示使用highest protocol。
import cPickle
write_file=open('/home/wepon/ab','wb')
cPickle.dump(a,write_file,-1)
cPickle.dump(b,write_file,-1)
write_file.close()
#读取,cPickle.load函数。
read_file=open('/home/wepon/ab','rb')
a_1=cPickle.load(read_file)
b_1=cPickle.load(read_file)
print a,b
read_file.close()import cPickle
#保存
write_file = open('path', 'wb')  
cPickle.dump(w.get_value(borrow=True), write_file, -1)  
cPickle.dump(v.get_value(borrow=True), write_file, -1)  
cPickle.dump(u.get_value(borrow=True), write_file, -1) 
write_file.close()
#读取
read_file = open('path')
w.set_value(cPickle.load(read_file), borrow=True)
v.set_value(cPickle.load(read_file), borrow=True)
u.set_value(cPickle.load(read_file), borrow=True)
read_file.close()if this_validation_loss < best_validation_loss:
这句代码的意思就是判断当前的验证损失是否小于最佳的验证损失,是的话,下面会更新best_validation_loss,也就是说当前参数下,模型比之前的有了优化,因此我们可以在这个if语句后面加入保存参数的代码:
save_params(classifier.W,classifier.b)
save_params函数定义如下:
def save_params(param1,param2):
	import cPickle
	write_file = open('params', 'wb') 
	cPickle.dump(param1.get_value(borrow=True), write_file, -1)
	cPickle.dump(param2.get_value(borrow=True), write_file, -1)
	write_file.close()当然参数的个数根据需要去定义。在logistic_sgd.py中参数只有classifier.W,classifier.b,因此这里定义为save_params(param1,param2)。
import cPickle
f=open('params')
w=cPickle.load(f)
b=cPickle.load(f)
#w大小是(n_in,n_out),b大小时(n_out,),b的值如下,因为MINST有10个类别,n_out=10,下面正是10个数
array([-0.0888151 ,  0.16875755, -0.03238435, -0.06493175,  0.05245609,
        0.1754718 , -0.0155049 ,  0.11216578, -0.26740651, -0.03980861])class LogisticRegression(object):
    def __init__(self, input, n_in, n_out):
        self.W = theano.shared(
            value=numpy.zeros(
                (n_in, n_out),
                dtype=theano.config.floatX
            ),
            name='W',
            borrow=True
        )
        self.b = theano.shared(
            value=numpy.zeros(
                (n_out,),
                dtype=theano.config.floatX
            ),
            name='b',
            borrow=True
        )
#!!!
#加入的代码在这里,程序运行到这里将会判断当前路径下有没有params文件,有的话就拿来初始化W和b
	if os.path.exists('params'):
		f=open('params')
		self.W.set_value(cPickle.load(f), borrow=True)
		self.b.set_value(cPickle.load(f), borrow=True)DeepLearning tutorial(2)机器学习算法在训练过程中保存参数
原文地址:http://blog.csdn.net/u012162613/article/details/43169019