标签:
不管是实验室研究机器学习算法或是公司研发,都有需要自己改进算法的时候,下面就说说怎么在weka里增加改进的机器学习算法。
一 添加分类算法的流程
1 编写的分类器必须继承 Classifier或是Classifier的子类;下面用比较简单的zeroR举例说明;
2 复写接口 buildClassifier,其是主要的方法之一,功能是构造分类器,训练模型;
3 复写接口 classifyInstance,功能是预测一个标签的概率;或实现distributeForInstance,功能是对得到所有的概率分布;
4 复写接口getCapabilities,其决定显示哪个分类器,否则为灰色;
5 参数option的set/get方法;
6 globalInfo和seedTipText方法,功能是说明作用;
7 见 第二部分,把这个分类器增加到weka应用程序上;
zeroR.java源码
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * ZeroR.java * Copyright (C) 1999 Eibe Frank * */ package weka.classifiers.rules; import weka.classifiers.Classifier; import weka.classifiers.Evaluation; import java.io.*; import java.util.*; import weka.core.*; /** * Class for building and using a 0-R classifier. Predicts the mean * (for a numeric class) or the mode (for a nominal class). * * @author Eibe Frank (eibe@cs.waikato.ac.nz) * @version $Revision: 1.11 $ */ public class ZeroR extends Classifier implements WeightedInstancesHandler { /** The class value 0R predicts. */ private double m_ClassValue; /** The number of instances in each class (null if class numeric). */ private double [] m_Counts; /** The class attribute. */ private Attribute m_Class; /** * Returns a string describing classifier * @return a description suitable for * displaying in the explorer/experimenter gui */ public String globalInfo() { return "Class for building and using a 0-R classifier. Predicts the mean " + "(for a numeric class) or the mode (for a nominal class)."; } /** * Generates the classifier. * * @param instances set of instances serving as training data * @exception Exception if the classifier has not been generated successfully */ public void buildClassifier(Instances instances) throws Exception { double sumOfWeights = 0; m_Class = instances.classAttribute(); m_ClassValue = 0; switch (instances.classAttribute().type()) { case Attribute.NUMERIC: m_Counts = null; break; case Attribute.NOMINAL: m_Counts = new double [instances.numClasses()]; for (int i = 0; i < m_Counts.length; i++) { m_Counts[i] = 1; } sumOfWeights = instances.numClasses(); break; default: throw new Exception("ZeroR can only handle nominal and numeric class" + " attributes."); } Enumeration enu = instances.enumerateInstances(); while (enu.hasMoreElements()) { Instance instance = (Instance) enu.nextElement(); if (!instance.classIsMissing()) { if (instances.classAttribute().isNominal()) { m_Counts[(int)instance.classValue()] += instance.weight(); } else { m_ClassValue += instance.weight() * instance.classValue(); } sumOfWeights += instance.weight(); } } if (instances.classAttribute().isNumeric()) { if (Utils.gr(sumOfWeights, 0)) { m_ClassValue /= sumOfWeights; } } else { m_ClassValue = Utils.maxIndex(m_Counts); Utils.normalize(m_Counts, sumOfWeights); } } /** * Classifies a given instance. * * @param instance the instance to be classified * @return index of the predicted class */ public double classifyInstance(Instance instance) { return m_ClassValue; } /** * Calculates the class membership probabilities for the given test instance. * * @param instance the instance to be classified * @return predicted class probability distribution * @exception Exception if class is numeric */ public double [] distributionForInstance(Instance instance) throws Exception { if (m_Counts == null) { double[] result = new double[1]; result[0] = m_ClassValue; return result; } else { return (double []) m_Counts.clone(); } } /** * Returns a description of the classifier. * * @return a description of the classifier as a string. */ public String toString() { if (m_Class == null) { return "ZeroR: No model built yet."; } if (m_Counts == null) { return "ZeroR predicts class value: " + m_ClassValue; } else { return "ZeroR predicts class value: " + m_Class.value((int) m_ClassValue); } } /** * Main method for testing this class. * * @param argv the options */ public static void main(String [] argv) { try { System.out.println(Evaluation.evaluateModel(new ZeroR(), argv)); } catch (Exception e) { System.err.println(e.getMessage()); } } }
二 添加模糊聚类算法流程
1.按照weka接口,写好一个模糊聚类算法,源码见最下面FuzzyCMeans.java ;并
2.把源码拷贝到weka.clusterers路径下;
3.修改 weka.gui.GenericObjectEditor.props ,在#Lists the Clusterers I want to choose from 的 weka.clusterers.Clusterer=\下加入:weka.clusterers.FuzzyCMeans
4. 相应的修改 weka.gui.GenericPropertiesCreator.props ,此去不用修改,因为包 weka.clusterers 已经存在,若加入新的包时则必须修改这里,加入新的包;
FuzzyCMeans.java源码:
/* * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA. */ /* * FCM.java * Copyright (C) 2007 Wei Xiaofei * */ package weka.clusterers; import weka.classifiers.rules.DecisionTableHashKey; import weka.core.Capabilities; import weka.core.Instance; import weka.core.Instances; import weka.core.Option; import weka.core.Utils; import weka.core.WeightedInstancesHandler; import weka.core.Capabilities.Capability; import weka.core.matrix.Matrix; import weka.filters.Filter; import weka.filters.unsupervised.attribute.ReplaceMissingValues; import java.util.Enumeration; import java.util.HashMap; import java.util.Random; import java.util.Vector; /** <!-- globalinfo-start --> * Cluster data using the Fuzzy C means algorithm * <p/> <!-- globalinfo-end --> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N <num> * number of clusters. * (default 2).</pre> * * <pre> -F <num> * exponent. * (default 2).</pre> * * <pre> -S <num> * Random number seed. * (default 10)</pre> * <!-- options-end --> * * @author Wei Xiaofei * @version 1.03 * @see RandomizableClusterer */ public class FuzzyCMeans extends RandomizableClusterer implements NumberOfClustersRequestable, WeightedInstancesHandler { /** for serialization */ static final long serialVersionUID = -2134543132156464L; /** * replace missing values in training instances * 替换训练集中的缺省值 */ private ReplaceMissingValues m_ReplaceMissingFilter; /** * number of clusters to generate * 产生聚类的个数 */ private int m_NumClusters = 2; /** * D: d(i,j)=||c(i)-x(j)||为第i个聚类中心与第j个数据点间的欧几里德距离 */ private Matrix D; // private Matrix U; /** * holds the fuzzifier * 模糊算子(加权指数) */ private double m_fuzzifier = 2; /** * holds the cluster centroids * 聚类中心 */ private Instances m_ClusterCentroids; /** * Holds the standard deviations of the numeric attributes in each cluster * 每个聚类的标准差 */ private Instances m_ClusterStdDevs; /** * For each cluster, holds the frequency counts for the values of each * nominal attribute */ private int [][][] m_ClusterNominalCounts; /** * The number of instances in each cluster * 每个聚类包含的实例个数 */ private int [] m_ClusterSizes; /** * attribute min values * 属性最小值 */ private double [] m_Min; /** * attribute max values * 属性最大值 */ private double [] m_Max; /** * Keep track of the number of iterations completed before convergence * 迭代次数 */ private int m_Iterations = 0; /** * Holds the squared errors for all clusters * 平方误差 */ private double [] m_squaredErrors; /** * the default constructor * 初始构造器 */ public FuzzyCMeans () { super(); m_SeedDefault = 10;//初始化种子个数 setSeed(m_SeedDefault); } /** * Returns a string describing this clusterer * @return a description of the evaluator suitable for * displaying in the explorer/experimenter gui * 全局信息, 在图形介面显示 */ public String globalInfo() { return "Cluster data using the fuzzy k means algorithm"; } /** * Returns default capabilities of the clusterer. * * @return the capabilities of this clusterer * 聚类容器 */ public Capabilities getCapabilities() { Capabilities result = super.getCapabilities(); result.disableAll(); result.enable(Capability.NO_CLASS); // attributes result.enable(Capability.NUMERIC_ATTRIBUTES); result.enable(Capability.MISSING_VALUES); return result; } /** * Generates a clusterer. Has to initialize all fields of the clusterer * that are not being set via options. * * @param data set of instances serving as training data * @throws Exception if the clusterer has not been * generated successfully * 聚类产生函数 */ public void buildClusterer(Instances data) throws Exception { // can clusterer handle the data?检测数据能否聚类 getCapabilities().testWithFail(data); m_Iterations = 0; m_ReplaceMissingFilter = new ReplaceMissingValues(); Instances instances = new Instances(data);//实例 instances.setClassIndex(-1); m_ReplaceMissingFilter.setInputFormat(instances); instances = Filter.useFilter(instances, m_ReplaceMissingFilter); m_Min = new double [instances.numAttributes()]; m_Max = new double [instances.numAttributes()]; for (int i = 0; i < instances.numAttributes(); i++) { m_Min[i] = m_Max[i] = Double.NaN;//随机分配不定值 } m_ClusterCentroids = new Instances(instances, m_NumClusters);//聚类中心 int[] clusterAssignments = new int [instances.numInstances()]; for (int i = 0; i < instances.numInstances(); i++) { updateMinMax(instances.instance(i));//更新最大最小值 } Random RandomO = new Random(getSeed());//随机数 int instIndex; HashMap initC = new HashMap(); DecisionTableHashKey hk = null; /* 利用决策表随机生成聚类中心 */ for (int j = instances.numInstances() - 1; j >= 0; j--) { instIndex = RandomO.nextInt(j+1); hk = new DecisionTableHashKey(instances.instance(instIndex), instances.numAttributes(), true); if (!initC.containsKey(hk)) { m_ClusterCentroids.add(instances.instance(instIndex)); initC.put(hk, null); } instances.swap(j, instIndex); if (m_ClusterCentroids.numInstances() == m_NumClusters) { break; } } m_NumClusters = m_ClusterCentroids.numInstances();//聚类个数=聚类中心个数 D = new Matrix(solveD(instances).getArray());//求聚类中心到每个实例的距离 int i, j; int n = instances.numInstances(); Instances [] tempI = new Instances[m_NumClusters]; m_squaredErrors = new double [m_NumClusters]; m_ClusterNominalCounts = new int [m_NumClusters][instances.numAttributes()][0]; Matrix U = new Matrix(solveU(instances).getArray());//初始化隶属矩阵U double q = 0;//初始化价值函数值 while (true) { m_Iterations++; for (i = 0; i < instances.numInstances(); i++) { Instance toCluster = instances.instance(i); int newC = clusterProcessedInstance(toCluster, true);//聚类处理实例,即输入的实例应该聚到哪一个簇?! clusterAssignments[i] = newC; } // update centroids 更新聚类中心 m_ClusterCentroids = new Instances(instances, m_NumClusters); for (i = 0; i < m_NumClusters; i++) { tempI[i] = new Instances(instances, 0); } for (i = 0; i < instances.numInstances(); i++) { tempI[clusterAssignments[i]].add(instances.instance(i)); } for (i = 0; i < m_NumClusters; i++) { double[] vals = new double[instances.numAttributes()]; for (j = 0; j < instances.numAttributes(); j++) { double sum1 = 0, sum2 = 0; for (int k = 0; k < n; k++) { sum1 += U.get(i, k) * U.get(i, k) * instances.instance(k).value(j); sum2 += U.get(i, k) * U.get(i, k); } vals[j] = sum1 / sum2; } m_ClusterCentroids.add(new Instance(1.0, vals)); } D = new Matrix(solveD(instances).getArray()); U = new Matrix(solveU(instances).getArray());//计算新的聿属矩阵U double q1 = 0;//新的价值函数值 for (i = 0; i < m_NumClusters; i++) { for (j = 0; j < n; j++) { /* 计算价值函数值 即q1 += U(i,j)^m * d(i,j)^2 */ q1 += Math.pow(U.get(i, j), getFuzzifier()) * D.get(i, j) * D.get(i, j); } } /* 上次价值函数值的改变量(q1 -q)小于某个阀值(这里用机器精度:2.2204e-16) */ if (q1 - q < 2.2204e-16) { break; } q = q1; } /* 计算标准差 跟K均值一样 */ m_ClusterStdDevs = new Instances(instances, m_NumClusters); m_ClusterSizes = new int [m_NumClusters]; for (i = 0; i < m_NumClusters; i++) { double [] vals2 = new double[instances.numAttributes()]; for (j = 0; j < instances.numAttributes(); j++) { if (instances.attribute(j).isNumeric()) {//判断属性是否是数值型的?! vals2[j] = Math.sqrt(tempI[i].variance(j)); } else { vals2[j] = Instance.missingValue(); } } m_ClusterStdDevs.add(new Instance(1.0, vals2));//1.0代表权值, vals2代表属性值 m_ClusterSizes[i] = tempI[i].numInstances(); } } /** * clusters an instance that has been through the filters * * @param instance the instance to assign a cluster to * @param updateErrors if true, update the within clusters sum of errors * @return a cluster number * 聚类一个实例, 返回实例应属于哪一个簇的编号 * 首先计算输入的实例到所有聚类中心的距离, 哪里距离最小 * 这个实例就属于哪一个聚类中心所在簇 */ private int clusterProcessedInstance(Instance instance, boolean updateErrors) { double minDist = Integer.MAX_VALUE; int bestCluster = 0; for (int i = 0; i < m_NumClusters; i++) { double dist = distance(instance, m_ClusterCentroids.instance(i)); if (dist < minDist) { minDist = dist; bestCluster = i; } } if (updateErrors) { m_squaredErrors[bestCluster] += minDist; } return bestCluster; } /** * Classifies a given instance. * * @param instance the instance to be assigned to a cluster * @return the number of the assigned cluster as an interger * if the class is enumerated, otherwise the predicted value * @throws Exception if instance could not be classified * successfully * 分类一个实例, 调用clusterProcessedInstance()函数 */ public int clusterInstance(Instance instance) throws Exception { m_ReplaceMissingFilter.input(instance); m_ReplaceMissingFilter.batchFinished(); Instance inst = m_ReplaceMissingFilter.output(); return clusterProcessedInstance(inst, false); } /** * 计算矩阵D, 即 d(i,j)=||c(i)-x(j)|| */ private Matrix solveD(Instances instances) { int n = instances.numInstances(); Matrix D = new Matrix(m_NumClusters, n); for (int i = 0; i < m_NumClusters; i++) { for (int j = 0; j < n; j++) { D.set(i, j, distance(instances.instance(j), m_ClusterCentroids.instance(i))); if (D.get(i, j) == 0) { D.set(i, j, 0.000000000001); } } } return D; } /** * 计算聿属矩阵U, 即U(i,j) = 1 / sum(d(i,j)/ d(k,j))^(2/(m-1) */ private Matrix solveU(Instances instances) { int n = instances.numInstances(); int i, j; Matrix U = new Matrix(m_NumClusters, n); for (i = 0; i < m_NumClusters; i++) { for (j = 0; j < n; j++) { double sum = 0; for (int k = 0; k < m_NumClusters; k++) { //d(i,j)/d(k,j)^(2/(m-1) sum += Math.pow(D.get(i, j) / D.get(k, j), 2 /(getFuzzifier() - 1)); } U.set(i, j, Math.pow(sum, -1)); } } return U; } /** * Calculates the distance between two instances * * @param first the first instance * @param second the second instance * @return the distance between the two given instances * 计算两个实例之间的距离, 返回欧几里德距离 */ private double distance(Instance first, Instance second) { double val1; double val2; double dist = 0.0; for (int i = 0; i <first.numAttributes(); i++) { val1 = first.value(i); val2 = second.value(i); dist += (val1 - val2) * (val1 - val2); } dist = Math.sqrt(dist); return dist; } /** * Updates the minimum and maximum values for all the attributes * based on a new instance. * * @param instance the new instance * 更新所有属性最大最小值, 跟K均值里的函数一样 */ private void updateMinMax(Instance instance) { for (int j = 0;j < m_ClusterCentroids.numAttributes(); j++) { if (!instance.isMissing(j)) { if (Double.isNaN(m_Min[j])) { m_Min[j] = instance.value(j); m_Max[j] = instance.value(j); } else { if (instance.value(j) < m_Min[j]) { m_Min[j] = instance.value(j); } else { if (instance.value(j) > m_Max[j]) { m_Max[j] = instance.value(j); } } } } } } /** * Returns the number of clusters. * * @return the number of clusters generated for a training dataset. * @throws Exception if number of clusters could not be returned * successfully * 返回聚类个数 */ public int numberOfClusters() throws Exception { return m_NumClusters; } /** * 返回模糊算子, 即加权指数 * * @return 加权指数 * @throws Exception 加权指数不能成功返回 */ public double fuzzifier() throws Exception { return m_fuzzifier; } /** * Returns an enumeration describing the available options. * * @return an enumeration of all the available options. * 返回一个枚举描述的活动选项(菜单) */ public Enumeration listOptions () { Vector result = new Vector(); result.addElement(new Option( "\tnumber of clusters.\n" + "\t(default 2).", "N", 1, "-N <num>")); result.addElement(new Option( "\texponent.\n" + "\t(default 2.0).", "F", 1, "-F <num>")); Enumeration en = super.listOptions(); while (en.hasMoreElements()) result.addElement(en.nextElement()); return result.elements(); } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui * 返回文本信息 */ public String numClustersTipText() { return "set number of clusters"; } /** * set the number of clusters to generate * * @param n the number of clusters to generate * @throws Exception if number of clusters is negative * 设置聚类个数 */ public void setNumClusters(int n) throws Exception { if (n <= 0) { throw new Exception("Number of clusters must be > 0"); } m_NumClusters = n; } /** * gets the number of clusters to generate * * @return the number of clusters to generate * 取聚类个数 */ public int getNumClusters() { return m_NumClusters; } /** * Returns the tip text for this property * @return tip text for this property suitable for * displaying in the explorer/experimenter gui * 返回文本信息 */ public String fuzzifierTipText() { return "set fuzzifier"; } /** * set the fuzzifier * * @param f fuzzifier * @throws Exception if exponent is negative * 设置模糊算子 */ public void setFuzzifier(double f) throws Exception { if (f <= 1) { throw new Exception("F must be > 1"); } m_fuzzifier= f; } /** * get the fuzzifier * * @return m_fuzzifier * 取得模糊算子 */ public double getFuzzifier() { return m_fuzzifier; } /** * Parses a given list of options. <p/> * <!-- options-start --> * Valid options are: <p/> * * <pre> -N <num> * number of clusters. * (default 2).</pre> * * <pre> -F <num> * fuzzifier. * (default 2.0).</pre> * * <pre> -S <num> * Random number seed. * (default 10)</pre> * <!-- options-end --> * * @param options the list of options as an array of strings * @throws Exception if an option is not supported * 设置活动选项 */ public void setOptions (String[] options) throws Exception { String optionString = Utils.getOption(‘N‘, options); if (optionString.length() != 0) { setNumClusters(Integer.parseInt(optionString)); } optionString = Utils.getOption(‘F‘, options); if (optionString.length() != 0) { setFuzzifier((new Double(optionString)).doubleValue()); } super.setOptions(options); } /** * Gets the current settings of FuzzyCMeans * * @return an array of strings suitable for passing to setOptions() * 取得活动选项 */ public String[] getOptions () { int i; Vector result; String[] options; result = new Vector(); result.add("-N"); result.add("" + getNumClusters()); result.add("-F"); result.add("" + getFuzzifier()); options = super.getOptions(); for (i = 0; i < options.length; i++) result.add(options[i]); return (String[]) result.toArray(new String[result.size()]); } /** * return a string describing this clusterer * * @return a description of the clusterer as a string * 结果显示 */ public String toString() { int maxWidth = 0; for (int i = 0; i < m_NumClusters; i++) { for (int j = 0 ;j < m_ClusterCentroids.numAttributes(); j++) { if (m_ClusterCentroids.attribute(j).isNumeric()) { double width = Math.log(Math.abs(m_ClusterCentroids.instance(i).value(j))) / Math.log(10.0); width += 1.0; if ((int)width > maxWidth) { maxWidth = (int)width; } } } } StringBuffer temp = new StringBuffer(); String naString = "N/A"; for (int i = 0; i < maxWidth+2; i++) { naString += " "; } temp.append("\nFuzzy C-means\n======\n"); temp.append("\nNumber of iterations: " + m_Iterations+"\n"); temp.append("Within cluster sum of squared errors: " + Utils.sum(m_squaredErrors)); temp.append("\n\nCluster centroids:\n"); for (int i = 0; i < m_NumClusters; i++) { temp.append("\nCluster "+i+"\n\t"); temp.append("\n\tStd Devs: "); for (int j = 0; j < m_ClusterStdDevs.numAttributes(); j++) { if (m_ClusterStdDevs.attribute(j).isNumeric()) { temp.append(" "+Utils.doubleToString(m_ClusterStdDevs.instance(i).value(j), maxWidth+5, 4)); } else { temp.append(" "+naString); } } } temp.append("\n\n"); return temp.toString(); } /** * Gets the the cluster centroids * * @return the cluster centroids * 取得聚类中心 */ public Instances getClusterCentroids() { return m_ClusterCentroids; } /** * Gets the standard deviations of the numeric attributes in each cluster * * @return the standard deviations of the numeric attributes * in each cluster * 聚得标准差 */ public Instances getClusterStandardDevs() { return m_ClusterStdDevs; } /** * Returns for each cluster the frequency counts for the values of each * nominal attribute * * @return the counts */ public int [][][] getClusterNominalCounts() { return m_ClusterNominalCounts; } /** * Gets the squared error for all clusters * * @return the squared error * 取得平方差 */ public double getSquaredError() { return Utils.sum(m_squaredErrors); } /** * Gets the number of instances in each cluster * * @return The number of instances in each cluster * 取每个簇的实例个数 */ public int [] getClusterSizes() { return m_ClusterSizes; } /** * Main method for testing this class. * * @param argv should contain the following arguments: <p> * -t training file [-N number of clusters] * 主函数 */ public static void main (String[] argv) { runClusterer(new FuzzyCMeans (), argv); } }
标签:
原文地址:http://www.cnblogs.com/rongyux/p/5396812.html