训练的时候没问题,inference的时候出错
收藏
# -*- coding: UTF-8 -*- import paddle import paddle.fluid as fluid import numpy as np import pickle import pdb #读取数据 def get_data(seq_path = "../bandwidth_dataset/data_interval_300.pickle"): with open(seq_path, 'rb') as f: data = pickle.load(f) return data def prediction(future): data = get_data() mean = np.mean(data) std = np.std(data) input = np.expand_dims(data[-12:], axis = 1).tolist() predict = [] place = fluid.CPUPlace() exe = fluid.Executor(place) path = "./model" [inference_program, feed_target_names, fetch_targets] = \ fluid.io.load_inference_model(dirname=path, executor=exe) for i in range(future): normal_input = (input-mean)/std #normal_input:(12,1)的np array lod_tensor_input = fluid.create_lod_tensor(normal_input,[[12]], place) results = exe.run(inference_program, \ feed={feed_target_names[0]:normal_input},fetch_list=fetch_targets) prediction(1)
报错如下
Traceback (most recent call last): File "predict.py", line 86, in prediction(1) File "predict.py", line 75, in prediction feed={feed_target_names[0]:normal_input},fetch_list=fetch_targets) File "/home/work/.pyenv/versions/2.7.13/lib/python2.7/site-packages/paddle/fluid/executor.py", line 443, in run self.executor.run(program.desc, scope, 0, True, True) paddle.fluid.core.EnforceNotMet: DataType of Paddle Op mul must be the same. Get 6 != 5 at [/paddle/paddle/fluid/framework/operator.cc:722] PaddlePaddle Call Stacks: 0 0x7fbb1b4e3376p paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) + 486 1 0x7fbb1bcfe304p paddle::framework::OperatorWithKernel::IndicateDataType(paddle::framework::ExecutionContext const&) const + 580 2 0x7fbb1bcfe46fp paddle::framework::OperatorWithKernel::GetExpectedKernelType(paddle::framework::ExecutionContext const&) const + 47 3 0x7fbb1bcfe97bp paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, boost::variant const&) const + 235 4 0x7fbb1bcfc450p paddle::framework::OperatorBase::Run(paddle::framework::Scope const&, boost::variant const&) + 208 5 0x7fbb1b576cdfp paddle::framework::Executor::RunPreparedContext(paddle::framework::ExecutorPrepareContext*, paddle::framework::Scope*, bool, bool, bool) + 255 6 0x7fbb1b577d30p paddle::framework::Executor::Run(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool) + 128 7 0x7fbb1b4f9fabp void pybind11::cpp_function::initialize(void (paddle::framework::Executor::*)(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool)#1}, void, paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool, pybind11::name, pybind11::is_method, pybind11::sibling>(pybind11::cpp_function::initialize(void (paddle::framework::Executor::*)(paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool)#1}&&, void (*)(paddle::framework::Executor*, paddle::framework::ProgramDesc const&, paddle::framework::Scope*, int, bool, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&)::{lambda(pybind11::detail::function_call&)#3}::_FUN(pybind11::detail::function_call) + 555 8 0x7fbb1b4f246cp pybind11::cpp_function::dispatcher(_object*, _object*, _object*) + 2540 9 0x4a9cb8p PyEval_EvalFrameEx + 31592 10 0x4ab507p PyEval_EvalCodeEx + 2167 11 0x4a9a77p PyEval_EvalFrameEx + 31015 12 0x4a9b98p PyEval_EvalFrameEx + 31304 13 0x4ab507p PyEval_EvalCodeEx + 2167 14 0x4ab612p PyEval_EvalCode + 50 15 0x4cc74ep PyRun_FileExFlags + 318 16 0x4cc977p PyRun_SimpleFileExFlags + 231 17 0x415112p Py_Main + 2850 18 0x318ae1ecddp __libc_start_main + 253 19 0x4141f9p
0
收藏
请登录后评论
已经解决了,训练的时候数据类型是np.float32,预测的时候读进去默认是float64,哎,注意格式一致啊!!