如何理解 Numpy 中的 axis

使用 numpy 的时候, 很多对 array 的操作都会涉及到 axis, 比如 rot90 (旋转), sum (求和). 二维的 array 还比较容易在脑中形象化或者在纸上画出来, 三维的就已经有难度了, 更高维的 array 就基本上不可能了 (不知道有没有人能够做到, 也许那些数学家物理学家可以?). 无法理解 axis, 在使用 numpy 的过程中就会遇到很多问题, 虽然现在能用到的对二维 array 进行的操作, 但是不好好理解一下总是心里没底. 在这里我记录一下对 axis 的理解, 表达能力有限, 主要目的还是留给自己以后查阅.

矩阵的维度和 axis

通常线性代数教材上介绍的矩阵都是二维矩阵, 即多个向量排列起来. 三维矩阵则是多个二维矩阵排列起来, 更高维的矩阵以此类推. 每叠加一层, 就增加一个维度.

使用 numpy array 创建一个三维矩阵并查看维度:

# 创建一个三维矩阵
a = np.array([[[1,2,3,4],[1,3,4,5]],[[2,4,7,5],[8,4,3,5]],[[2,5,7,3],[1,5,3,7]]])
# 查看维度
np.ndim(a)  # 3
# 查看 shape
a.shape     # (3,2,4)

每个维度对应一个 axis, 因此一个三维的矩阵有三个 axis. 如上所示的三维矩阵中, 三个 axis 0, 1, 2 的长度分别是 3, 2, 4. 如果把矩阵摆在空间中 (例如, 一个三维矩阵摆在三维空间中, 可以形成立方体的形状) 任意两个维度都是正交的, 在三维空间中比较容易形象的理解, 更高维就不太好形象化了, 但是仍然是两两正交. 如何定义或者说明一个 axis 是什么, 我现在还想不到很准确又形象的方法, 不过可以想象一下在空间中, 沿着一个 axis 看过去的情景, 因为这个 axis 是和其他 axis 正交的, 一条视线上就有这个 axis 上的元素, 元素的个数即为 axis 的长度, 在上例中 axis 0 的长度就是 3

array.sum 中的 axis

那么在对矩阵的运算操作中如何理解 axis 呢? 下面举两个例子 (sum 和 sort) 来说明

b = np.array([[[1,2,3,4],[1,3,4,5]],[[2,4,7,5],[8,4,3,5]],[[2,5,7,3],[1,5,3,7]]])
b
# array([[[1, 2, 3, 4],
#         [1, 3, 4, 5]],
#
#        [[2, 4, 7, 5],
#         [8, 4, 3, 5]],
#
#        [[2, 5, 7, 3],
#         [1, 5, 3, 7]]])

b.shape
# (3, 2, 4)

在 axis 0 上进行 sum, 相当与对其他的 axis 上的元素, 在 axis 0 上 iterate 一遍并 sum 就将 axis 0 reduce 掉了, 原先的 shape (3,2,4) 变成了 (2,4)

b1 = b.sum(axis=0)
b1
# array([[ 5, 11, 17, 12],
#        [10, 12, 10, 17]])

b1.shape
# (2, 4)

同理, 在 axis 0 上 sort, 也即在其他 axis 上, 取 axis 0 的元素, 对其进行排序, 排序完成后的 array 中, 除 axis 0 上的 index 以外, 其他 axis 的 index 不变

b = np.array([[[5,2,3,4],[1,0,5,4]],[[2,4,7,5],[8,4,3,5]],[[2,5,7,3],[1,5,3,7]]])
b.sort(axis=0)
b
# array([[[2, 2, 3, 3],
#         [1, 0, 3, 4]],
#
#        [[2, 4, 7, 4],
#         [1, 4, 3, 5]],
#
#        [[5, 5, 7, 5],
#         [8, 5, 5, 7]]])

理解 numpy.rot90 中的 axis

m = np.array([[[5,2,3,4],[1,0,5,4]],[[2,4,7,5],[8,4,3,5]],[[2,5,7,3],[1,5,3,7]]])
m
# array([[[5, 2, 3, 4],
#         [1, 0, 5, 4]],
#
#        [[2, 4, 7, 5],
#         [8, 4, 3, 5]],
#
#        [[2, 5, 7, 3],
#         [1, 5, 3, 7]]])

m.shape
# (3, 2, 4)

在 axis 0 和 1 构成的平面上旋转, 原先位置上的元素旋转到新的 array 后, 除了 axis 0 和 axis 1, 其他 axis 上的 index 不会改变 原先的 axis 0 变成 axis 1, 原先的 axis 1 变成 axis 0

m1 = np.rot90(m, 1, axes=(0,1))
m1
# array([[[1, 0, 5, 4],
#         [8, 4, 3, 5],
#         [1, 5, 3, 7]],
#
#        [[5, 2, 3, 4],
#         [2, 4, 7, 5],
#         [2, 5, 7, 3]]])

m1.shape
# (2, 3, 4)

另一个例子来理解一下旋转的方向

n = np.array([[1, -1], [0, 0]])
n
# array([[ 1, -1],
#        [ 0,  0]])

rot90 的 axes 参数是一个两元素的 tuple, 分别标记为 a1, a2, 表示旋转方向是从 axis a1 的正方向向 axis a2 的正方向旋转 这个例子中即从 axis 0 的正方向向 axis 1 的正方向旋转

np.rot90(n, 1, axes=(0,1))
# array([[-1,  0],
#        [ 1,  0]])

参考资料

  1. Numpy小记——有关axis/axes的理解 https://zhuanlan.zhihu.com/p/25761406

标签

归档