model.train()
收藏
代码中的model.train()与model.eval()有什么作用?
def train(model): model.train() opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters()) for epoch in range(epochs): for batch_id, data in enumerate(train_loader): sent = data[0] label = data[1] logits = model(sent) loss = paddle.nn.functional.cross_entropy(logits, label) if batch_id % 500 == 0: print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy())) loss.backward() opt.step() opt.clear_grad() # evaluate model after one epoch model.eval() accuracies = [] losses = [] for batch_id, data in enumerate(test_loader): sent = data[0] label = data[1] logits = model(sent) loss = paddle.nn.functional.cross_entropy(logits, label) acc = paddle.metric.accuracy(logits, label) accuracies.append(acc.numpy()) losses.append(loss.numpy()) avg_acc, avg_loss = np.mean(accuracies), np.mean(losses) print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss)) model.train() model = MyNet() train(model)
0
收藏
请登录后评论
开启训练模式吧, 我也不知道具体什么作用,希望有大牛解释下