标签:code UNC data 设置 transform class batch range 计算
""" 自动编码的核心就是各种全连接的组合,它是一种无监督的形式,因为他的标签是自己。 """ import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D from matplotlib import cm import numpy as np # 超参数 EPOCH = 10 BATCH_SIZE = 64 LR = 0.005 DOWNLOAD_MNIST = False N_TEST_IMG = 5 # Mnist数据集 train_data = torchvision.datasets.MNIST( root=‘./mnist/‘, train=True, transform=torchvision.transforms.ToTensor(), download=DOWNLOAD_MNIST, ) print(train_data.train_data.size()) # (60000, 28, 28) print(train_data.train_labels.size()) # (60000) # 显示出一个例子 plt.imshow(train_data.train_data[2].numpy(), cmap=‘gray‘) plt.title(‘%i‘ % train_data.train_labels[2]) plt.show() # 将数据集分为多批数据 train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 搭建自编码网络框架 class AutoEncoder(nn.Module): def __init__(self): super(AutoEncoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.Tanh(), nn.Linear(128, 64), nn.Tanh(), nn.Linear(64, 12), nn.Tanh(), nn.Linear(12, 3), ) self.decoder = nn.Sequential( nn.Linear(3, 12), nn.Tanh(), nn.Linear(12, 64), nn.Tanh(), nn.Linear(64, 128), nn.Tanh(), nn.Linear(128, 28*28), nn.Sigmoid(), # 将输出结果压缩到0到1之间,因为train_data的数据在0到1之间 ) def forward(self, x): encoded = self.encoder(x) decoded = self.decoder(encoded) return encoded, decoded autoencoder = AutoEncoder() optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR) loss_func = nn.MSELoss() # initialize figure f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2)) plt.ion() # 设置为实时打印 # 第一行是原始图片 view_data = Variable(train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.) for i in range(N_TEST_IMG): a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=‘gray‘); a[0][i].set_xticks(()); a[0][i].set_yticks(()) for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): b_x = Variable(x.view(-1, 28*28)) b_y = Variable(x.view(-1, 28*28)) encoded, decoded = autoencoder(b_x) loss = loss_func(decoded, b_y) optimizer.zero_grad() # 将上一部的梯度清零 loss.backward() # 反向传播,计算梯度 optimizer.step() # 优化网络中的各个参数 if step % 100 == 0: print(‘Epoch: ‘, epoch, ‘| train loss: %.4f‘ % loss.data[0]) # 第二行画出解码后的图片 _, decoded_data = autoencoder(view_data) for i in range(N_TEST_IMG): a[1][i].clear() a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap=‘gray‘) a[1][i].set_xticks(()); a[1][i].set_yticks(()) plt.draw(); plt.pause(0.05) plt.ioff() plt.show() # 可视化三维图 view_data = Variable(train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.) encoded_data, _ = autoencoder(view_data) fig = plt.figure(2); ax = Axes3D(fig) X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy() values = train_data.train_labels[:200].numpy() for x, y, z, s in zip(X, Y, Z, values): c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c) ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max()) plt.show()
标签:code UNC data 设置 transform class batch range 计算
原文地址:https://www.cnblogs.com/czz0508/p/10347065.html