one_hot api参数失效
收藏
0
收藏
全部评论(7)
train_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(train_program, startup_program):
label = fluid.layers.data(name="label", shape=[4], append_batch_size=False, dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=4, allow_out_of_range=True)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(startup_program)
compiled_prog = fluid.compiler.CompiledProgram(train_program)
x1 = np.array([[1, 1, 5, 0]])
res = exe.run(compiled_prog,
feed={"label": x1}, fetch_list=["one_hot_v2_0.tmp_0"]
)[0][0]
expect = [[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 0.],
[1., 0., 0., 0.]]
0
请登录后评论
https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/one_hot_cn.html
allow_out_of_range这个参数不生效