首页 百问百答 帖子详情
softmax层的作用
收藏
快速回复
百问百答 问答学习资料 553 0
softmax层的作用
收藏
快速回复
百问百答 问答学习资料 553 0

softmax一般用于多分类任务中,将输出总和归一化,从而成为预测类别的概率分布,通常之后可以接交叉熵损失函数。

1、Softmax公式
对于一个向量例如[x1,x2,x3,x4,x5],做Softmax运算:

sum = e ** (x1) + e ** (x2) + e ** (x3) + e ** (x4) + e ** (x5)
softmax(x1) = e ** (x1) / sum,同理对于x2,x3,x4,x5也是一样
我们可以看到经过softmax之后,输出值加和为1,且每个值都处于0~1之间,这样就符合概率的定义了。运行如下代码我们可以可视化softmax函数对[2,-1,3,0,1]做的变换。

import matplotlib.pyplot as plt
%matplotlib inline

x=[1,2,3,4,5]
y=[2,-1,3,0,1]  # 原始向量
sum_y, e = 0, 2.71828
for item in y:
    sum_y += e ** item
soft_y = [e ** item / sum_y for item in y]

plt.subplot(1, 2, 1)
color=['peru','peru','peru','peru','peru']
x_label=['x1','x2','x3','x4','x5']
plt.xticks(x, x_label)  # 绘制x刻度标签
plt.bar(x, y,color=color)  # 绘制y刻度标签

#设置网格刻度
plt.grid(True,linestyle=':',color='r',alpha=0.6)

plt.subplot(1, 2, 2)
color=['black','black','black','black','black']
x_label=['x1','x2','x3','x4','x5']
plt.xticks(x, x_label)  # 绘制x刻度标签
plt.bar(x, soft_y,color=color)  # 绘制y刻度标签

#设置网格刻度
plt.grid(True,linestyle=':',color='r',alpha=0.6)

plt.show()

2、Softmax改进版
由于Softmax运算中用到了指数运算,而指数运算则会出现数值爆炸(上溢出)以及数值下溢为0的问题。指数不应过大或过小。我们使用刚才计算Softmax公式来计算[20, 40, 5000],如下代码会报错。

y=[20, 40, 5000]  # 原始向量
sum_y, e = 0, 2.71828
for item in y:
    sum_y += e ** item
soft_y = [e ** item / sum_y for item in y]

一个可行的办法是将指数减去原向量中的最大值来避免上溢。如下代码所示,这两个结果是一样的,但是第二种会更加稳定。

# 第一种
y = [22, 30, 10]
sum_y, e = 0, 2.71828
for item in y:
    sum_y += e ** item
soft_y = [e ** item / sum_y for item in y]
print('第一种', soft_y)

# 第二种
y = [22, 30, 10]
max_y = max(y)
sum_y, e = 0, 2.71828
for item in y:
    sum_y += e ** (item - max_y)
soft_y = [e ** (item - max_y) / sum_y for item in y]
print('第二种', soft_y)

0
收藏
回复
在@后输入用户全名并按空格结束,可艾特全站任一用户