码迷,mamicode.com
首页 > 其他好文 > 详细

山东大学 机器学习 实验报告 实验2 模式分类 上机练习

时间:2019-10-28 00:47:56      阅读:197      评论:0      收藏:0      [点我收藏+]

标签:实验报告   raw   sort   处理   输出   hashmap   value   利用   rgs   

 

【17级的同辈们,这是我实验报告真实且全部的内容,求求求求你们,不要让我后悔提前发布 ╥﹏╥... 。真的挺简单的,1天就能搞定,而且在书里的位置我都标注出来了,让我们来一起学习吧!!!( ̄▽ ̄)",当然错了也概不负责哈~~~~】

 

3、实验内容及说明 
使用上面给出的三维数据: 
1. 编写程序,对类 1 和类 2 中的 3 个特征
x i 分别求解最大似然估计的均值???和方差?? ?2。 
2. 编写程序,处理二维数据的情形??(??)~??(??,??)。对类 1 和类 2 中任意两个特征的组合分别求解最大
似然估计的均值?? ?和方差?? ?(每个类有 3 种可能) 。 
3. 编写程序,处理三维数据的情形??(??)~??(??,??)。对类 1 和类 2 中三个特征求解最大似然估计的均值
?? ?和方差?? ?。 
4. 假设三维高斯模型是可分离的,即?? = ????????(??1 2,??2 2,??3 2),编写程序估计类 1 和类 2 中的均值和协方
差矩阵中的参数。 
5. 比较前 4 种方法计算出来的每一个特征的均值 μi的异同,并加以解释。 
6. 比较前 4 种方法计算出来的每一个特征的方差 σi的异同,并加以解释。 
 

 

 

1、 

a)预处理二列矩阵,就是类1里选1个,3种可能、类2里选1个,3种可能

b)均值的最大似然估计就是样本均值,协方差的最大似然估计就是协方差矩阵均值,运用书P71的式(16)、式(17)计算均值???和方差?? ?2。

 

2、 

a)  预处理四列矩阵,就是类1里选2个,3种可能、类2里选2个,3种可能

b)  同上

 

3、

a)六列矩阵不预处理

b)同上

 

4、

    a)该题属于3.2.3 ??已知而?? ?未知的情况,用式(11)估算出?? ?值,(事实上,可以发现?? ?的计算方式并没有改变)

        b)由3.2.4节知,??的参数可能是1/n和1/(n-1),我们希望选择的参数能导致分类结果最优,也就是能让最大值和第二大值差别更大。我的思路是,用单位向量作为测试点,得到所有判别值,选择第一和第二的差的绝对值最大的参数

 

5、6:异同嘛……写就好了

 

