首页 Paddle框架 帖子详情
PaddlePaddle中如何实现torch.masked_fill类似的功能? 已解决
收藏
快速回复
Paddle框架 其他学习资料 936 1
PaddlePaddle中如何实现torch.masked_fill类似的功能? 已解决
收藏
快速回复
Paddle框架 其他学习资料 936 1

torch.masked_fill参数可以传入一个mask,将Tensor中的指定位置掩盖掉,但是查看PaddlePaddle的API文档,貌似没有响应的API,那么如何实现类似的功能呢?

答案是可以使用paddle.where这个API来实现类似的功能:

import paddle

x = paddle.rand([3, 3], dtype='float32')
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#        [[0.00276479, 0.45899123, 0.96637046],
#         [0.66818708, 0.05855134, 0.33184195],
#         [0.34202638, 0.95503175, 0.33745834]])

mask = paddle.randint(0, 2, [3, 3]).astype('bool')
# Tensor(shape=[3, 3], dtype=bool, place=CUDAPlace(0), stop_gradient=True,
#        [[True , True , False],
#         [True , True , True ],
#         [True , True , True ]])

def masked_fill(x, mask, value):
    y = paddle.full(x.shape, value, x.dtype)
    return paddle.where(mask, y, x)

out = masked_fill(x, mask, 2)
# Tensor(shape=[3, 3], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
#        [[2.        , 2.        , 0.96637046],
#         [2.        , 2.        , 2.        ],
#         [2.        , 2.        , 2.        ]])

 

DeepGeGe
已解决
2# 回复于2021-12
这里是paddle.where这个API的详细信息: [图片]
展开
0
收藏
回复
全部评论(1)
时间顺序
DeepGeGe
#2 回复于2021-12

这里是paddle.where这个API的详细信息:

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户