首页 Paddle框架 帖子详情
model.train()
收藏
快速回复
Paddle框架 问答模型训练 791 1
model.train()
收藏
快速回复
Paddle框架 问答模型训练 791 1

代码中的model.train()与model.eval()有什么作用?

def train(model):

    model.train()
    opt = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())

    for epoch in range(epochs):
        for batch_id, data in enumerate(train_loader):

            sent = data[0]
            label = data[1]

            logits = model(sent)
            loss = paddle.nn.functional.cross_entropy(logits, label)

            if batch_id % 500 == 0:
                print("epoch: {}, batch_id: {}, loss is: {}".format(epoch, batch_id, loss.numpy()))

            loss.backward()
            opt.step()
            opt.clear_grad()

        # evaluate model after one epoch
        model.eval()
        accuracies = []
        losses = []

        for batch_id, data in enumerate(test_loader):

            sent = data[0]
            label = data[1]

            logits = model(sent)
            loss = paddle.nn.functional.cross_entropy(logits, label)
            acc = paddle.metric.accuracy(logits, label)

            accuracies.append(acc.numpy())
            losses.append(loss.numpy())

        avg_acc, avg_loss = np.mean(accuracies), np.mean(losses)
        print("[validation] accuracy/loss: {}/{}".format(avg_acc, avg_loss))

        model.train()

model = MyNet()
train(model)
0
收藏
回复
全部评论(1)
时间顺序
快播下载的幸福
#2 回复于2021-11

开启训练模式吧, 我也不知道具体什么作用,希望有大牛解释下

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户