每一道题的类包含相应预处理,内核方法在math包里:

 1 import static math.M.*;
 2 
 3 //1题
 4 public class a {
 5     //        预处理二列
 6     public static void main(String[] args) {
 7         int group = 2;
 8         int len0 = 1;
 9         for (int i = 0; i < 3; i++) {
10             for (int j = 0; j < 3; j++) {
11                 double[][] x0 = new double[raw.length][2];
12                 addCols(raw, i, x0, 0);
13                 addCols(raw, j, x0, 1);
14                 double[][][] xx = ini(x0, group, len0);
15 
16                 double[][] u = getU(xx, group, len0);
17                 double[][][] sigma = getSigma(xx, group, len0, xx.length);
18 
19                 System.out.println("组合:" + (i + 1) + " " + (j + 1));
20                 System.out.println("\uD835\uDF07?");
21                 print(u);
22                 System.out.println("\uD835\uDF0E ?2");
23                 print(sigma);
24                 System.out.println();
25             }
26         }
27     }
28 }

 

 1 import static math.M.*;
 2 
 3 //2题
 4 public class b {
 5     //    预处理出四列
 6     public static void main(String[] args) {
 7         int group = 2;
 8         int len0 = 2;
 9 
10         double[][] x0 = new double[raw.length][4];
11         int[][]x={
12                 {0,1},
13                 {0,2},
14                 {1,2},
15         };
16 
17         for (int i = 0; i < 3; i++) {
18             for (int j = 0; j < 3; j++) {
19                 print2(x0,group,len0,x[i][0],x[i][1],x[j][0],x[j][1]);
20             }
21         }
22     }
23 
24     static void print2(double[][] x, int group, int len0, int x1, int x2, int x3, int x4) {
25         System.out.println("类1 特征" + (x1 + 1) + (x2 + 1) + " 类2 特征" + (x3 + 1) + (x4 + 1));
26         addCols(raw, x1, x, 0);
27         addCols(raw, x2, x, 1);
28         addCols(raw, x3 + 3, x, 2);
29         addCols(raw, x4 + 3, x, 3);
30         double[][][] xx = ini(x, group, len0);
31         print(getU(xx, group, len0), getSigma(xx, group, len0,x.length));
32 
33     }
34 }

 

 1 import static math.M.*;
 2 
 3 //3题
 4 public class c {
 5 //    六列不处理
 6     public static void main(String[] args) {
 7         double[][]X=raw;
 8         int group = 2;
 9         int len0 = 3;
10         double[][][] x = ini(X, group, len0);
11         double[][] u = getU(x, group, len0);
12         double[][][] sigma = getSigma(x, group, len0,x.length);
13 
14         System.out.println("\uD835\uDF07?");
15         print(u);
16         System.out.println();
17         System.out.println("\uD835\uDF0E ?2");
18         print(sigma);
19     }
20 }

 

 1 import java.util.Arrays;
 2 
 3 import static java.lang.Math.abs;
 4 import static math.Classify.g;
 5 import static math.M.*;
 6 
 7 //4题
 8 public class d {
 9     public static void main(String[] args) {
10         System.out.println("n: "+raw.length);
11         System.out.println(getSigmaDiv(ini(raw, 2, 3), new double[]{1 / 2, 1 / 2}, 2, 3));
12     }
13 
14     static int getSigmaDiv(double[][][] x, double[] p, int group, int len0) {
15         double[][][] sigma1 = getSigma(x, group, len0, x.length);
16         double[][][] sigma2 = getSigma(x, group, len0, x.length - 1);
17         double[][] u = getU(x, group, len0);
18 
19         double[][] x0 = new double[1][u[0].length];
20         Arrays.fill(x0[0], 1);
21 
22         double[] t1 = new double[group];
23         double[] t2 = new double[group];
24 
25         for (int j = 0; j < group; j++) {
26             t1[j] = g(x0[0], u[j], sigma1[j], p[j]);
27             t2[j] = g(x0[0], u[j], sigma2[j], p[j]);
28         }
29         Arrays.sort(t1);
30         Arrays.sort(t2);
31 
32         return abs(t1[t1.length - 1] - t1[t1.length - 2]) > abs(t2[t2.length - 1] - t2[t2.length - 2]) ?
33                 x.length : x.length - 1;
34     }
35 }

 

下面是math包里的内容

 

  1 package math;
  2 
  3 import java.util.Arrays;
  4 
  5 public class M {
  6     public static double[][] raw = {
  7             {0.011, 1.03, -0.21, 1.36, 2.17, 0.14},
  8             {1.27, 1.28, 0.08, 1.41, 1.45, -0.38},
  9             {0.13, 3.12, 0.16, 1.22, 0.99, 0.69},
 10             {-0.21, 1.23, -0.11, 2.46, 2.19, 1.31},
 11             {-2.18, 1.39, -0.19, 0.68, 0.79, 0.87},
 12             {0.34, 1.96, -0.16, 2.51, 3.22, 1.35},
 13             {-1.38, 0.94, 0.45, 0.60, 2.44, 0.92},
 14             {-1.02, 0.82, 0.17, 0.64, 0.13, 0.97},
 15             {-1.44, 2.31, 0.14, 0.85, 0.58, 0.99},
 16             {0.26, 1.94, 0.08, 0.66, 0.51, 0.88},
 17 
 18     };
 19 
 20     //    输入:x[i][j]第i个样本j类的向量,group类数量,len0维度
 21     //    输出:sigma[i]第i组的协方差矩阵
 22     public static double[][][] getSigma(double[][][] x, int group, int len0,int div) {
 23         double[][] u = getU(x, group, len0);
 24         double[][][] sigma = new double[group][len0][len0];
 25 
 26         //1/n*sigma(xk-u‘)*(xk-u‘).T
 27         for (int i = 0; i < group; i++) {
 28             SUB(x, u[i], i);
 29             double[][] t = new double[len0][len0];
 30             for (int j = 0; j < x.length; j++) ADD(t, cov0(x[j][i]));
 31             DIVI(div, t);
 32             sigma[i] = t;
 33         }
 34         return sigma;
 35     }
 36 
 37 
 38 
 39     //    输入:向量a
 40     //    输出:a的方差
 41     public static double[][] cov0(double[] a) {
 42         double[][] row = new double[1][a.length];
 43         row[0] = a;
 44         double[][] col = new double[a.length][1];
 45         for (int i = 0; i < a.length; i++) col[i][0] = a[i];
 46 
 47         return mult(col, row);
 48     }
 49 
 50 
 51     //    输入:x[i][j]第i个样本j类的向量,group类数量,len0维度
 52     //    输出:u[i]第i组的均值向量
 53     public static double[][] getU(double[][][] x, int group, int len0) {
 54         double[][] u = new double[group][len0];
 55         //        1/n*sigma(xk)
 56         for (int i = 0; i < group; i++) {
 57             double[] t = new double[len0];
 58             for (int j = 0; j < x.length; j++) ADD(t, x[j][i]);
 59             DIVI(x.length, t);
 60             u[i] = t;
 61         }
 62         return u;
 63     }
 64 
 65 
 66     //    输入:样本X,group类总数,len0维度
 67     //    输出:x[i][j]样本i的j类的向量
 68     public static double[][][] ini(double[][] X, int group, int len0) {
 69         double[][][] ret = new double[X.length][group][];
 70         for (int i = 0; i < X.length; i++) {
 71             for (int j = 0; j < group; j++) {
 72                 int s = j * len0, e = s + len0;
 73                 double[] t = new double[len0];
 74                 for (int z = s, p = 0; z < e; z++, p++) t[p] = X[i][z];
 75                 ret[i][j] = t;
 76             }
 77         }
 78         return ret;
 79     }
 80 
 81 //    ------------------------工具方法-----------------------------------
 82 
 83 
 84     //    利用副作用,把SUB的col列取出,-=b
 85     public static void SUB(double[][][] SUB, double[] b, int col) {
 86         for (int i = 0; i < SUB.length; i++) {
 87             for (int j = 0; j < SUB[0][0].length; j++) {
 88                 SUB[i][col][j] -= b[j];
 89             }
 90         }
 91     }
 92 
 93     public static double[] sub(double[] x, double[] u) {
 94         double[] ret = new double[x.length];
 95         for (int i = 0; i < x.length; i++) {
 96             ret[i] = x[i] - u[i];
 97         }
 98         return ret;
 99     }
100 
101 
102     //    利用副作用,ALL+=b
103     public static void ADD(double[] ALL, double[] b) {
104         for (int i = 0; i < ALL.length; i++) {
105             ALL[i] += b[i];
106         }
107     }
108 
109     //    利用副作用,ALL+=b
110     public static void ADD(double[][] ALL, double[][] b) {
111         for (int i = 0; i < ALL.length; i++) {
112             for (int j = 0; j < ALL[0].length; j++) {
113                 ALL[i][j] += b[i][j];
114             }
115         }
116     }
117     //  n*矩阵
118     public static double[][] mult(double n, double[][] a) {
119         double[][] b = new double[a.length][a[0].length];
120         for (int i = 0; i < a.length; i++) {
121             for (int j = 0; j < a[0].length; j++) {
122                 b[i][j] = a[i][j] * n;
123             }
124         }
125         return b;
126     }
127     //    矩阵乘法
128     public static double[][] mult(double[][] a, double[][] b) {
129         double[][] c = new double[a.length][b[0].length];
130         for (int i = 0; i < a.length; i++) {
131             for (int j = 0; j < b[0].length; j++) {
132                 for (int k = 0; k < a[0].length; k++) {
133                     c[i][j] += (a[i][k] * b[k][j]);
134                 }
135             }
136         }
137         return c;
138     }
139 
140     //    矩阵/n
141     public static void DIVI(double n, double[] a) {
142         for (int i = 0; i < a.length; i++) {
143             a[i] /= n;
144         }
145     }
146 
147     public static void DIVI(double n, double[][] a) {
148         for (int i = 0; i < a.length; i++) {
149             for (int j = 0; j < a[0].length; j++) {
150                 a[i][j] /= n;
151             }
152         }
153     }
154 
155 
156 
157     public static double[][] rowToCol(double[] x) {
158         double[][] ret = new double[x.length][1];
159         for (int i = 0; i < x.length; i++) {
160             ret[i][0] = x[i];
161         }
162         return ret;
163     }
164 
165     public static double[][] colToRow(double[][] wi) {
166         double[][] ret = new double[1][wi.length];
167         for (int i = 0; i < wi.length; i++) {
168             ret[0][i] = wi[i][0];
169         }
170         return ret;
171     }
172 
173 
174     //    输入:a样本数据,col a要加的col列,all被加数组,col2 加到all的col2列
175     public static void addCols(double[][] a, int col, double[][] all, int col2) {
176         for (int i = 0; i < all.length; i++) all[i][col2] = a[i][col];
177     }
178 
179     public static void print(double[][] u, double[][][] sigma) {
180         System.out.println("\uD835\uDF07?");
181         print(u);
182         System.out.println("\uD835\uDF0E ?2");
183         print(sigma);
184         System.out.println();
185     }
186 
187     public static void print(double[][][] a) {
188         for (int i = 0; i < a.length; i++) {
189             print(a[i]);
190             System.out.println();
191         }
192     }
193 
194     public static void print(double[][] a) {
195         for (int i = 0; i < a.length; i++) print(a[i]);
196     }
197 
198     public static void print(double[] a) {
199         System.out.println(Arrays.toString(a));
200     }
201 }

 

 1 package math;
 2 
 3 import java.util.Arrays;
 4 import java.util.HashMap;
 5 
 6 import static java.lang.Math.log;
 7 import static math.Inv.det;
 8 import static math.Inv.inv;
 9 import static math.M.*;
