标签:
如题,该算法是来自德国的牛逼的数学家strassen搞出来的,因为把n*n矩阵之间的乘法复杂度降低到n^(lg7)(lg的底是2),一开始想当然地认为朴素的做法是n^3,哪里还能有复杂度更低的做法,但是牛逼的strassen先生简直刷新了我的线性代数观和算法观
思路:基本的思路网上有,此处不再赘述,下面说说怎么实现strassen算法(数据处理)
m*n的矩阵和n*p的矩阵相乘,得到m*p的矩阵,因为每次要二分,遇到奇数的做法就是在行尾(列尾)加全零行(列),因为加入全零行(列)是不会影响就算结果的,从而使得可以二分。我为了省事直接在输入两个矩阵后就直接把它们统一扩展成了2^y*2^y的矩阵,其中,y=2^「lg(Max{m,p})」 (「」是向上取整,输入法里没找到合适的符号....;另外,Max(m,n)和Max(n,p)是相等的,取哪个都一样)
代码:
# include <iostream.h> # include "..\Sort\IO_tools.cpp" void strassen(int** &C,int** &A,int** &B,int arow,int acol,int brow,int bcol,int size); void stra_plus(int** &C,int** &A,int** &B,int crow,int ccol,int arow,int acol,int brow,int bcol,int size,int symbol); void main(){ int** A=NULL;int** B=NULL;int** C=NULL; int A_row,A_col,B_row,B_col,size; cout<<"size of A<row,col>:"<<endl; cin>>A_row>>A_col; size=input2A(A,A_row,A_col);//万能的传引用,绝对正确! cout<<"size of B<row,col>:"<<endl; cin>>B_row>>B_col; input2A(B,B_row,B_col);//万能的传引用,绝对正确! strassen(C,A,B,0,0,0,0,size); output2A(C,A_row,B_col); //stra_plus(C,A,B,0,0,0,0,0,0,size,1); //output2A(C,size,size); } void strassen(int** &C,int** &A,int** &B,int arow,int acol,int brow,int bcol,int size){ C=(int**)new int* [size]; for(int i=0;i<size;i++){ C[i]=new int[size]; } /* 对于size>1的要进一步拆分(其实strassen算法递归计算时要申请这么多内存,size不够大时反而降低了效率, 故而size达到下限时可以采用朴素的矩阵乘法计算方法而不必继续调用strassen算法,此处出于偷懒就省点事儿把size下限设为1) */ if(size>1){ //S(1-10)初始化 int** S1=NULL;int** S2=NULL;int** S3=NULL;int** S4=NULL;int** S5=NULL;int** S6=NULL;int** S7=NULL;int** S8=NULL;int** S9=NULL;int** S10=NULL; stra_plus(S1,B,B,0,0,brow,bcol+size/2,brow+size/2,bcol+size/2,size/2,-1); stra_plus(S2,A,A,0,0,arow,acol,arow,acol+size/2,size/2,1); stra_plus(S3,A,A,0,0,arow+size/2,acol,arow+size/2,acol+size/2,size/2,1); stra_plus(S4,B,B,0,0,brow+size/2,bcol,brow,bcol,size/2,-1); stra_plus(S5,A,A,0,0,arow,acol,arow+size/2,acol+size/2,size/2,1); stra_plus(S6,B,B,0,0,brow,bcol,brow+size/2,bcol+size/2,size/2,1); stra_plus(S7,A,A,0,0,arow,acol+size/2,arow+size/2,acol+size/2,size/2,-1); stra_plus(S8,B,B,0,0,brow+size/2,bcol,brow+size/2,bcol+size/2,size/2,1); stra_plus(S9,A,A,0,0,arow,acol,arow+size/2,acol,size/2,-1); stra_plus(S10,B,B,0,0,brow,bcol,brow,bcol+size/2,size/2,1); //P(1-7)初始化 int** P1=NULL;int** P2=NULL;int** P3=NULL;int** P4=NULL;int** P5=NULL;int** P6=NULL;int** P7=NULL; strassen(P1,A,S1,arow,acol,0,0,size/2); strassen(P2,S2,B,0,0,brow+size/2,bcol+size/2,size/2); strassen(P3,S3,B,0,0,brow,bcol,size/2); strassen(P4,A,S4,arow+size/2,acol+size/2,0,0,size/2); strassen(P5,S5,S6,0,0,0,0,size/2); strassen(P6,S7,S8,0,0,0,0,size/2); strassen(P7,S9,S10,0,0,0,0,size/2); //计算结果C(依次是C11,C12,C21,C22) stra_plus(C,P4,P5,0,0,0,0,0,0,size/2,1);stra_plus(C,C,P2,0,0,0,0,0,0,size/2,-1);stra_plus(C,C,P6,0,0,0,0,0,0,size/2,1); stra_plus(C,P1,P2,0,size/2,0,0,0,0,size/2,1); stra_plus(C,P3,P4,size/2,0,0,0,0,0,size/2,1); stra_plus(C,P5,P1,size/2,size/2,0,0,0,0,size/2,1);stra_plus(C,C,P3,size/2,size/2,size/2,size/2,0,0,size/2,-1);stra_plus(C,C,P7,size/2,size/2,size/2,size/2,0,0,size/2,-1); } /*到达下限*/ else{ C[0][0]=A[arow][acol]*B[brow][bcol]; } } //参与运算的是A,B,C的size*size的(子)矩阵,<arow,acol>是A的参与运算的子矩阵的左上角坐标,<brow,bcol>同理,C是保存结果的 void stra_plus(int** &C,int** &A,int** &B,int crow,int ccol,int arow,int acol,int brow,int bcol,int size,int symbol){ if(C==NULL){ C=(int**)new int* [size]; for(int i=0;i<size;i++){ C[i]=new int[size]; } } for(int i=0;i<size;i++){ for(int j=0;j<size;j++){ C[i+crow][j+ccol]=A[i+arow][j+acol]+symbol*B[i+brow][j+bcol]; } } }
# include <iostream.h> # include <stdlib.h> # include <math.h> int get_Upper_2Pow(int row,int col); void inputA(int A[],int n){ int i=n; cout<<"Input Array:"; while(i--){ cin>>A[n-i-1]; } } void outputA(int A[],int n){ int i=n; cout<<"output Array:"; while(i--){ cout<<A[n-i-1]<<" "; } cout<<endl; } int input2A(int** &A,int row,int col){ int size=get_Upper_2Pow(row,col); A=(int**)new int* [size]; for(int i=0;i<size;i++){ A[i]=new int[size]; } cout<<"Input 2th Array:"<<endl; for(int r=0;r<size;r++){ for(int c=0;c<size;c++){ A[r][c]=0; } } for(r=0;r<row;r++){ for(int c=0;c<col;c++){ cin>>A[r][c]; } } return size; } void output2A(int** A,int row,int col){ cout<<"output:"<<endl; for(int r=0;r<row;r++){ for(int c=0;c<col;c++){ cout<<A[r][c]<<" "; } cout<<endl; } } int get_Upper_2Pow(int row,int col){ for(int i=0;pow(2,i)<row||pow(2,i)<col;i++); return (int)pow(2,i); }
标签:
原文地址:http://www.cnblogs.com/zpfly2008/p/5575126.html