首页 Paddle框架 帖子详情
用自己加载的数据集,DataLoader运行报错
收藏
快速回复
Paddle框架 问答深度学习模型训练 1116 6
用自己加载的数据集,DataLoader运行报错
收藏
快速回复
Paddle框架 问答深度学习模型训练 1116 6

李宏毅机器学习课程练习题,加载数据集就出问题了,求好心大佬看看怎么处理

"""
构建数据集集
"""
import numpy as np
import pandas as pd
import paddle
from paddle.io import Dataset
from PIL import Image
import glob,os

"""
图片大小不一致,需要进行缩放
"""
def resize_tensor(img,W,H):
    r_img=img.resize((W,H), Image.ANTIALIAS)
    t=paddle.vision.transforms.to_tensor(r_img)
    return t

class MyDataset(Dataset):
    """
    步骤一:继承paddle.io.Dataset类
    """
    def __init__(self, mode='train'):
        """
        步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
        """
        super(MyDataset, self).__init__()
        self.data=[]
        if mode == 'train':
            path='work/food-11/training'
            file=glob.glob(os.path.join(path, "*.jpg"))
            for f in file:
                data=[Image.open(f),f[f.find('/',len(path))+1:f.find('_')]]
                self.data.append(data)
        else:
            path='work/food-11/validation'
            file=glob.glob(os.path.join(path, "*.jpg"))
            for f in file:
                data=[Image.open(f),f[f.find('/',len(path))+1:f.find('_')]]
                self.data.append(data)

    def __getitem__(self, index):
        """
        步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
        """
        img_t = resize_tensor(self.data[index][0],512,512) #图片大部分是512*512
        label = self.data[index][1]

        return img_t, label

    def __len__(self):
        """
        步骤四:实现__len__方法,返回数据集总数目
        """
        return len(self.data)
train_data=MyDataset(mode='train')
validation_data=MyDataset(mode='validation')

train_loader = paddle.io.DataLoader(train_data, batch_size=32, shuffle=True)
for batch_id, data in enumerate(train_loader()):
    print(batch_id)

错误提示截图如下:

求好心大佬帮忙看看,谢谢啦

0
收藏
回复
全部评论(6)
时间顺序
韩泽
#2 回复于2021-06

解决了吗我也遇到了这个问题

0
回复
FutureSI
#3 回复于2021-06

报这个错,一般都是 __getitem__() 里的错误

0
回复
FutureSI
#4 回复于2021-06

可以检查下数据是否被正确读取

0
回复
FutureSI
#5 回复于2021-06

由于这个报错不会显示具体哪行代码出错,可以用 print 语句定位下

0
回复
g
gameaholic
#6 回复于2021-06

问题解决了,是因为__getitem__,返回的lable必须为int类型。 改成label = int(self.data[index][1])就可以了

0
回复
FutureSI
#7 回复于2021-06
问题解决了,是因为__getitem__,返回的lable必须为int类型。 改成label = int(self.data[index][1])就可以了

一般label要int,其它的要float32类型

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户