码迷,mamicode.com
首页 > 其他好文 > 详细

实用矩阵类(Matrix)(带测试)

时间:2015-12-07 12:29:45      阅读:194      评论:0      收藏:0      [点我收藏+]

标签:

引言:

无意间看到国外一个网站写的Matrix类,实现了加减乘除基本运算以及各自的const版本等等,功能还算比较完善,,于是记录下来,以备后用:

  1 #ifndef MATRIX_H
  2 #define MATRIX_H
  3 
  4 #include <iostream>
  5 #include <functional>
  6 #include <algorithm>
  7 
  8 using namespace std;
  9 
 10 template <size_t Rows,size_t Cols,typename ElemType = double>
 11 class Matrix
 12 {
 13 public:
 14     Matrix();        //默认构造函数
 15     template<typename InputIterator>
 16     Matrix(InputIterator begin, InputIterator end);    //用2迭代器构造
 17     ~Matrix();    //析构函数
 18 
 19     ElemType & at(size_t row, size_t col);    //获取某个元素(引用)
 20     const ElemType & at(size_t row, size_t col) const;
 21     
 22     size_t numRows() const;        //获得行数、列数
 23     size_t numCols() const;
 24     size_t size() const;        //返回总元素数
 25 
 26     //多维矩阵
 27     class MutableReference;
 28     class ImmutableReference;
 29     MutableReference operator[] (size_t row);
 30     ImmutableReference operator[] (size_t row) const;
 31 
 32     typedef ElemType* iterator; //将元素类型指针定义为迭代器
 33     typedef const ElemType* const_iterator;
 34 
 35     iterator begin();
 36     iterator end();
 37     const_iterator begin() const;
 38     const_iterator end() const;
 39 
 40     iterator row_begin(size_t row);
 41     iterator row_end(size_t row);
 42     const_iterator row_begin(size_t row) const;
 43     const_iterator row_end(size_t row) const;
 44     
 45     Matrix& operator+= (const Matrix& rhs);
 46     Matrix& operator-= (const Matrix& rhs);
 47     Matrix& operator*= (const ElemType& scalar);
 48     Matrix& operator/= (const ElemType& scalar);
 49 
 50     //打印矩阵
 51     void printMatrix(void) const;
 52 private:
 53     ElemType elems[Rows*Cols];        //矩阵元素的数组
 54 
 55 };
 56 //两矩阵相加
 57 template <size_t M,size_t N,typename T>
 58 const Matrix<M, N, T> operator+ (const Matrix<M, N, T> &lhs, const Matrix<M, N, T> &rhs);
 59 //两矩阵相减
 60 template <size_t M, size_t N, typename T>
 61 const Matrix<M, N, T> operator- (const Matrix<M, N, T> &lhs, const Matrix<M, N, T> &rhs);
 62 //矩阵数乘(右乘)
 63 template <size_t M, size_t N, typename T>
 64 const Matrix<M, N, T> operator* (const Matrix<M, N, T> &lhs, const T& scalar);
 65 //矩阵数乘(左乘)
 66 template <size_t M, size_t N, typename T>
 67 const Matrix<M, N, T> operator* (const T& scalar, const Matrix<M, N, T> &rhs);
 68 //矩阵除以一个数
 69 template <size_t M, size_t N, typename T>
 70 const Matrix<M, N, T> operator/ (const Matrix<M, N, T>& lhs,const T& scalar);
 71 //一元运算的加减  相当于添加符号
 72 template <size_t M, size_t N, typename T>
 73 const Matrix<M, N, T> operator+ (const Matrix<M, N, T>& operand);
 74 template <size_t M, size_t N, typename T>
 75 const Matrix<M, N, T> operator- (const Matrix<M, N, T>& operand);
 76 //2矩阵相乘
 77 template <size_t M, size_t N, size_t P, typename T>
 78 const Matrix<M, P, T> operator*(const Matrix<M, N, T>& lhs,const Matrix<N, P, T>& rhs);
 79 //矩阵的比较操作
 80 template <size_t M, size_t N, typename T>
 81 bool operator== (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs);
 82 
 83 template <size_t M, size_t N, typename T>
 84 bool operator!= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs);
 85 
 86 template <size_t M, size_t N, typename T>
 87 bool operator<  (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs);
 88 
 89 template <size_t M, size_t N, typename T>
 90 bool operator<= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs);
 91 
 92 template <size_t M, size_t N, typename T>
 93 bool operator>= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs);
 94 
 95 template <size_t M, size_t N, typename T>
 96 bool operator>  (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs);
 97 //是否为单位矩阵
 98 template <size_t M, typename T>
 99 Matrix<M, M, T> Identity();
