以下是完整的测试代码
#测试性的DataLoader
import paddle
from paddle.io import Dataset,DataLoader
class TestDataset(Dataset):
def __init__(self,num):
self.num_samples = num
def __getitem__(self,idx):
random_img = paddle.randn((3,64,64))
return random_img
def __len__(self):
return self.num_samples
dataset = TestDataset(10000)
#从Dataset中获取的单个元素,预期shape是 [3,64,64]
print(dataset[10].shape) #实际输出确实是 [3,64,64]
#使用DataLoader进行加载,BatchSize设置成 256,预期获取到的数据 是一个 Shape 为 [256,3,64,64] 的Tensor (与pytorch行为一致)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=0)
for data in dataloader:
print(len(data)) #data 不是一个Tensor,而变成了一个List,长度为 3
print(data[0].shape) #data中每个元素的 shape 为 [256,64,64]
break
收藏
点赞
0
个赞
请登录后评论
TOP
切换版块
将 random_img = paddle.randn((3,64,64)) 修改成 random_img = paddle.randn((1,3,64,64)) 可以解决问题,如果是实际图片,可以进行一次reshape(1,3,64,64)来解决问题