10 import static math.M.mult;
11 
12 public class Classify {
13 
14     public static int[] classify(double[][] x0, double[][][] x, double[] p, int group, int len0) {
15         double[][][] sigma = getSigma(x, group, len0, x.length);
16         double[][] u = getU(x, group, len0);
17 
18         int[] clas = new int[x.length];
19         double[] t = new double[group];
20 
21         for (int i = 0; i < x.length; i++) {
22             double min = Double.MIN_VALUE;
23             for (int j = 0; j < group; j++) {
24                 t[j] = g(x0[i], u[j], sigma[j], p[j]);
25                 if (t[j] > min) {
26                     min = t[j];
27                     clas[i] = j + 1;
28                 }
29             }
30 
31         }
32         return clas;
33     }
34 
35     //    输入:x测试点,u某个类的均值向量,sigma某个类的协方差矩阵,p某个类的先验概率
36 //    输出:x的判别函数值(方程对应《模式分类第二版》P32 (66)~(69))
37     public static double g(double[] x0, double[] u0, double[][] sigma0, double p) {
38         double[][] Wi = mult(-0.5, inv(sigma0));
39         double[][] t1 = mult(new double[][]{x0}, Wi);
40         double t2 = mult(t1, rowToCol(sub(x0, u0)))[0][0];
41 
42         double[][] wi = mult(inv(sigma0), rowToCol(u0));
43         double t3 = mult(colToRow(wi), rowToCol(x0))[0][0];
44 
45         double[][] t4 = mult(new double[][]{u0}, inv(sigma0));
46         double[][] t5 = mult(t4, rowToCol(u0));
47         double t6 = mult(-0.5, t5)[0][0];
48         t6 = t6 - 0.5 * log(det(sigma0)) + log(p);
49 
50         return t2 + t3 + t6;
51     }
52 }

 

java里实现inv()和det(),我直接找这位兄弟的:https://blog.csdn.net/qiyu93422/article/details/46921095

 

山东大学 机器学习 实验报告 实验2 模式分类 上机练习

标签:实验报告   raw   sort   处理   输出   hashmap   value   利用   rgs   

原文地址:https://www.cnblogs.com/towerbird/p/11745116.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!