首页 Paddle框架 帖子详情
fluid.layers.split  对 batch size 这个通道进行split, 因为维度为-1,assert失败了
收藏
快速回复
Paddle框架 问答深度学习 52352 8
fluid.layers.split  对 batch size 这个通道进行split, 因为维度为-1,assert失败了
收藏
快速回复
Paddle框架 问答深度学习 52352 8
  File "thirdparty/paddlemodels/metric_learning/trainfast_aiflow_elem_multitask.py", line 360, in main
    train_async(args)
  File "thirdparty/paddlemodels/metric_learning/trainfast_aiflow_elem_multitask.py", line 228, in train_async
    train_batch_split=train_batch_size)
  File "thirdparty/paddlemodels/metric_learning/trainfast_aiflow_elem_multitask.py", line 183, in build_program
    avg_cost, acc_top1, acc_top5, out = net_config(image, label, stn, model, args, dtype, is_train, train_batch_split)
  File "thirdparty/paddlemodels/metric_learning/trainfast_aiflow_elem_multitask.py", line 124, in net_config
    embedding_split = fluid.layers.split(out['embedding'], train_batch_split, dim=0)
  File "/home/users/chenkaibing/env/paddle-release/python-gcc482-paddle/lib/python2.7/site-packages/paddle/fluid/layers/nn.py", line 6712, in split
    dim], 'len(num_or_sections) must not be more than input.shape[dim].'
AssertionError: len(num_or_sections) must not be more than input.shape[dim].
0
收藏
回复
全部评论(8)
时间顺序
AIStudio790020
#2 回复于2019-11

I have writen a sample code to test split op based on the develop branch:

import paddle.fluid as fluid
input = fluid.layers.data(shape=[-1, 4], dtype="float32")
x0, x1 = fluid.layers.split(input, num_or_sections=2, dim=0)
print x0.shape
print x1.shape

but no error message found. Which version of paddle do you use?

0
回复
AIStudio792130
#3 回复于2019-11

I have writen a sample code to test split op based on the develop branch:

import paddle.fluid as fluid
input = fluid.layers.data(shape=[-1, 4], dtype="float32")
x0, x1 = fluid.layers.split(input, num_or_sections=2, dim=0)
print x0.shape
print x1.shape

but no error message found. Which version of paddle do you use?

you should test with num_or_sections=[2, 2]

import paddle.fluid as fluid
input = fluid.layers.data(shape=[-1, 4], dtype="float32", name='data')
x0, x1 = fluid.layers.split(input, num_or_sections=[2, 2], dim=0)
print x0.shape
print x1.shape
0
回复
AIStudio790020
#4 回复于2019-11

reference: https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn/split_cn.html#split
if your batch size is fixed as 4=2+2, you would better fix the batch size in data

0
回复
AIStudio792130
#5 回复于2019-11

reference: https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/layers_cn/split_cn.html#split
if your batch size is fixed as 4=2+2, you would better fix the batch size in data

i use pyreader to build network. if i fix the batch size in the pyreader, it will produce tensor with 5 dimension. Because it will append -1 to the shapes.

0
回复
AIStudio792130
#6 回复于2019-11

test code:

import numpy as np
import paddle
import paddle.fluid as fluid


queue_capacity = 64
batch_size = -1
#batch_size = 8
image_shape = [3, 224, 224]
py_reader = fluid.layers.py_reader(
    capacity=queue_capacity,
    shapes=[[batch_size] + image_shape, [batch_size, 1]],
    lod_levels=[0, 0],
    dtypes=["float32", "int64"],
    use_double_buffer=True)
image, label = fluid.layers.read_file(py_reader)

conv = fluid.layers.conv2d(image, 16, 3, 2, 1)
output, _ = fluid.layers.split(conv, num_or_sections=[4, 4], dim=0)

print output.shape
fluid.layers.Print(output)

def reader():
    for _ in xrange(100):
        image = np.random.rand(3, 224, 224)
        label = 0
        yield (image, label)

train_reader = paddle.batch(reader, batch_size=8, drop_last=True)
py_reader.decorate_paddle_reader(train_reader)
py_reader.start()
fetch_list = [output.name]

place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

res = exe.run(
    fluid.default_main_program(),
    fetch_list=fetch_list
)
print res[0].shape
0
回复
AIStudio790020
#7 回复于2019-11

In summary, user need to specify a batch size with pyreader, but the pyreader always appends a dimension if the shape[0] is not -1, which is not as expected

0
回复
AIStudio792130
#8 回复于2019-11

which

right.

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户