100 //矩阵转置
101 template <size_t M, size_t N, typename T>
102 const Matrix<N, M, T> Transpose(const Matrix<M, N, T>& m);
103 
104 
105 //////////////////////////////////////////////////////////////////////////
106 /************************************************************************/
107 /*
108     函数的实现部分
109 */
110 /************************************************************************/
111 //默认构造函数
112 template <size_t M, size_t N, typename T>
113 Matrix<M, N, T>::Matrix() 
114 {
115 }
116 //迭代器构造函数
117 template <size_t M, size_t N, typename T>
118 template <typename InputIterator>
119 Matrix<M, N, T>::Matrix(InputIterator rangeBegin, InputIterator rangeEnd) 
120 {
121     std::copy(rangeBegin, rangeEnd, begin());
122 }
123 //析构函数
124 template <size_t M, size_t N, typename T>
125 Matrix<M, N, T>::~Matrix()
126 {
127 }
128 //得到row,col处元素(引用) 常量版本
129 template <size_t M, size_t N, typename T>
130 const T& Matrix<M, N, T>::at(size_t row, size_t col) const 
131 {
132     return *(begin() + row * numCols() + col);
133 }
134 //得到row,col处元素(引用) 非常量版本
135 template <size_t M, size_t N, typename T>
136 T& Matrix<M, N, T>::at(size_t row, size_t col) 
137 {
138     return const_cast<T&>(static_cast<const Matrix<M, N, T>*>(this)->at(row, col));
139 }
140 //得到行数
141 template<size_t M,size_t N,typename T>
142 size_t Matrix<M, N, T>::numRows() const
143 {
144     return M;
145 }
146 template<size_t M,size_t N,typename T>
147 size_t Matrix<M, N, T>::numCols() const
148 {
149     return N;
150 }
151 
152 template<size_t M,size_t N,typename T>
153 size_t Matrix<M,N,T>::size() const
154 {
155     return M*N;
156 }
157 //迭代器返回首地址指针 //注意返回是迭代器类型
158 template<size_t M,size_t N,typename T>
159 typename Matrix<M, N, T>::iterator Matrix<M,N,T>::begin()
160 {
161     return elems;
162 }
163 //迭代器返回首地址指针的常量版本
164 template<size_t M, size_t N, typename T>
165 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::begin() const
166 {
167     return elems;
168 }
169 //尾迭代器获取
170 template<size_t M, size_t N, typename T>
171 typename Matrix<M, N, T>::iterator Matrix<M, N, T>::end()
172 {
173     return begin()+size();
174 }
175 //尾迭代器获取(常量版本)
176 template<size_t M, size_t N, typename T>
177 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::end() const
178 {
179     return begin() + size();
180 }
181 //行迭代器(跳过指定元素获取)
182 template<size_t M,size_t N,typename T>
183 typename Matrix<M, N, T>::iterator Matrix<M, N, T>::row_begin(size_t row)
184 {
185     return begin() + row*numCols();
186 }
187 //行迭代器(跳过指定元素获取) 常量版本
188 template <size_t M, size_t N, typename T>
189 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::row_begin(size_t row) const
190 {
191     return begin() + row*numCols();
192 }
193 //获得行尾迭代器
194 template<size_t M,size_t N,typename T>
195 typename Matrix<M, N, T>::iterator Matrix<M, N, T>::row_end(size_t row)
196 {
197     return row_begin(row) + N;
198 }
199 //获得行尾迭代器 const版本
200 template<size_t M, size_t N, typename T>
201 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::row_end(size_t row) const
202 {
203     return row_begin(row) + N;
204 }
205 /************************************************************************/
206 /*
207     方括号[]操作返回引用的实现(非const版本)
208 */
209 /************************************************************************/
210 template <size_t M,size_t N,typename T>
211 class Matrix<M, N, T>::MutableReference
212 {
213 public:
214     T& operator[] (size_t col)
215     {
216         return parent->at(row, col);
217     }
218 private:
219     //私有构造函数  是获得此类实例的为例方法(有元类Matrix可以访问)
220     MutableReference(Matrix* owner, size_t row) :parent(owner), row(row)
221     {
222 
223     }
224     friend class Matrix;
225     const size_t row;
226     Matrix *const parent;
227 };
228 /************************************************************************/
229 /*
230 方括号[]操作返回引用的实现(const版本)
231 */
232 /************************************************************************/
233 template <size_t M, size_t N, typename T>
234 class Matrix<M, N, T>::ImmutableReference
235 {
236 public:
237     const T& operator[] (size_t col) const
238     {
239         return parent->at(row, col);
240     }
241 private:
242     //私有构造函数  是获得此类实例的为例方法(有元类Matrix可以访问)
243     ImmutableReference(const Matrix* owner, size_t row) :parent(owner), row(row)
244     {
245 
246     }
247     friend class Matrix;
248     const size_t row;
249     const Matrix *const parent;
250 };
251 //方括号返回引用的真真实现(用了上面的类)
252 template<size_t M,size_t N,typename T>
253 typename Matrix<M, N, T>::MutableReference Matrix<M, N, T>::operator [] (size_t row)
254 {
255     return MutableReference(this, row);
256 }
257 template<size_t M, size_t N, typename T>
258 typename Matrix<M, N, T>::ImmutableReference Matrix<M, N, T>::operator [] (size_t row) const
259 {
260     return ImmutableReference(this, row);
261 }
262 /************************************************************************/
263 /*
264     复合运算符实现
265 */
266 /************************************************************************/
267 template <size_t M, size_t N, typename T>
268 Matrix<M, N, T>& Matrix<M, N, T>::operator+= (const Matrix<M, N, T>& rhs) 
269 {
270     std::transform(begin(), end(),  // First input range is lhs
271         rhs.begin(),     // Start of second input range is rhs
272         begin(),         // Overwrite lhs
273         std::plus<T>()); // Using addition
274     return *this;
275 }
276 
277 template <size_t M, size_t N, typename T>
278 Matrix<M, N, T>& Matrix<M, N, T>::operator-= (const Matrix<M, N, T>& rhs)
279 {
280     std::transform(begin(), end(),   // First input range is lhs
281         rhs.begin(),      // Start of second input range is rhs
282         begin(),          // Overwrite lhs
283         std::minus<T>()); // Using subtraction
284     return *this;
285 }
286 template <size_t M, size_t N, typename T>
287 Matrix<M, N, T>& Matrix<M, N, T>::operator*= (const T& scalar) 
288 {
289     std::transform(begin(), end(), // Input range is lhs
290         begin(),        // Output overwrites lhs
291         std::bind2nd(std::multiplies<T>(), scalar)); // Scalar mult.
292     return *this;
293 }
294 template <size_t M, size_t N, typename T>
295 Matrix<M, N, T>& Matrix<M, N, T>::operator/= (const T& scalar) 
296 {
297     std::transform(begin(), end(), // Input range is lhs
298         begin(),        // Output overwrites lhs
299         std::bind2nd(std::divides<T>(), scalar)); // Divide by scalar
300     return *this;
301 }
302 /************************************************************************/
303 /*
304     双目运算符实现
305 */
306 /************************************************************************/
307 template <size_t M, size_t N, typename T>
308 const Matrix<M, N, T> operator+ (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 
309 {
310     return Matrix<M, N, T>(lhs) += rhs;        //用到了复合运算符(成员函数)
311 }
312 template <size_t M, size_t N, typename T>
313 const Matrix<M, N, T> operator- (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 
314 {
315     return Matrix<M, N, T>(lhs) -= rhs;
316 }
317 template <size_t M, size_t N, typename T>  //(右乘一个数)
318 const Matrix<M, N, T> operator* (const Matrix<M, N, T>& lhs,const T& scalar) 
319 {
320     return Matrix<M, N, T>(lhs) *= scalar;
321 }
322 template <size_t M, size_t N, typename T> //左乘一个数
323 const Matrix<M, N, T> operator* (const T& scalar,const Matrix<M, N, T>& rhs) 
324 {
325     return Matrix<M, N, T>(rhs) *= scalar;    
326 }
327 template <size_t M, size_t N, typename T>
328 const Matrix<M, N, T> operator/ (const Matrix<M, N, T>& lhs,const T& scalar) 
329 {
330     return Matrix<M, N, T>(lhs) /= scalar;
331 }
332 //一元运算符+
333 template <size_t M, size_t N, typename T>
334 const Matrix<M, N, T> operator+ (const Matrix<M, N, T>& operand) {
335     return operand;
336 }
337 //一元运算符-
338 template <size_t M, size_t N, typename T>
339 const Matrix<M, N, T> operator- (const Matrix<M, N, T>& operand) 
340 {
341     return Matrix<M, N, T>(operand) *= T(-1);
342 }
343 //2矩阵相乘
344 template <size_t M, size_t N, size_t P, typename T>
345 const Matrix<M, P, T> operator*(const Matrix<M, N, T>& one,const Matrix<N, P, T>& two) 
346 {
347     /* Create a result matrix of the right size and initialize it to zero. */
348     Matrix<M, P, T> result;
349     std::fill(result.begin(), result.end(), T(0));    //初始化结果变量
350 
351     /* Now go fill it in. */
352     for (size_t row = 0; row < result.numRows(); ++row)
353         for (size_t col = 0; col < result.numCols(); ++col)
354             for (size_t i = 0; i < N; ++i)
355                 result[row][col] += one[row][i] * two[i][col];
356 
357     return result;
358 }
359 //matrix1*=matrix运算实现
360 template <size_t M, typename T>
361 Matrix<M, M, T>& operator*= (Matrix<M, M, T>& lhs,const Matrix<M, M, T>& rhs) 
362 {
363     return lhs = lhs * rhs; // Nothing fancy here.
364 }
365 //比较运算符实现
366 template <size_t M, size_t N, typename T>
367 bool operator== (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 
368 {
369     return std::equal(lhs.begin(), lhs.end(), rhs.begin());
370 }
371 template <size_t M, size_t N, typename T>
372 bool operator!= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 
373 {
374     return !(lhs == rhs); //用了==运算符
375 }
376 /* The less-than operator uses the std::mismatch algorithm to chase down
377 * the first element that differs in the two matrices, then returns whether
378 * the lhs element is less than the rhs element.  This is essentially a
379 * lexicographical comparison optimized on the assumption that the two
380 * sequences have the same size.
381 */
382 //小于运算符
383 template <size_t M, size_t N, typename T>
384 bool operator<  (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 
385 {
386     /* Compute the mismatch. */
387     std::pair<typename Matrix<M, N, T>::const_iterator,
388         typename Matrix<M, N, T>::const_iterator> disagreement =
389         std::mismatch(lhs.begin(), lhs.end(), rhs.begin());
390 
391     /* lhs < rhs only if there is a mismatch and the lhs‘s element is
392     * lower than the rhs‘s element.
393     */
394     return disagreement.first != lhs.end() &&
395         *disagreement.first < *disagreement.second;
396 }
397 
398 /* The remaining relational operators are implemented in terms of <. */
399 template <size_t M, size_t N, typename T>
400 bool operator<= (const Matrix<M, N, T>& lhs, const Matrix<M, N, T>& rhs)
401 {
402     /* x <= y  iff  !(x > y)  iff  !(y < x) */
403     return !(rhs < lhs);
404 }
405 template <size_t M, size_t N, typename T>
406 bool operator>= (const Matrix<M, N, T>& lhs, const Matrix<M, N, T>& rhs)
407 {
408     /* x >= y  iff  !(y > x)  iff  !(x < y) */
409     return !(lhs < rhs);
410 }
411 template <size_t M, size_t N, typename T>
412 bool operator>(const Matrix<M, N, T>& lhs, const Matrix<M, N, T>& rhs)
413 {
414     /* x > y  iff  y < x */
415     return !(rhs < lhs);
416 }
417 
418 /* Transposition is reasonably straightforward. */ //转置
419 template <size_t M, size_t N, typename T>
420 const Matrix<N, M, T> Transpose(const Matrix<M, N, T>& m)
421 {
422     Matrix<N, M, T> result;
423     for (size_t row = 0; row < m.numRows(); ++row)
424         for (size_t col = 0; col < m.numCols(); ++col)
425             result[col][row] = m[row][col];
426     return result;
427 }
428 
429 /* Identity matrix just fills in the diagonal. */
430 template <size_t M, typename T> Matrix<M, M, T> Identity()
431 {
432     Matrix<M, M, T> result;
433     for (size_t row = 0; row < result.numRows(); ++row)
434         for (size_t col = 0; col < result.numCols(); ++col)
435             result[row][col] = (row == col ? T(1) : T(0));
436     return result;
437 }
438 template <size_t M,size_t N,typename T>
439 void Matrix<M, N, T>::printMatrix(void) const
440 {
441     for (size_t i = 0; i < this->numRows();++i)
442     {
443         for (size_t j = 0; j < this->numCols();++j)
444         {
445             cout << this->at(i, j)<<" ";
446         }
447         cout << endl;
448     }
449 }
450 
455 #endif

测试代码:

原站上并没有测试代码,为了验证类的正确性,自己写了一个简单的测试代码,仅供参考:

 1 #include <iostream>
 2 #include <vector>
 3 #include "matrix.h"
 4 
 5 using namespace std;
 6 
 7 
 8 void testMatrixClass();
 9 
10 int main()
11 {
12     testMatrixClass();
13     
14     return 0;
15 }
16 void testMatrixClass()
17 {
18     vector<int>  vec1,vec2;
19     for (int i = 0; i < 6;++i)
20     {
21         vec1.push_back(i);
22     }
23     for (int i = 0; i < 12; ++i)
24     {
25         vec2.push_back(i+1);
26     }
27     vector<int>::iterator itBegin = vec1.begin();
28     vector<int>::iterator itEnd = vec1.end();
29 
30 
31     Matrix<2, 3, int>m_matrix1(itBegin,itEnd );    //用迭代器构造矩阵对象
32     Matrix<3, 4, int>m_matrix2(vec2.begin(),vec2.end());
33     cout << "---------Matrix 1 = :-----------------" << endl;
34     m_matrix1.printMatrix();
35     cout << "---------Matrix 2 = :-----------------" << endl;
36     m_matrix2.printMatrix();
37     cout << "-----matrix1(1,1) (从0开始)= " << m_matrix1.at(1, 1) << endl;
38     cout << "---matrix1‘s size = " << m_matrix1.size() << " rows = " << m_matrix1.numRows()
39         << " cols = " << m_matrix1.numCols() << endl;
40     cout << "----matrix1 *3 = " <<  endl;
41     (m_matrix1 *= 3).printMatrix();
42     cout << "----matrix1 * matrix 2 = " << endl;
43     Matrix<2, 4, int> result;
44     result = m_matrix1*m_matrix2;
45     result.printMatrix();
46 
47 }

 

实用矩阵类(Matrix)(带测试)

标签:

原文地址:http://www.cnblogs.com/xiaogangpao/p/5025524.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!