首页 Paddle框架 帖子详情
Paddlefl相关问题 已解决
收藏
快速回复
Paddle框架 问答学习资料 458 1
Paddlefl相关问题 已解决
收藏
快速回复
Paddle框架 问答学习资料 458 1

您好,我在使用paddlefl框架时遇到了如下问题:下面的乘法代码运行时会报如下错误,当我注释掉“ op_ge = pfl_mpc.layers.greater_equal(x=x, y=y)”这一行代码时,乘法可以正常运行。我的理解是pfl_mpc.layers某些算法不能放在一起使用,例如基于混淆电路的greater_equal和基于算数电路的mul。如果我想在一起使用这些算法的话,需要怎么做呢?

 

FAIL: test_ge (main.TestOptest)
Traceback (most recent call last):
File "test.py", line 51, in test_ge
self.assertEqual(ret[0], True)
AssertionError: EnforceNotMet('\n\n----------------------[5701 chars]or]') != True

 

import unittest
from multiprocessing import Manager

import numpy as np
import paddle.fluid as fluid
import paddle_fl.mpc as pfl_mpc
import test_op_base
from paddle_fl.mpc.data_utils.data_utils import get_datautils

aby3 = get_datautils('aby3')

class TestOptest(test_op_base.TestOpBase):

def ge(self, **kwargs):
    role = kwargs['role']
    d_1 = kwargs['data_1'][role]
    d_2 = kwargs['data_2'][role]
    d_3 = kwargs['data_3'][role]
    expected_out = kwargs['expect_results'][role]
    pfl_mpc.init("aby3", role, "localhost", self.server, int(self.port))
    x = pfl_mpc.data(name='x', shape=[1], dtype='int64')
    y = fluid.data(name='y', shape=[1], dtype='float32')
    z = pfl_mpc.data(name='z', shape=[1], dtype='int64')
    op_ge = pfl_mpc.layers.greater_equal(x=x, y=y)
    exe = fluid.Executor(place=fluid.CPUPlace())
    #results = exe.run(feed={'x': d_1, 'y': d_2}, fetch_list=[op_ge])
    #print(results)
    #self.assertEqual(results[0].shape, (1, ))
    #self.assertTrue(np.allclose(results[0], expected_out))

    print("here")
    op_mul = pfl_mpc.layers.mul(x=x, y=z)
    results2 = exe.run(feed={'x': d_1, 'z': d_1}, fetch_list=[op_mul])
    print(results2)

def test_ge(self):
    data_1 = np.full((1), fill_value=6553.6)
    data_1_shares = aby3.make_shares(data_1)
    data_1_all3shares = np.array([aby3.get_shares(data_1_shares, i) for i in range(3)])
    data_2 = [np.array([65536]).astype('float32')] * self.party_num
    data_3 = np.full((1), fill_value=655.36)
    data_3_shares = aby3.make_shares(data_1)
    data_3_all3shares = np.array([aby3.get_shares(data_1_shares, i) for i in range(3)])
    expect_results = [np.array([0])] * self.party_num
    ret = self.multi_party_run(target=self.ge,
                               data_1=data_1_all3shares,
                               data_2=data_2,
                               data_3=data_3_all3shares,
                               expect_results=expect_results)
    self.assertEqual(ret[0], True)
if name == 'main':
    unittest.main()
s
swjkz
已解决
2# 回复于2022-03
0
收藏
回复
全部评论(1)
时间顺序
s
swjkz
#2 回复于2022-03
0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户