标签:
引言:
无意间看到国外一个网站写的Matrix类,实现了加减乘除基本运算以及各自的const版本等等,功能还算比较完善,,于是记录下来,以备后用:
1 #ifndef MATRIX_H 2 #define MATRIX_H 3 4 #include <iostream> 5 #include <functional> 6 #include <algorithm> 7 8 using namespace std; 9 10 template <size_t Rows,size_t Cols,typename ElemType = double> 11 class Matrix 12 { 13 public: 14 Matrix(); //默认构造函数 15 template<typename InputIterator> 16 Matrix(InputIterator begin, InputIterator end); //用2迭代器构造 17 ~Matrix(); //析构函数 18 19 ElemType & at(size_t row, size_t col); //获取某个元素(引用) 20 const ElemType & at(size_t row, size_t col) const; 21 22 size_t numRows() const; //获得行数、列数 23 size_t numCols() const; 24 size_t size() const; //返回总元素数 25 26 //多维矩阵 27 class MutableReference; 28 class ImmutableReference; 29 MutableReference operator[] (size_t row); 30 ImmutableReference operator[] (size_t row) const; 31 32 typedef ElemType* iterator; //将元素类型指针定义为迭代器 33 typedef const ElemType* const_iterator; 34 35 iterator begin(); 36 iterator end(); 37 const_iterator begin() const; 38 const_iterator end() const; 39 40 iterator row_begin(size_t row); 41 iterator row_end(size_t row); 42 const_iterator row_begin(size_t row) const; 43 const_iterator row_end(size_t row) const; 44 45 Matrix& operator+= (const Matrix& rhs); 46 Matrix& operator-= (const Matrix& rhs); 47 Matrix& operator*= (const ElemType& scalar); 48 Matrix& operator/= (const ElemType& scalar); 49 50 //打印矩阵 51 void printMatrix(void) const; 52 private: 53 ElemType elems[Rows*Cols]; //矩阵元素的数组 54 55 }; 56 //两矩阵相加 57 template <size_t M,size_t N,typename T> 58 const Matrix<M, N, T> operator+ (const Matrix<M, N, T> &lhs, const Matrix<M, N, T> &rhs); 59 //两矩阵相减 60 template <size_t M, size_t N, typename T> 61 const Matrix<M, N, T> operator- (const Matrix<M, N, T> &lhs, const Matrix<M, N, T> &rhs); 62 //矩阵数乘(右乘) 63 template <size_t M, size_t N, typename T> 64 const Matrix<M, N, T> operator* (const Matrix<M, N, T> &lhs, const T& scalar); 65 //矩阵数乘(左乘) 66 template <size_t M, size_t N, typename T> 67 const Matrix<M, N, T> operator* (const T& scalar, const Matrix<M, N, T> &rhs); 68 //矩阵除以一个数 69 template <size_t M, size_t N, typename T> 70 const Matrix<M, N, T> operator/ (const Matrix<M, N, T>& lhs,const T& scalar); 71 //一元运算的加减 相当于添加符号 72 template <size_t M, size_t N, typename T> 73 const Matrix<M, N, T> operator+ (const Matrix<M, N, T>& operand); 74 template <size_t M, size_t N, typename T> 75 const Matrix<M, N, T> operator- (const Matrix<M, N, T>& operand); 76 //2矩阵相乘 77 template <size_t M, size_t N, size_t P, typename T> 78 const Matrix<M, P, T> operator*(const Matrix<M, N, T>& lhs,const Matrix<N, P, T>& rhs); 79 //矩阵的比较操作 80 template <size_t M, size_t N, typename T> 81 bool operator== (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs); 82 83 template <size_t M, size_t N, typename T> 84 bool operator!= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs); 85 86 template <size_t M, size_t N, typename T> 87 bool operator< (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs); 88 89 template <size_t M, size_t N, typename T> 90 bool operator<= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs); 91 92 template <size_t M, size_t N, typename T> 93 bool operator>= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs); 94 95 template <size_t M, size_t N, typename T> 96 bool operator> (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs); 97 //是否为单位矩阵 98 template <size_t M, typename T> 99 Matrix<M, M, T> Identity(); 100 //矩阵转置 101 template <size_t M, size_t N, typename T> 102 const Matrix<N, M, T> Transpose(const Matrix<M, N, T>& m); 103 104 105 ////////////////////////////////////////////////////////////////////////// 106 /************************************************************************/ 107 /* 108 函数的实现部分 109 */ 110 /************************************************************************/ 111 //默认构造函数 112 template <size_t M, size_t N, typename T> 113 Matrix<M, N, T>::Matrix() 114 { 115 } 116 //迭代器构造函数 117 template <size_t M, size_t N, typename T> 118 template <typename InputIterator> 119 Matrix<M, N, T>::Matrix(InputIterator rangeBegin, InputIterator rangeEnd) 120 { 121 std::copy(rangeBegin, rangeEnd, begin()); 122 } 123 //析构函数 124 template <size_t M, size_t N, typename T> 125 Matrix<M, N, T>::~Matrix() 126 { 127 } 128 //得到row,col处元素(引用) 常量版本 129 template <size_t M, size_t N, typename T> 130 const T& Matrix<M, N, T>::at(size_t row, size_t col) const 131 { 132 return *(begin() + row * numCols() + col); 133 } 134 //得到row,col处元素(引用) 非常量版本 135 template <size_t M, size_t N, typename T> 136 T& Matrix<M, N, T>::at(size_t row, size_t col) 137 { 138 return const_cast<T&>(static_cast<const Matrix<M, N, T>*>(this)->at(row, col)); 139 } 140 //得到行数 141 template<size_t M,size_t N,typename T> 142 size_t Matrix<M, N, T>::numRows() const 143 { 144 return M; 145 } 146 template<size_t M,size_t N,typename T> 147 size_t Matrix<M, N, T>::numCols() const 148 { 149 return N; 150 } 151 152 template<size_t M,size_t N,typename T> 153 size_t Matrix<M,N,T>::size() const 154 { 155 return M*N; 156 } 157 //迭代器返回首地址指针 //注意返回是迭代器类型 158 template<size_t M,size_t N,typename T> 159 typename Matrix<M, N, T>::iterator Matrix<M,N,T>::begin() 160 { 161 return elems; 162 } 163 //迭代器返回首地址指针的常量版本 164 template<size_t M, size_t N, typename T> 165 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::begin() const 166 { 167 return elems; 168 } 169 //尾迭代器获取 170 template<size_t M, size_t N, typename T> 171 typename Matrix<M, N, T>::iterator Matrix<M, N, T>::end() 172 { 173 return begin()+size(); 174 } 175 //尾迭代器获取(常量版本) 176 template<size_t M, size_t N, typename T> 177 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::end() const 178 { 179 return begin() + size(); 180 } 181 //行迭代器(跳过指定元素获取) 182 template<size_t M,size_t N,typename T> 183 typename Matrix<M, N, T>::iterator Matrix<M, N, T>::row_begin(size_t row) 184 { 185 return begin() + row*numCols(); 186 } 187 //行迭代器(跳过指定元素获取) 常量版本 188 template <size_t M, size_t N, typename T> 189 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::row_begin(size_t row) const 190 { 191 return begin() + row*numCols(); 192 } 193 //获得行尾迭代器 194 template<size_t M,size_t N,typename T> 195 typename Matrix<M, N, T>::iterator Matrix<M, N, T>::row_end(size_t row) 196 { 197 return row_begin(row) + N; 198 } 199 //获得行尾迭代器 const版本 200 template<size_t M, size_t N, typename T> 201 typename Matrix<M, N, T>::const_iterator Matrix<M, N, T>::row_end(size_t row) const 202 { 203 return row_begin(row) + N; 204 } 205 /************************************************************************/ 206 /* 207 方括号[]操作返回引用的实现(非const版本) 208 */ 209 /************************************************************************/ 210 template <size_t M,size_t N,typename T> 211 class Matrix<M, N, T>::MutableReference 212 { 213 public: 214 T& operator[] (size_t col) 215 { 216 return parent->at(row, col); 217 } 218 private: 219 //私有构造函数 是获得此类实例的为例方法(有元类Matrix可以访问) 220 MutableReference(Matrix* owner, size_t row) :parent(owner), row(row) 221 { 222 223 } 224 friend class Matrix; 225 const size_t row; 226 Matrix *const parent; 227 }; 228 /************************************************************************/ 229 /* 230 方括号[]操作返回引用的实现(const版本) 231 */ 232 /************************************************************************/ 233 template <size_t M, size_t N, typename T> 234 class Matrix<M, N, T>::ImmutableReference 235 { 236 public: 237 const T& operator[] (size_t col) const 238 { 239 return parent->at(row, col); 240 } 241 private: 242 //私有构造函数 是获得此类实例的为例方法(有元类Matrix可以访问) 243 ImmutableReference(const Matrix* owner, size_t row) :parent(owner), row(row) 244 { 245 246 } 247 friend class Matrix; 248 const size_t row; 249 const Matrix *const parent; 250 }; 251 //方括号返回引用的真真实现(用了上面的类) 252 template<size_t M,size_t N,typename T> 253 typename Matrix<M, N, T>::MutableReference Matrix<M, N, T>::operator [] (size_t row) 254 { 255 return MutableReference(this, row); 256 } 257 template<size_t M, size_t N, typename T> 258 typename Matrix<M, N, T>::ImmutableReference Matrix<M, N, T>::operator [] (size_t row) const 259 { 260 return ImmutableReference(this, row); 261 } 262 /************************************************************************/ 263 /* 264 复合运算符实现 265 */ 266 /************************************************************************/ 267 template <size_t M, size_t N, typename T> 268 Matrix<M, N, T>& Matrix<M, N, T>::operator+= (const Matrix<M, N, T>& rhs) 269 { 270 std::transform(begin(), end(), // First input range is lhs 271 rhs.begin(), // Start of second input range is rhs 272 begin(), // Overwrite lhs 273 std::plus<T>()); // Using addition 274 return *this; 275 } 276 277 template <size_t M, size_t N, typename T> 278 Matrix<M, N, T>& Matrix<M, N, T>::operator-= (const Matrix<M, N, T>& rhs) 279 { 280 std::transform(begin(), end(), // First input range is lhs 281 rhs.begin(), // Start of second input range is rhs 282 begin(), // Overwrite lhs 283 std::minus<T>()); // Using subtraction 284 return *this; 285 } 286 template <size_t M, size_t N, typename T> 287 Matrix<M, N, T>& Matrix<M, N, T>::operator*= (const T& scalar) 288 { 289 std::transform(begin(), end(), // Input range is lhs 290 begin(), // Output overwrites lhs 291 std::bind2nd(std::multiplies<T>(), scalar)); // Scalar mult. 292 return *this; 293 } 294 template <size_t M, size_t N, typename T> 295 Matrix<M, N, T>& Matrix<M, N, T>::operator/= (const T& scalar) 296 { 297 std::transform(begin(), end(), // Input range is lhs 298 begin(), // Output overwrites lhs 299 std::bind2nd(std::divides<T>(), scalar)); // Divide by scalar 300 return *this; 301 } 302 /************************************************************************/ 303 /* 304 双目运算符实现 305 */ 306 /************************************************************************/ 307 template <size_t M, size_t N, typename T> 308 const Matrix<M, N, T> operator+ (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 309 { 310 return Matrix<M, N, T>(lhs) += rhs; //用到了复合运算符(成员函数) 311 } 312 template <size_t M, size_t N, typename T> 313 const Matrix<M, N, T> operator- (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 314 { 315 return Matrix<M, N, T>(lhs) -= rhs; 316 } 317 template <size_t M, size_t N, typename T> //(右乘一个数) 318 const Matrix<M, N, T> operator* (const Matrix<M, N, T>& lhs,const T& scalar) 319 { 320 return Matrix<M, N, T>(lhs) *= scalar; 321 } 322 template <size_t M, size_t N, typename T> //左乘一个数 323 const Matrix<M, N, T> operator* (const T& scalar,const Matrix<M, N, T>& rhs) 324 { 325 return Matrix<M, N, T>(rhs) *= scalar; 326 } 327 template <size_t M, size_t N, typename T> 328 const Matrix<M, N, T> operator/ (const Matrix<M, N, T>& lhs,const T& scalar) 329 { 330 return Matrix<M, N, T>(lhs) /= scalar; 331 } 332 //一元运算符+ 333 template <size_t M, size_t N, typename T> 334 const Matrix<M, N, T> operator+ (const Matrix<M, N, T>& operand) { 335 return operand; 336 } 337 //一元运算符- 338 template <size_t M, size_t N, typename T> 339 const Matrix<M, N, T> operator- (const Matrix<M, N, T>& operand) 340 { 341 return Matrix<M, N, T>(operand) *= T(-1); 342 } 343 //2矩阵相乘 344 template <size_t M, size_t N, size_t P, typename T> 345 const Matrix<M, P, T> operator*(const Matrix<M, N, T>& one,const Matrix<N, P, T>& two) 346 { 347 /* Create a result matrix of the right size and initialize it to zero. */ 348 Matrix<M, P, T> result; 349 std::fill(result.begin(), result.end(), T(0)); //初始化结果变量 350 351 /* Now go fill it in. */ 352 for (size_t row = 0; row < result.numRows(); ++row) 353 for (size_t col = 0; col < result.numCols(); ++col) 354 for (size_t i = 0; i < N; ++i) 355 result[row][col] += one[row][i] * two[i][col]; 356 357 return result; 358 } 359 //matrix1*=matrix运算实现 360 template <size_t M, typename T> 361 Matrix<M, M, T>& operator*= (Matrix<M, M, T>& lhs,const Matrix<M, M, T>& rhs) 362 { 363 return lhs = lhs * rhs; // Nothing fancy here. 364 } 365 //比较运算符实现 366 template <size_t M, size_t N, typename T> 367 bool operator== (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 368 { 369 return std::equal(lhs.begin(), lhs.end(), rhs.begin()); 370 } 371 template <size_t M, size_t N, typename T> 372 bool operator!= (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 373 { 374 return !(lhs == rhs); //用了==运算符 375 } 376 /* The less-than operator uses the std::mismatch algorithm to chase down 377 * the first element that differs in the two matrices, then returns whether 378 * the lhs element is less than the rhs element. This is essentially a 379 * lexicographical comparison optimized on the assumption that the two 380 * sequences have the same size. 381 */ 382 //小于运算符 383 template <size_t M, size_t N, typename T> 384 bool operator< (const Matrix<M, N, T>& lhs,const Matrix<M, N, T>& rhs) 385 { 386 /* Compute the mismatch. */ 387 std::pair<typename Matrix<M, N, T>::const_iterator, 388 typename Matrix<M, N, T>::const_iterator> disagreement = 389 std::mismatch(lhs.begin(), lhs.end(), rhs.begin()); 390 391 /* lhs < rhs only if there is a mismatch and the lhs‘s element is 392 * lower than the rhs‘s element. 393 */ 394 return disagreement.first != lhs.end() && 395 *disagreement.first < *disagreement.second; 396 } 397 398 /* The remaining relational operators are implemented in terms of <. */ 399 template <size_t M, size_t N, typename T> 400 bool operator<= (const Matrix<M, N, T>& lhs, const Matrix<M, N, T>& rhs) 401 { 402 /* x <= y iff !(x > y) iff !(y < x) */ 403 return !(rhs < lhs); 404 } 405 template <size_t M, size_t N, typename T> 406 bool operator>= (const Matrix<M, N, T>& lhs, const Matrix<M, N, T>& rhs) 407 { 408 /* x >= y iff !(y > x) iff !(x < y) */ 409 return !(lhs < rhs); 410 } 411 template <size_t M, size_t N, typename T> 412 bool operator>(const Matrix<M, N, T>& lhs, const Matrix<M, N, T>& rhs) 413 { 414 /* x > y iff y < x */ 415 return !(rhs < lhs); 416 } 417 418 /* Transposition is reasonably straightforward. */ //转置 419 template <size_t M, size_t N, typename T> 420 const Matrix<N, M, T> Transpose(const Matrix<M, N, T>& m) 421 { 422 Matrix<N, M, T> result; 423 for (size_t row = 0; row < m.numRows(); ++row) 424 for (size_t col = 0; col < m.numCols(); ++col) 425 result[col][row] = m[row][col]; 426 return result; 427 } 428 429 /* Identity matrix just fills in the diagonal. */ 430 template <size_t M, typename T> Matrix<M, M, T> Identity() 431 { 432 Matrix<M, M, T> result; 433 for (size_t row = 0; row < result.numRows(); ++row) 434 for (size_t col = 0; col < result.numCols(); ++col) 435 result[row][col] = (row == col ? T(1) : T(0)); 436 return result; 437 } 438 template <size_t M,size_t N,typename T> 439 void Matrix<M, N, T>::printMatrix(void) const 440 { 441 for (size_t i = 0; i < this->numRows();++i) 442 { 443 for (size_t j = 0; j < this->numCols();++j) 444 { 445 cout << this->at(i, j)<<" "; 446 } 447 cout << endl; 448 } 449 } 450 455 #endif
测试代码:
原站上并没有测试代码,为了验证类的正确性,自己写了一个简单的测试代码,仅供参考:
1 #include <iostream> 2 #include <vector> 3 #include "matrix.h" 4 5 using namespace std; 6 7 8 void testMatrixClass(); 9 10 int main() 11 { 12 testMatrixClass(); 13 14 return 0; 15 } 16 void testMatrixClass() 17 { 18 vector<int> vec1,vec2; 19 for (int i = 0; i < 6;++i) 20 { 21 vec1.push_back(i); 22 } 23 for (int i = 0; i < 12; ++i) 24 { 25 vec2.push_back(i+1); 26 } 27 vector<int>::iterator itBegin = vec1.begin(); 28 vector<int>::iterator itEnd = vec1.end(); 29 30 31 Matrix<2, 3, int>m_matrix1(itBegin,itEnd ); //用迭代器构造矩阵对象 32 Matrix<3, 4, int>m_matrix2(vec2.begin(),vec2.end()); 33 cout << "---------Matrix 1 = :-----------------" << endl; 34 m_matrix1.printMatrix(); 35 cout << "---------Matrix 2 = :-----------------" << endl; 36 m_matrix2.printMatrix(); 37 cout << "-----matrix1(1,1) (从0开始)= " << m_matrix1.at(1, 1) << endl; 38 cout << "---matrix1‘s size = " << m_matrix1.size() << " rows = " << m_matrix1.numRows() 39 << " cols = " << m_matrix1.numCols() << endl; 40 cout << "----matrix1 *3 = " << endl; 41 (m_matrix1 *= 3).printMatrix(); 42 cout << "----matrix1 * matrix 2 = " << endl; 43 Matrix<2, 4, int> result; 44 result = m_matrix1*m_matrix2; 45 result.printMatrix(); 46 47 }
标签:
原文地址:http://www.cnblogs.com/xiaogangpao/p/5025524.html