首页 Paddle框架 帖子详情
飞桨的高层API来做线性回归训练的问题
收藏
快速回复
Paddle框架 其他深度学习模型训练 1274 1
飞桨的高层API来做线性回归训练的问题
收藏
快速回复
Paddle框架 其他深度学习模型训练 1274 1

官网文档手册:https://www.paddlepaddle.org.cn/documentation/docs/zh/tutorial/quick_start/linear_regression/linear_regression.html

里面的最后部分,高层API来做线性回归训练,发现官网案例效果极差,loss几乎没有收敛。

训练loss没有收敛,eval loss在每次减少2个点的收敛..... 至少这不是一个合格的例子。 我试了下,提高lr可以改善一点,另外就是用SGD效果要好很多。

The loss value printed in the log is the current step, and the metric is the average value of previous steps.
Epoch 1/5
step 51/51 [==============================] - loss: 624.0728 - 2ms/step
Eval begin...
step 13/13 [==============================] - loss: 397.2567 - 878us/step
Eval samples: 102
Epoch 2/5
step 51/51 [==============================] - loss: 422.2296 - 1ms/step
Eval begin...
step 13/13 [==============================] - loss: 394.6901 - 750us/step
Eval samples: 102
Epoch 3/5
step 51/51 [==============================] - loss: 417.4614 - 1ms/step
Eval begin...
step 13/13 [==============================] - loss: 392.1667 - 810us/step
Eval samples: 102
Epoch 4/5
step 51/51 [==============================] - loss: 423.6764 - 1ms/step
Eval begin...
step 13/13 [==============================] - loss: 389.6587 - 772us/step
Eval samples: 102
Epoch 5/5
step 51/51 [==============================] - loss: 461.0751 - 1ms/step
Eval begin...
step 13/13 [==============================] - loss: 387.1344 - 828us/step

 

官网代码:

   

import paddle
paddle.set_default_dtype("float64")

# step1:用高层API定义数据集,无需进行数据处理等,高层API为你一条龙搞定
train_dataset = paddle.text.datasets.UCIHousing(mode='train')
eval_dataset = paddle.text.datasets.UCIHousing(mode='test')

# step2:定义模型
class UCIHousing(paddle.nn.Layer):
    def __init__(self):
        super(UCIHousing, self).__init__()
        self.fc = paddle.nn.Linear(13, 1, None)

    def forward(self, input):
        pred = self.fc(input)
        return pred

# step3:训练模型
model = paddle.Model(UCIHousing())
model.prepare(paddle.optimizer.Adam(parameters=model.parameters()),
              paddle.nn.MSELoss())
model.fit(train_dataset, eval_dataset, epochs=5, batch_size=8, verbose=1)

 

1
收藏
回复
全部评论(1)
时间顺序
JavaRoom
#2 回复于2021-10

哈哈哈,干得漂亮

0
回复
需求/bug反馈?一键提issue告诉我们
发现bug?如果您知道修复办法,欢迎提pr直接参与建设飞桨~
在@后输入用户全名并按空格结束,可艾特全站任一用户