码迷,mamicode.com
首页 > 其他好文 > 详细

产生分类中的双月问题的数据集

时间:2015-10-25 09:39:09      阅读:2047      评论:0      收藏:0      [点我收藏+]

标签:

在书”神经网络与机器学习“中,经常要用到一个数据集,就是双月数据集,作者貌似没提供生成这个数据集的代码,网上找了一个matlab版本,生成效果如下

技术分享

代码如下:

技术分享
 1 function data=dbmoon(N,d,r,w)
 2 % Usage: data=dbmoon(N,d,r,w)
 3 % doublemoon.m - genereate the double moon data set in Haykins book titled
 4 % "neural networks and learning machine" third edition 2009 Pearson
 5 % Figure 1.8 pp. 61
 6 % The data set contains two regions A and B representing 2 classes
 7 % each region is a half ring with radius r = 10, width = 6, one is upper
 8 % half and the other is lower half
 9 % d: distance between the two regions
10 % will generate region A centered at (0, 0) and region B is a mirror image
11 % of region A (w.r.t. x axis) with a (r, d) shift of origin
12 % N: # of samples each class, default = 1000
13 % d: seperation of two class, negative value means overlapping (default=1)
14 % r: radius (default=10), w: width of ring (default=6)
15 % 
16 % (C) 2010 by Yu Hen Hu
17 % Created: Sept. 3, 2010
18 
19 % clear all; close all;
20 if nargin<4, w=6; end
21 if nargin<3, r=10; end
22 if nargin<2, d=1; end
23 if nargin < 1, N=1000; end
24 
25 % generate region A:
26 % first generate a uniformly random distributed data points from (-r-w/2, 0)
27 % to (r+w/2, r+w/2)
28 N1=10*N;  % generate more points and select those meet criteria
29 w2=w/2; 
30 done=0; data=[]; tmp1=[];
31 while ~done, 
32     tmp=[2*(r+w2)*(rand(N1,1)-0.5) (r+w2)*rand(N1,1)];
33     % 3rd column of tmp is the magnitude of each data point
34     tmp(:,3)=sqrt(tmp(:,1).*tmp(:,1)+tmp(:,2).*tmp(:,2)); 
35     idx=find([tmp(:,3)>r-w2] & [tmp(:,3)<r+w2]);
36     tmp1=[tmp1;tmp(idx,1:2)];
37     if length(idx)>= N, 
38         done=1;
39     end
40     % if not enough data point, generate more and test
41 end
42 % region A data and class label 0
43 % region B data is region A data flip y coordinate - d, and x coordinate +r
44 data=[tmp1(1:N,:) zeros(N,1);
45     [tmp1(1:N,1)+r -tmp1(1:N,2)-d ones(N,1)]];
46 
47  plot(data(1:N,1),data(1:N,2),.r,data(N+1:end,1),data(N+1:end,2),.b);
48  title([Double moon data set, d =  num2str(d)]),
49  axis([-r-w2 2*r+w2 -r-w2-d r+w2])
50 
51 save dbmoon N r w d data;
View Code

 

由于最近使用python,所以使用python翻译了一下matlab版本的代码,自己写了一个python版本

代码如下:

技术分享
 1 # Usage: data=dbmoon(N,d,r,w)
 2 # dbmoon.py - genereate the double moon data set in Haykin‘s book titled
 3 # "neural networks and learning machine" third edition 2009 Pearson
 4 # Figure 1.8 pp. 61
 5 # The data set contains two regions A and B representing 2 classes
 6 # each region is a half ring with radius r = 10, width = 6, one is upper
 7 # half and the other is lower half
 8 # d: distance between the two regions
 9 # will generate region A centered at (0, 0) and region B is a mirror image
10 # of region A (w.r.t. x axis) with a (r, d) shift of origin
11 # N: # of samples each class, default = 1000
12 # d: seperation of two class, negative value means overlapping (default=1)
13 # r: radius (default=10), w: width of ring (default=6)
14 #
15 # (C) 2015 by Wanqian Luo
16 # Created: Oct. 25, 2010
17 
18 import numpy as np
19 def dbmoon(N=1000, d=1, r=10, w=6):
20     N1 = 10*N
21     w2 = w/2
22     done = True
23     data = np.empty(0)
24     while done:
25         tmp_x = 2*(r+w2)*(np.random.random([N1,1])-0.5)
26         tmp_y = (r+w2)*np.random.random([N1,1])
27         tmp = np.concatenate((tmp_x, tmp_y), axis=1)
28         tmp_ds = np.sqrt(tmp_x*tmp_x + tmp_y*tmp_y)
29 
30         idx = np.logical_and(tmp_ds>(r-w2), tmp_ds<(r+w2))
31         idx = (idx.nonzero())[0]
32 
33         if data.shape[0] == 0:
34             data = tmp.take(idx,axis=0)
35         else:
36             data = np.concatenate((data, tmp.take(idx,axis=0)),axis=0)
37         if data.shape[0] >= N:
38             done = False
39 
40     db_moon = data[0:N,:]
41     data_t = np.empty([N,2])
42     data_t[:,0] = data[0:N,0] + r
43     data_t[:,1] = -data[0:N,1] - d
44     db_moon = np.concatenate((db_moon, data_t), axis=0)
45     return db_moon
View Code

效果如下:

技术分享

提供下代码下载,里面包含了如何调用代码的demo:http://pan.baidu.com/s/1bnsNrm3 密码:2ekd

产生分类中的双月问题的数据集

标签:

原文地址:http://www.cnblogs.com/cpointer/p/4908216.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!