softmax层的作用
收藏
softmax一般用于多分类任务中,将输出总和归一化,从而成为预测类别的概率分布,通常之后可以接交叉熵损失函数。
1、Softmax公式
对于一个向量例如[x1,x2,x3,x4,x5],做Softmax运算:
sum = e ** (x1) + e ** (x2) + e ** (x3) + e ** (x4) + e ** (x5)
softmax(x1) = e ** (x1) / sum,同理对于x2,x3,x4,x5也是一样
我们可以看到经过softmax之后,输出值加和为1,且每个值都处于0~1之间,这样就符合概率的定义了。运行如下代码我们可以可视化softmax函数对[2,-1,3,0,1]做的变换。
import matplotlib.pyplot as plt %matplotlib inline x=[1,2,3,4,5] y=[2,-1,3,0,1] # 原始向量 sum_y, e = 0, 2.71828 for item in y: sum_y += e ** item soft_y = [e ** item / sum_y for item in y] plt.subplot(1, 2, 1) color=['peru','peru','peru','peru','peru'] x_label=['x1','x2','x3','x4','x5'] plt.xticks(x, x_label) # 绘制x刻度标签 plt.bar(x, y,color=color) # 绘制y刻度标签 #设置网格刻度 plt.grid(True,linestyle=':',color='r',alpha=0.6) plt.subplot(1, 2, 2) color=['black','black','black','black','black'] x_label=['x1','x2','x3','x4','x5'] plt.xticks(x, x_label) # 绘制x刻度标签 plt.bar(x, soft_y,color=color) # 绘制y刻度标签 #设置网格刻度 plt.grid(True,linestyle=':',color='r',alpha=0.6) plt.show()
2、Softmax改进版
由于Softmax运算中用到了指数运算,而指数运算则会出现数值爆炸(上溢出)以及数值下溢为0的问题。指数不应过大或过小。我们使用刚才计算Softmax公式来计算[20, 40, 5000],如下代码会报错。
y=[20, 40, 5000] # 原始向量 sum_y, e = 0, 2.71828 for item in y: sum_y += e ** item soft_y = [e ** item / sum_y for item in y]
一个可行的办法是将指数减去原向量中的最大值来避免上溢。如下代码所示,这两个结果是一样的,但是第二种会更加稳定。
# 第一种 y = [22, 30, 10] sum_y, e = 0, 2.71828 for item in y: sum_y += e ** item soft_y = [e ** item / sum_y for item in y] print('第一种', soft_y) # 第二种 y = [22, 30, 10] max_y = max(y) sum_y, e = 0, 2.71828 for item in y: sum_y += e ** (item - max_y) soft_y = [e ** (item - max_y) / sum_y for item in y] print('第二种', soft_y)
0
收藏
请登录后评论