码迷,mamicode.com
首页 > 其他好文 > 详细

机器学习之实战matlab神经网络工具箱

时间:2015-08-11 23:33:36      阅读:6172      评论:0      收藏:0      [点我收藏+]

标签:神经网络   机器学习   matlab   算法   

上节在
《机器学习之从logistic到神经网络算法》
中,我们已经从原理上介绍了神经网络算法的来源与构造,并编程实战了简单神经网络对于线性与非线性数据的分类测试实验。看过上节的可能会发现,上节实现的算法对于非线性数据的分类效果并不是非常完美,有许多值得优化的地方。而matlab作为一个科学计算软件,本身集成了非常多的优化算法,其中的神经网络工具箱就是其中一个优秀的工具箱,本节将以工具箱中的函数重新实验上节中的分类实验。

首先来了解这个工具箱。我们说,一个简单的神经网络就如下图所示:
技术分享
这是我们上节用过的网络,含有两层,每层3个节点的网络。然而已经证明,其实只含有一个隐含层的网络是可以拟合出任意有限的输入输出映射问题的。基于此,在matlab集成的工具箱中我们只能看到只含有一个隐含层的网络,而网络的节点个数却是我们可以需要改变与设计的。简单来工具箱的一个简单神经网络可以表示如下:
技术分享
通过输入然后映射到输出,输出在映射到最终输出。整个网络包含权值w、v和常数b。网络需要设计的就是确定隐含层的节点个数(matlab默认10),同时映射函数需要注意,上节我们默认的都是采用sigmod函数,然而matlab默认的是隐含层输出采用sigmod,而输出后的映射采用的是线性映射,像上图中的在输出部分的红色所示。其实除了这两种映射函数外,还有正切等等函数,可以作为映射函数的共同点就是他们都是可导的(这是原理推导上所必须的)。那么matlab默认输出映射为什么是线性的而不是sigmod的呢?我们知道sigmod函数会把数据映射到0-1之间,然而我们实际应用中获得的数据的目标值(分类标签)并不一定都在0-1之间,那么如果用sigmod函数,那么我们需要把原始数据的目标输出转化到0-1之间,而线性映射就没有这个问题(因为它映射后的数据没有上下限)。

好了下面说说关于几个重要的函数:

  1. 网络创建函数:feedforwardnet(hiddenSizes,trainFcn) ,(matlab2012后较新的版本),在老版本的matlab,这个函数是newff。这个函数就是创建一个上述的(前馈)网络,包括两个参数,第一个hiddenSizes隐藏层的大小(实际就是节点数的多少,默认10),trainFcn是网络训练所采取的方法,这个方法包括:梯度下降算法、动量梯度下降算法、变学习率梯度下降算法等等数10种方法,各种方法各有优缺点,而有代表性的方法是有代表性的五种算法为:’traingdx’,’trainrp’,’trainscg’,’trainoss’,’trainlm’,默认的是’trainlm’。其实不需要了解的太详细,一般的数据默认的方法就非常的好了,所以这个参数可以不用管。给一个详细介绍这个函数的链接:

    matlab神经网络函数(feedforwardnet,fitnet,patternet)
    关于MATLAB神经网络命令feedforwardnet的一些记录

  2. 好了网络创建完以后,下面就是训练网络的参数了,这里的参数就是权值矩阵w,v,以及常数矩阵b。函数是train,关于这个函数其实可以有很多参数,也可以只有几个参数,因为好多参数都是有默认,而采用默认值就可以得到很好的效果,比如迭代次数,训练最小的允许误差等等。这里只说几个重要的参数,
    1.第一个参数,上述创建的网络net,
    2.第二个,训练数据,
    3.第三个,与训练数据对应的数据目标值(分类标签或者输出值等),这个输出值不光可以是一维的,还可以是多维的。

    部分详细参考如下:bp神经网络及matlab实现

而train函数出来的就是训练好的网络net,matlab出来的net是一个结构体数据,里面包括了网络的所有信息(训练方法,误差,包括我们熟悉的权值矩阵w,b等等),那么我们是否需要把w与b在提取出来呢?并不需要,在实际应用中,我们可以直接把需要测试的数据输入到net中就可以了。比如说现在网络训练好了就是net,那么来了一个测试样本sample,那么它的输出值就是net(sample),这样就可以达到预测的目的了。

