标签:this set str top 标签 amp des memset forward
#include <iostream> #include <algorithm> #include "MnistFile.cpp" #include <cmath> using namespace std; const int synapseNums = 800; class Node { public: //int weight = 1;//权重 double value = 0;//保持计算和 Node *pre[synapseNums];//pre[0]指向前一层节点 Node *next;//链接下一个节点 }; class Chain {//循环链表 public: // start->input->others Node *start;//指向输出层的开始节点 Node *input;//指向输入层的开始节点 Node *others;//指向中间层的节点,动态增删 Node *end;//整个链的最后节点,下一个节点指向start Node null;//删除节点造成的空指针集中指向的节点 Node **outputLink;//指向输出层的节点 }; class Network { public: Chain chain; const int inputNodeNums;//输入节点维度 const int outputNodeNums;//输出节点维度 int othersNodeNums = 0;//中间层节点数目 int inputIndexValue = -1;//当前标签值 public: Network(int in, int out); ~Network(); void init();//建立网络结构 void setNull(Node *p); void setNodes(Node **p, int nums); void inputValue(vector<double> input, double index); double activate(double x); int softMax(double **x, int n); int retIndexOf(Node *p, Node *q); int retNullIndex(Node *p); void forward(); double getIndexValue(int index);//得到标签的输出值 int getIndex();//得到正确的标签值 double getChange();//得到改变边后输出的变化 void newNode(); void deleteNode(Node *p); void deleteNode2(Node *p); void newEdge(Node **p); void newEdge2(Node **ip, vector<double> &labels, vector<vector<double> >&images); void newOutputEdge(); //void deleteedge(); void train(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images); void train2(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images); int predictIndex();//预测的标签值 //void writeDate(); //void readData(string s); double evalStudyRate(int testSampleNums, vector<double> &labels, vector<vector<double> >&images); }; void Network::setNull(Node *p) { for (int i = 0; i < synapseNums; ++i) p->pre[i] = &chain.null; } Network::Network(int in, int out) : inputNodeNums(in), outputNodeNums(out) { //设置null节点 chain.null.next = nullptr; chain.null.value = 0; setNull(&chain.null); //初始化 init(); //设置outputLink chain.outputLink = new Node*[outputNodeNums]; Node *p = chain.start; for (int i = 0; i < outputNodeNums; ++i) { chain.outputLink[i] = p; p = p->next; } } Network::~Network() { cout << "~Network ..." << endl; Node *p = nullptr; Node *q = nullptr; p = chain.start; chain.start = nullptr; while (p != chain.input) { //cout << "p isn‘t null" << endl; q = p->next; delete p; p = q; } delete [] chain.outputLink; } void Network::setNodes(Node **p, int nums) { for (int i = 0; i < nums; i++) { Node *q = new Node; setNull(q); (*p)->next = q; *p = q; } } void Network::init() { cout << "init ..." << endl; //将所有节点拉成链 //建立输出层 Node *p = new Node; setNull(p); chain.start = p;//p前节点 setNodes(&p, outputNodeNums - 1); //建立输入层 setNodes(&p, 1); chain.input = p;//p前节点 setNodes(&p, inputNodeNums - 1); //连接others setNodes(&p, 1); othersNodeNums = 1; //连接最后节点指针 chain.others = chain.end = p; chain.end->next = chain.start;//循环链表 } double Network::activate(double x) { return 1.0 / (1.0 + exp(-x)); } int Network::softMax(double **x, int n) { double *p = *x; int k = 0; for (int i = 0; i < n; ++i) { if (p[k] < p[i]) k = i; } return k; } void Network::inputValue(vector<double> input, double index) { Node *p = chain.input; double s = 0; for (int j = 0; j < inputNodeNums; ++j) { s += input[j]; } for (int i = 0; i < inputNodeNums; ++i) { p->value = input[i];// / s;//正规化 p = p->next; } inputIndexValue = index; //cout << "Network::inputValue(vector<double> input, double index); end!" << endl; } void Network::forward() { Node *p = chain.others; double sum = 0; while (p != chain.input) { for (int i = 0; i < synapseNums; ++i) { sum = sum + p->pre[i]->value; } p->value = activate(sum); //cout << p->value << " "; p = p->next; } } void Network::newNode() { cout << "newNode ..." << endl; Node *p = new Node; setNull(p); othersNodeNums++; chain.end->next = p; chain.end = p; chain.end->next = chain.start; } void Network::deleteNode(Node *p) { // //删除一个节点,需要找到指向该节点的指针,然后再删除该节点 //找到p的前驱 Node *q = chain.others; while (q) { if (q->next == p || q->next == chain.start) { break; } else { q = q->next; } } if (q->next != p) return;//没找到这个节点 //找到指向p的边修改为null Node *t = chain.others; while (t) { for (int i = 0; i < synapseNums; ++i) { if (t->pre[i] == p) { t->pre[i] = &chain.null;//t->pre[i] == nullptr; } } t = t->next; if (t == chain.input) break; } //删除节点 q->next = p->next; delete p; othersNodeNums--; } void Network::deleteNode2(Node *p) { p->value = 0; } double Network::getIndexValue(int index) { return chain.outputLink[index]->value; } double Network::getChange() { forward(); int index = getIndex(); return getIndexValue(index); } int Network::getIndex() { return this->inputIndexValue; } int Network::predictIndex() { double x[outputNodeNums]; memset(x, 0, outputNodeNums); Node *p = chain.start; for (int i = 0; i < outputNodeNums; ++i) { x[i] = p->value; p = p->next; } double *q = x; return softMax(&q, outputNodeNums); } //返回p指向q的下标 int Network::retIndexOf(Node *p, Node *q) { //p->q? for (int i = 0; i < synapseNums; ++i) { if (p->pre[i] == q) { return i;//找到了,p的第i个指针指向q } } return -1;//没找到 } //寻找空闲指针 int Network::retNullIndex(Node *p){ Node *q = &chain.null; for (int i = 0; i < synapseNums; ++i) { if (p->pre[i] == q) return i; } return -1;//说明满了,没有空闲的 } void Network::newEdge(Node **ip) { cout << "newEdge ..." << endl; Node *p = *ip; Node *t = chain.input; double old = 0; double now = 0; for (int i = 0; i < synapseNums;) { //测试当前right index的输出值 old = getChange(); //连接边 // i = retNullIndex(p); // if (i == -1) // break;//这里不对,与下面的i++冲突 p->pre[i] = t;//当前节点指向t //修改output层的正确输出的指针指向p int re = retIndexOf(chain.outputLink[getIndex()], p); if (re == -1){ int rn = retNullIndex(chain.outputLink[getIndex()]); if (rn == -1) cerr << "retNullIndex is overfill!!\n"; chain.outputLink[getIndex()]->pre[rn] = p; } //测试当前连接后的right index的输出值 now = getChange(); //比较,若大于,则连接,否则连接下一个节点 if (now > old) ++i; t = t->next; if (t == chain.start) break; if(predictIndex() == getIndex()) break;//如果得到正确的输出就跳出 // cout << i << " "; } } //void //void Network::newOutputEdge() { // Node *p = chain.start;//输出层节点指针 // double old = 0; // double now = 0; // Node *q = chain.input;//输入层和中间层节点指针 // // while (p != chain.input){ // int r = retNullIndex(p);//空闲指针 // // p = p->next; // // } //} double Network::evalStudyRate(int testSampleNums, vector<double> &labels, vector<vector<double> >&images) { cout << "evalStudyRate : " ; //根据输入的测试集数目来评估正确率 int rightNums = 0; for (int i = 0; i < testSampleNums; ++i) { inputValue(images[50000+i], labels[50000+i]); forward();//前向传播 if (predictIndex() == inputIndexValue) //预测 rightNums++; } return (double)rightNums / testSampleNums * 100.0; } void Network::newEdge2(Node **ip, vector<double> &labels, vector<vector<double> >&images) { cout << "newEdge ..." << endl; Node *p = *ip; Node *t = chain.input; double old = 0; double now = 0; for (int i = 0; i < synapseNums;) { //测试当前right index的输出值 old = evalStudyRate(100, labels, images); //连接边 p->pre[i] = t;//当前节点指向t // //修改output层的正确输出的指针指向p // int re = retIndexOf(chain.outputLink[getIndex()], p); // if (re == -1){ // int rn = retNullIndex(chain.outputLink[getIndex()]); // if (rn == -1) // cerr << "retNullIndex is overfill!!\n"; // chain.outputLink[getIndex()]->pre[rn] = p; // } //测试当前连接后的right index的输出值 now = evalStudyRate(100, labels, images); //比较,若大于,则连接,否则连接下一个节点 if (now > old) ++i; t = t->next; if (t == chain.start) break; if(predictIndex() == getIndex()) break;//如果得到正确的输出就跳出 //cout << i << " "; } } //训练这个网络,得到网络结构 void Network::train(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images) { cout << "train ... " << endl; //训练 for (int i = 0; i < trainSampleNums; ++i) { //输入数据进入网络 inputValue(images[i], labels[i]); //连接边 newEdge(&chain.end); //新建节点 newNode(); cout << "othersNodeNums = " << othersNodeNums << endl; if (othersNodeNums >= 50) { cout << "if (othersNodeNums >= 8) stop" << endl; break; } } //评估学习率 cout << evalStudyRate(9000, labels, images) << "%" << endl; } //训练这个网络,得到网络结构 void Network::train2(int trainSampleNums, vector<double>&labels, vector<vector<double> > &images) { cout << "train ... " << endl; //训练 for (int i = 0; i < trainSampleNums; ++i) { //输入数据进入网络 inputValue(images[i], labels[i]); //连接边 newEdge(&chain.end); //新建节点 newNode(); cout << "othersNodeNums = " << othersNodeNums << endl; if (othersNodeNums >= 4) { cout << "if (othersNodeNums >= 8) stop" << endl; break; } } // Node *p = chain.others; // while (p != chain.input) { // newEdge2(&p, labels, images); // } //评估学习率 cout << evalStudyRate(9000, labels, images) << "%" << endl; } int main(){ //读数据 cout << "read data..." << endl; vector<double>labels; read_Mnist_Label("train-labels.idx1-ubyte", labels); vector<vector<double>> images; read_Mnist_Images("train-images.idx3-ubyte", images); Network network(784, 10); network.train2(10, labels, images); }
标签:this set str top 标签 amp des memset forward
原文地址:https://www.cnblogs.com/niubidexiebiao/p/13039469.html