首页 Paddle框架 帖子详情
构建skipgram模型时出错
收藏
快速回复
Paddle框架 问答深度学习模型训练 1250 1
构建skipgram模型时出错
收藏
快速回复
Paddle框架 问答深度学习模型训练 1250 1

利用paddlepaddle构建skipgram模型,其中调用了fluid.layers.nce接口,在使用CPU的情况下正常,在使用GPU的情况下报错
源代码

embed_word = fluid.layers.data(name = 'embed_word', shape = [1], dtype = 'int64')
label_word = fluid.layers.data(name = 'label_word', shape = [1], dtype = 'int64')
embed = fluid.layers.embedding(input = embed_word, size = [self.vocabulary_size, self.embedding_size], param_attr = 'embed_w', is_sparse = is_sparse)
loss = fluid.layers.nce(input = embed, label = label_word, num_total_classes = self.vocabulary_size, param_attr = 'nce_w', bias_attr = 'nce_b', num_neg_samples = 10)
avg_loss = fluid.layers.mean(loss)

错误

paddle.fluid.core.EnforceNotMet: op nce does not have kernel for data_type[float]:data_layout[ANY_LAYOUT]:place[CUDAPlace(0)]:library_type[PLAIN] at [/paddle/paddle/fluid/framework/operator.cc:678]
PaddlePaddle Call Stacks: 
0       0x7f36f331ba16p paddle::platform::EnforceNotMet::EnforceNotMet(std::__exception_ptr::exception_ptr, char const*, int) + 486
1       0x7f36f4216c5fp paddle::framework::OperatorWithKernel::RunImpl(paddle::framework::Scope const&, boost::variant const&) const + 1231
2       0x7f36f4213a2cp paddle::framework::OperatorBase::Run(paddle::framework::Scope const&, boost::variant const&) + 252
3       0x7f36f4057ac7p
4       0x7f36f4075c50p
5       0x7f36f40754b5p paddle::framework::details::OpHandleBase::RunAndRecordEvent(std::function const&) + 805
6       0x7f36f405759fp paddle::framework::details::ComputationOpHandle::RunImpl() + 95
7       0x7f36f4076555p paddle::framework::details::OpHandleBase::Run(bool) + 117
8       0x7f36f4035e5ap
9       0x7f36f3eb19e3p std::_Function_handler (), std::__future_base::_Task_setter, std::__future_base::_Result_base::_Deleter>, void> >::_M_invoke(std::_Any_data const&) + 35
10      0x7f36f3471a77p std::__future_base::_State_base::_M_do_set(std::function ()>&, bool&) + 39
11      0x7f3770848a99p
12      0x7f36f4034c62p
13      0x7f36f34735b4p ThreadPool::ThreadPool(unsigned long)::{lambda()#1}::operator()() const + 404
14      0x7f370101cc5cp
15      0x7f37708416bap
16      0x7f376fe6741dp clone + 109
0
收藏
回复
全部评论(1)
时间顺序
liberyu
#2 回复于2018-09

数据类型错误,你查看下nce这个代码封装的数据输入应该是什么类型的,在gpu状态下可能需要将其中的一个输入类型做转换

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户