好了,了解这些就可以实验了,还有许多详细的细节部分也可以改变,用不到的说多了可能也没有用就不说了。

下面同样实验上一节的两组人造样本集,线性集与非线性集,画出来如下:
技术分享
技术分享
代码如下:

%%  
% * matlab自带神经网络工具箱的分类设计
% * 线性与非线性分类
% 
%% 
clc
clear
close all
%% Load data
% * 数据预处理--分两类情况
data = load(‘data_test1.mat‘);
data = data.data‘;
% 将标签设置为0,1
data(:,3) = data(:,3) - 1;
%选择训练样本个数
num_train = 80;
%构造随机选择序列
choose = randperm(length(data));
train_data = data(choose(1:num_train),:);
gscatter(train_data(:,1),train_data(:,2),train_data(:,3));
label_train = train_data(:,end);
test_data = data(choose(num_train+1:end),:);
label_test = test_data(:,end);
%% 神经网络的构建与训练
% 构造神经网络(包含10个隐藏层的节点)
net = feedforwardnet(10);
% net.layers{2}.transferFcn = ‘tansig‘;% 输出的映射方法,默认purelin--线性映射
% 训练网络
net = train(net,train_data(:,1:end-1)‘,label_train‘);
% 显示构造的网络
view(net);
% 用这个网络来预测测试集的分类
y_test = net(test_data(:,1:end-1)‘);
%输出的值四舍五入,认为大于0.5的属于类‘1’,其他的属于类‘0’
predict = round(abs(y_test));

%% 显示结果--测试训
figure;
index1 = find(predict==0);
data1 = (test_data(index1,:))‘;
plot(data1(1,:),data1(2,:),‘or‘);
hold on
index2 = find(predict==1);
data2 = (test_data(index2,:))‘;
plot(data2(1,:),data2(2,:),‘*‘);
hold on
indexw = find(predict‘~=(label_test));
dataw = (test_data(indexw,:))‘;
plot(dataw(1,:),dataw(2,:),‘+g‘,‘LineWidth‘,3);
accuracy = length(find(predict‘==label_test))/length(test_data);
title([‘predict the training data and the accuracy is :‘,num2str(accuracy)]);

首先加载线性数据集:
技术分享
技术分享

绿色为错分的点,程序中我们采用的是10个隐含层节点,同时输出采用的默认的映射—线性映射。程序中也有sigmod映射(对应的注释掉的地方)。

下面进行非线性数据的测试:
技术分享
同样,绿色为错分的点。可以看到,准确率终于上升到80%以上(与上节对比,基本上不可能)。这里我们再把输出映射也改为sigmod映射,因为我们的输出标签已经改到了0、1标签,输入就在0-1之间,所以可以直接用。将上述对应的注释去掉,改完后的结果如下:
技术分享
可以看到输出采用sigmod函数效果似乎更好些。同时再来看看画出来的网络结构:
技术分享
看这个结果和上面那个结构,你发现了什么?是不是输出层的映射关系变化了。

这就是matlab下基本神经网络的训练与预测实验,其神经网络工具箱功能远远不止这。同时我这是采用函数命令实现的,matlab对于神经网络集成了GUI功能,可以直接在图形界面操作。在命令窗口输入:nnstart就会出现下面GUI界面:
技术分享
可以看到matlab下神经网络主要有四个方向用途:拟合数据,模式识别与分类数据,聚类数据,时间序列模型的数据处理,可以看到,我们这篇的这个问题其实是输入模式分类数据这部分的,考虑篇幅有限,感兴趣自己可以去详细研究其他用途,异常的强大。

总结一下该部分,matlab自带神经网络工具箱相比上节自己编的,对于线性数据准确率大概差不多,但是对于非线性数据的划分,工具箱的函数效果优化的非常好,同时用法也简单,运算速度快,可以说是一个非常好的一个分类方法。

版权声明:本文为博主原创文章,未经博主允许不得转载。

机器学习之实战matlab神经网络工具箱

标签:神经网络   机器学习   matlab   算法   

原文地址:http://blog.csdn.net/on2way/article/details/47428201

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!