码迷,mamicode.com
首页 > 编程语言 > 详细

softmax函数python实现

时间:2019-05-29 16:29:44      阅读:705      评论:0      收藏:0      [点我收藏+]

标签:code   div   soft   mft   代码   axis   else   softmax   def   

import numpy as np

def softmax(x):
    """
    对输入x的每一行计算softmax。

    该函数对于输入是向量(将向量视为单独的行)或者矩阵(M x N)均适用。
    
    代码利用softmax函数的性质: softmax(x) = softmax(x + c)

    参数:
    x -- 一个N维向量,或者M x N维numpy矩阵.

    返回值:
    x -- 在函数内部处理后的x
    """
    orig_shape = x.shape

    # 根据输入类型是矩阵还是向量分别计算softmax
    if len(x.shape) > 1:
        # 矩阵
        tmp = np.max(x,axis=1) # 得到每行的最大值,用于缩放每行的元素,避免溢出。 shape为(x.shape[0],)
        x -= tmp.reshape((x.shape[0],1)) # 利用性质缩放元素
        x = np.exp(x) # 计算所有值的指数
        tmp = np.sum(x, axis = 1) # 每行求和        
        x /= tmp.reshape((x.shape[0], 1)) # 求softmax
    else:
        # 向量
        tmp = np.max(x) # 得到最大值
        x -= tmp # 利用最大值缩放数据
        x = np.exp(x) # 对所有元素求指数        
        tmp = np.sum(x) # 求元素和
        x /= tmp # 求somftmax
    return x

x = np.array([[1,2,3],[4,7,6]])
print(softmax(x))

 

softmax函数python实现

标签:code   div   soft   mft   代码   axis   else   softmax   def   

原文地址:https://www.cnblogs.com/ljygoodgoodstudydaydayup/p/10944431.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!