PaddlePaddle中如何实现torch.masked_fill类似的功能?
收藏
torch.masked_fill参数可以传入一个mask,将Tensor中的指定位置掩盖掉,但是查看PaddlePaddle的API文档,貌似没有响应的API,那么如何实现类似的功能呢?
答案是可以使用paddle.where这个API来实现类似的功能:
import paddle
x = paddle.rand([3, 3], dtype='float32')
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[0.00276479, 0.45899123, 0.96637046],
# [0.66818708, 0.05855134, 0.33184195],
# [0.34202638, 0.95503175, 0.33745834]])
mask = paddle.randint(0, 2, [3, 3]).astype('bool')
# Tensor(shape=[3, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,
# [[True , True , False],
# [True , True , True ],
# [True , True , True ]])
def masked_fill(x, mask, value):
y = paddle.full(x.shape, value, x.dtype)
return paddle.where(mask, y, x)
out = masked_fill(x, mask, 2)
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[2. , 2. , 0.96637046],
# [2. , 2. , 2. ],
# [2. , 2. , 2. ]])
0
收藏
请登录后评论
这里是paddle.where这个API的详细信息: