标签:init value roo tick view ase self map load
import torch from torch import nn import numpy as np import matplotlib.pyplot as plt import torch.utils.data as Data import torchvision from mpl_toolkits.mplot3d import Axes3D #画3D图 from matplotlib import cm # Hyper Parameters EPOCH=10 BATCH_SIZE=64 LR = 0.005 # learning rate DOWNLOAD_MNIST=False N_TEST_IMG=5 train_data=torchvision.datasets.MNIST( root=‘./mnist/‘, train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST ) train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True) class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28 * 28, 128), nn.Tanh(), nn.Linear(128,64), nn.Tanh(), nn.Linear(64, 12), # nn.Tanh(), # nn.Linear(12, 3), ) self.decoder=nn.Sequential( # nn.Linear(3,12), # nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid() ) def forward(self, x ): encoder=self.encoder(x) decoder=self.decoder(encoder) return encoder,decoder AutoEncoder = AutoEncoder() # print(AutoEncoder) optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR) # optimize all cnn parameters loss_func = nn.MSELoss() f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2)) plt.ion() # continuously plot view_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255 for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=‘gray‘) a[0][i].set_xticks(()) a[0][i].set_yticks(()) for epoch in range(EPOCH): for step,(x,b_label) in enumerate(train_loader): b_x=x.view(-1,28*28) b_y=x.view(-1,28*28) encoded, decoded = AutoEncoder(b_x) loss=loss_func(decoded,b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step%100==0: print(‘Epoch:|‘,epoch,‘train loss:%0.4f‘%loss.data.numpy()) _,decoded_data=AutoEncoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap=‘gray‘) a[1][i].set_xticks(()) a[1][i].set_yticks(()) plt.draw() plt.pause(0.05) plt.ioff() plt.show() view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255 encoded_data,_=AutoEncoder(view_data) fig=plt.figure(2) ax=Axes3D(fig) X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy() values=train_data.train_labels[:200].numpy() for x,y,z ,s in zip(X,Y,Z,values): c=cm.rainbow(int(255*s/9)) ax.text(x,y,z,s,backgroundcolor=c) ax.set_xlim(X.min(),X.max()) ax.set_ylim(Y.min(),Y.max()) ax.set_zlim(Z.min(),Z.max()) plt.show()
选出五张图片做测试。
图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。
标签:init value roo tick view ase self map load
原文地址:https://www.cnblogs.com/wmy-ncut/p/10190482.html