标签:
numpy中有numpy.dot函数,用于处理矩阵之间的相乘:
In [2]: a = np.reshape(np.arange(6), (2,3)) In [3]: b = np.reshape(np.arange(6), (3,2)) In [4]: np.dot(a,b) Out[4]: array([[10, 13], [28, 40]])
但如果我们要实现矩阵a * b * c * d,则写三个numpy.dot则略显麻烦.我们可以定义一个函数mdot(a,b,c,d)来完成.
1. 使用reduce
我们可以使用reduce来实现mdot:
In [9]: def mdot(*args): ...: return reduce(np.dot, args) ...: In [10]: a = np.reshape(np.arange(6), (2,3)) In [11]: b = np.reshape(np.arange(6), (3,2)) In [12]: mdot(a,b,a,b) Out[12]: array([[ 464, 650], [1400, 1964]])
2. 控制矩阵相乘的顺序
假设我们想通过()来让mdot执行有序的相乘,则我们需要编写一个递归的函数来完成:
In [13]: import types In [14]: def mdot(*args): ....: if len(args) == 1: ....: return args[0] ....: elif len(args) == 2: ....: return _mdot_r(args[0], args[1]) ....: else: ....: return _mdot_r(args[:-1], args[-1]) ....: In [15]: def _mdot_r(a, b): ....: if type(a) == types.TupleType: ....: if len(a) > 1: ....: a = mdot(*a) ....: else: ....: a = a[0] ....: if type(b) == types.TupleType: ....: if len(b) > 1: ....: b = mdot(*b) ....: else: ....: b = b[0] ....: return np.dot(a, b) ....: In [16]: mdot(b, ((a, b), a)) Out[16]: array([[ 120, 188, 256], [ 438, 688, 938], [ 756, 1188, 1620]]) In [17]: a Out[17]: array([[0, 1, 2], [3, 4, 5]]) In [18]: b Out[18]: array([[0, 1], [2, 3], [4, 5]])
1. 使用名称来标识数组
有两种方式来达到"使用名称来标识数组":recarrays和structured arrays.
structured arrays如下:
In [33]: from numpy import * In [34]: ones(3, dtype=dtype([(‘foo‘, int), (‘bar‘, float)])) Out[34]: array([(1, 1.0), (1, 1.0), (1, 1.0)], dtype=[(‘foo‘, ‘<i8‘), (‘bar‘, ‘<f8‘)]) In [35]: r = _ In [36]: r[‘foo‘] Out[36]: array([1, 1, 1])
而我们可以使用recarray将r转换为:recarray类型
In [48]: r2 = r.view(recarray) In [49]: r2 Out[49]: rec.array([(1, 1.0), (1, 1.0), (1, 1.0)], dtype=[(‘foo‘, ‘<i8‘), (‘bar‘, ‘<f8‘)]) In [50]: r2.foo Out[50]: array([1, 1, 1])
但是r和r2的区别在哪里?
In [56]: r == r2 Out[56]: rec.array([ True, True, True], dtype=bool) In [57]: r.dtype == r2.dtype Out[57]: True In [58]: r.shape == r2.shape Out[58]: True In [59]: type(r) == type(r2) Out[59]: False In [60]: type(r) Out[60]: numpy.ndarray In [61]: type(r2) Out[61]: numpy.core.records.recarray
标签:
原文地址:http://my.oschina.net/u/2422076/blog/484511