高层DataSet底层DataLoader
收藏
高层DataSet底层DataLoader,是这么搭配的吗?
# dataset与mnist网络结构的定义与第一部分内容一致 # 用 DataLoader 实现数据加载 train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True) mnist.train() # 设置迭代次数 epoch_num = 5 # 设置优化器 optim = paddle.optimizer.Adam(parameters=model.parameters()) # 设置损失函数 loss_fn = paddle.nn.CrossEntropyLoss() for epoch in range(epoch_num): for batch_id, data in enumerate(train_loader): inputs = data[0] # 训练数据 labels = data[1] # 训练数据标签 predicts = mnist(inputs) # 预测结果 # 计算损失 等价于 prepare 中loss的设置 loss = loss_fn(predicts, labels) # 计算准确率 等价于 prepare 中metrics的设置 acc = paddle.metric.accuracy(predicts, labels) # 反向传播 loss.backward() if batch_id % 100 == 0: print("epoch: {}, batch_id: {}, loss is: {}, acc is: {}".format(epoch, batch_id, loss.numpy(), acc.numpy())) # 更新参数 optim.step() # 梯度清零 optim.clear_grad()
0
收藏
请登录后评论
DataSet 确实与 DataLoader搭配
有事也用BatchSampler或DistributedBatchSampler
Sample好多好多
凡是讲过的我都来抄一遍。
好学生,学习了~