最近看NMS的实现的时候发现代码里有个np.where(…)[0],这个0就显得很有灵性了,以为是取第一个元素,后来发现并不是这样。查了一波后用代码自己写了下,发现某些情况下有点和想象的不一样。尤其是这个函数的两种完全不同的用法。
一、np.where(condition)
先来份代码:
1 | import numpy as np |
输出:
1 | a.shape: (10,) |
可以发现输出的类型是基本元素类型为numpy.ndarray组成的tuple类(由于单个元素作tuple时,元素后面要加逗号,因此上面的int64括号后有个逗号),那么一维的情况下,得到实际需要的下标就取第一个元素c[0],像这样:
1 | print(c[0]) # 这里修改成c[0]就可以拿到实际你需要的下表列表 |
再来个二维的情况:
1 | import numpy as np |
输出:
1 | a.shape: (2, 3) |
由此可以发现np.where实际输出的是符合条件的元素本身的下标在各自维度组成的多个numpy.ndarray打包成的tuple。这样输出一下:
1 | for x, y in zip(b[0], b[1]): |
输出:
1 | 4 |
总结一下,假如condition中有n个元素符合条件,设它们为m维元素,记为
那么np.where返回的就是
二、np.where(condition, a, b)
此函数比较复杂,讲下简单的用法:
1 | import numpy as np |
输出:
1 | [False False True True] |
输出一个numpy.ndarray,shape与condition中比较的参数形状相同,若condition为真,则该位置赋值为a,否则赋值为b。
二维情况:
1 | import numpy as np |
输出:
1 | [[False False False False] |
个人感觉主要用于减少循环语句的使用,因为python numpy模块就是注重效率和运算向量化,因此有必要加强封装性和减少显式循环代码