码迷,mamicode.com
首页 > Web开发 > 详细

encode与decode

时间:2018-12-28 15:26:50      阅读:334      评论:0      收藏:0      [点我收藏+]

标签:init   value   roo   tick   view   ase   self   map   load   

import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
import torchvision
from mpl_toolkits.mplot3d import Axes3D    #画3D图
from matplotlib import cm
# Hyper Parameters
EPOCH=10
BATCH_SIZE=64
LR = 0.005 # learning rate
DOWNLOAD_MNIST=False
N_TEST_IMG=5

train_data=torchvision.datasets.MNIST(
    root=./mnist/,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.Tanh(),
            nn.Linear(128,64),
            nn.Tanh(),
            nn.Linear(64, 12),
            # nn.Tanh(),
            # nn.Linear(12, 3),
        )
        self.decoder=nn.Sequential(
            # nn.Linear(3,12),
            # nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid()

        )

    def forward(self, x ):
       encoder=self.encoder(x)
       decoder=self.decoder(encoder)
       return  encoder,decoder


AutoEncoder = AutoEncoder()
# print(AutoEncoder)

optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()

f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2))

plt.ion()  # continuously plot

view_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255

for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=gray)
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())

for epoch in range(EPOCH):
    for step,(x,b_label) in enumerate(train_loader):
        b_x=x.view(-1,28*28)
        b_y=x.view(-1,28*28)
        encoded, decoded = AutoEncoder(b_x)
        loss=loss_func(decoded,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step%100==0:
            print(Epoch:|,epoch,train loss:%0.4f%loss.data.numpy())
            _,decoded_data=AutoEncoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap=gray)
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            plt.draw()
            plt.pause(0.05)
plt.ioff()
plt.show()

view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255
encoded_data,_=AutoEncoder(view_data)
fig=plt.figure(2)
ax=Axes3D(fig)
X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values=train_data.train_labels[:200].numpy()
for x,y,z ,s in zip(X,Y,Z,values):
    c=cm.rainbow(int(255*s/9))
    ax.text(x,y,z,s,backgroundcolor=c)
ax.set_xlim(X.min(),X.max())
ax.set_ylim(Y.min(),Y.max())
ax.set_zlim(Z.min(),Z.max())
plt.show()

选出五张图片做测试。

图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。

encode与decode

标签:init   value   roo   tick   view   ase   self   map   load   

原文地址:https://www.cnblogs.com/wmy-ncut/p/10190482.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!