码迷,mamicode.com
首页 > 其他好文 > 详细

Softmax实现 fashion.mnist 分类

时间:2018-11-22 19:00:48      阅读:370      评论:0      收藏:0      [点我收藏+]

标签:touch   scale   typeahead   pen   pix   bookmark   apc   ima   tag   

 

 

softmax

 

 

技术分享图片
#!/usr/bin/env python
# coding: utf-8

# In[1]:


get_ipython().run_line_magic(matplotlib, inline)
import gluonbook as gb
from mxnet import autograd,nd


# In[2]:


batch_size = 256
train_iter,test_iter = gb.load_data_fashion_mnist(batch_size)


# In[3]:


num_inputs = 784
num_outputs = 10

W = nd.random.normal(scale=0.01,shape=(num_inputs,num_outputs))
b = nd.zeros(num_outputs)


# In[4]:


W.attach_grad()
b.attach_grad()


# softmax运算

# In[5]:


X = nd.array([[1,2,3],[4,5,6]])
X.sum(axis=0,keepdims=True)


# In[6]:


def softmax(X):
    X_exp = X.exp()
    partition = X_exp.sum(axis = 1,keepdims = True)
    return X_exp / partition


# 例如

# In[7]:


X = nd.random.normal(shape=(2,5))
X_prob = softmax(X)
X_prob,X_prob.sum(axis=1)


# 定义模型

# In[8]:


def net(X):
    return softmax(nd.dot(X.reshape((-1,num_inputs)),W)+b)


# 定义损失函数

# In[9]:


y_hat = nd.array([[0.1,0.3,0.6],[0.3,0.2,0.5]])
y = nd.array([0,2])
nd.pick(y_hat,y)


# 交叉熵损失函数

# In[10]:


def cross_entropy(y_hat,y):
    return - nd.pick(y_hat,y).log()


# 计算分类准确率

# In[11]:


def accuracy(y_hat,y):
    return (y_hat.argmax(axis=1)==y.astype(float32)).mean().asscalar()


# In[12]:


accuracy(y_hat,y)


# 评价 net 在 data_iter上的准确率

# In[13]:


def evaluate_accuracy(data_iter,net):
    acc = 0
    for X,y in data_iter:
        acc += accuracy(net(X),y)
    return acc / len(data_iter)


# In[14]:


evaluate_accuracy(test_iter,net)


# 训练模型

# In[15]:


num_epochs, lr = 5, 0.1

def train(net, train_iter, test_iter, loss, num_epochs, batch_size,
              params=None, lr=None, trainer=None):
    for epoch in range(num_epochs):
        train_l_sum = 0
        train_acc_sum = 0
        for X, y in train_iter:     # 对每个样本预测
            with autograd.record():
                y_hat = net(X)      # 计算预测值 XW+b
                l = loss(y_hat, y)  # 计算交叉熵函数
            l.backward()            # 交叉熵函数求导

            gb.sgd(params, lr, batch_size)   # 修改参数 W,b

            train_l_sum += l.mean().asscalar()
            train_acc_sum += accuracy(y_hat, y)
        test_acc = evaluate_accuracy(test_iter, net)
        print(epoch %d, loss %.4f, train acc %.3f, test acc %.3f
              % (epoch + 1, train_l_sum / len(train_iter),
                 train_acc_sum / len(train_iter), test_acc))

train(net, train_iter, test_iter, cross_entropy, num_epochs,batch_size, [W, b], lr)


# 预测

# In[16]:


for X, y in test_iter:
    break

true_labels = gb.get_fashion_mnist_labels(y.asnumpy())
pred_labels = gb.get_fashion_mnist_labels(net(X).argmax(axis=1).asnumpy())
titles = [true + \n + pred for true, pred in zip(true_labels, pred_labels)]

gb.show_fashion_mnist(X[0:9], titles[0:9])
View Code

 

 

 

技术分享图片

 

Softmax实现 fashion.mnist 分类

标签:touch   scale   typeahead   pen   pix   bookmark   apc   ima   tag   

原文地址:https://www.cnblogs.com/TreeDream/p/10002918.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!