高层DataSet底层DataLoader
收藏
高层DataSet底层DataLoader,是这么搭配的吗?
# dataset与mnist网络结构的定义与第一部分内容一致
# 用 DataLoader 实现数据加载
train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)
mnist.train()
# 设置迭代次数
epoch_num = 5
# 设置优化器
optim = paddle.optimizer.Adam(parameters=model.parameters())
# 设置损失函数
loss_fn = paddle.nn.CrossEntropyLoss()
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_loader):
inputs = data[0] # 训练数据
labels = data[1] # 训练数据标签
predicts = mnist(inputs) # 预测结果
# 计算损失 等价于 prepare 中loss的设置
loss = loss_fn(predicts, labels)
# 计算准确率 等价于 prepare 中metrics的设置
acc = paddle.metric.accuracy(predicts, labels)
# 反向传播
loss.backward()
if batch_id % 100 == 0:
print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy()))
# 更新参数
optim.step()
# 梯度清零
optim.clear_grad()
0
收藏
请登录后评论
DataSet 确实与 DataLoader搭配
有事也用BatchSampler或DistributedBatchSampler
Sample好多好多
凡是讲过的我都来抄一遍。
好学生,学习了~