Blackops

初心易得,始终难守

0%

numpy.where的两种用法

最近看NMS的实现的时候发现代码里有个np.where(…)[0],这个0就显得很有灵性了,以为是取第一个元素,后来发现并不是这样。查了一波后用代码自己写了下,发现某些情况下有点和想象的不一样。尤其是这个函数的两种完全不同的用法。

一、np.where(condition)

先来份代码:

1
2
3
4
5
6
7
8
9
import numpy as np

a = np.arange(0, 10)
b = a[::-1]
c = np.where(a > b)

print('a.shape: ', a.shape)
print('type(c): ', type(c))
print('c: ', c)

输出:

1
2
3
a.shape:  (10,)
type(c): <class 'tuple'>
c: (array([5, 6, 7, 8, 9], dtype=int64),)

可以发现输出的类型是基本元素类型为numpy.ndarray组成的tuple类(由于单个元素作tuple时,元素后面要加逗号,因此上面的int64括号后有个逗号),那么一维的情况下,得到实际需要的下标就取第一个元素c[0],像这样:

1
print(c[0]) # 这里修改成c[0]就可以拿到实际你需要的下表列表

再来个二维的情况:

1
2
3
4
5
6
7
8
import numpy as np

a = np.array([[1, 2, 3], [4, 5, 6]])
b = np.where(a > 3)
print('a.shape: ', a.shape)
print(b)
print('b[0]: ', b[0])
print('b[1]: ', b[1])

输出:

1
2
3
4
a.shape:  (2, 3)
(array([1, 1, 1], dtype=int64), array([0, 1, 2], dtype=int64))
b[0]: [1 1 1]
b[1]: [0 1 2]

由此可以发现np.where实际输出的是符合条件的元素本身的下标在各自维度组成的多个numpy.ndarray打包成的tuple。这样输出一下:

1
2
for x, y in zip(b[0], b[1]):
print(a[x][y])

输出:

1
2
3
4
5
6

总结一下,假如condition中有n个元素符合条件,设它们为m维元素,记为

那么np.where返回的就是


二、np.where(condition, a, b)

此函数比较复杂,讲下简单的用法:

1
2
3
4
5
6
import numpy as np

a = np.array([1, 2, 3, 4])
b = a[::-1]

print(np.where(a > b, True, False))

输出:

1
[False False  True  True]

输出一个numpy.ndarray,shape与condition中比较的参数形状相同,若condition为真,则该位置赋值为a,否则赋值为b。

二维情况:

1
2
3
4
5
6
import numpy as np

a = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
b = a[::-1]
print(b)
print(np.where(a > b, True, False))

输出:

1
2
[[False False False False]
[ True True True True]]

个人感觉主要用于减少循环语句的使用,因为python numpy模块就是注重效率和运算向量化,因此有必要加强封装性和减少显式循环代码