码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch实现简单的变分自动编码器VAE

时间:2020-03-30 23:49:45      阅读:281      评论:0      收藏:0      [点我收藏+]

标签:not   exp   htm   enc   element   func   test   函数   lamp   

      在上一篇博客中我们介绍并实现了自动编码器,本文将用PyTorch实现变分自动编码器(Variational AutoEncoder, VAE)。自动变分编码器原理与一般的自动编码器的区别在于需要在编码过程增加一点限制,迫使它生成的隐含向量能够粗略的遵循标准正态分布。这样一来,当需要生成一张新图片时,只需要给解码器一个标准正态分布的隐含随机向量就可以了。

      在实际操作中,实际上不是生成一个隐含向量,而是生成两个向量:一个表示均值,一个表示标准差,然后通过这两个统计量合成隐含向量,用一个标准正态分布先乘标准差再加上均值就行了。具体关于变分自动编码器的内容,可参考廖星宇的《深度学习之PyTorch》的第六章,下面的代码也是来自这个资料,但本文对原代码做了一点改动。

import os
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms as tfs
from torchvision.utils import save_image

# Hyper parameters
EPOCH = 1
LR = 1e-3
BATCHSIZE = 128

im_tfs = tfs.Compose([
    tfs.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to
                       # torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]
    tfs.Normalize([0.5], [0.5])   # 把[0.0, 1.0]的数据扩大范围到[-1., 1]
])

train_set = MNIST(
    root=‘/Users/wangpeng/Desktop/all/CS/Courses/Deep Learning/mofan_PyTorch/mnist/‘,   # mnist has been downloaded before, use it directly
    train=True,
    transform=im_tfs,
)
train_loader = DataLoader(train_set, batch_size=BATCHSIZE, shuffle=True)


class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)   # mean
        self.fc22 = nn.Linear(400, 20)   # var
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()                     # 矩阵点对点相乘之后再把这些元素作为e的指数
        eps = torch.FloatTensor(std.size()).normal_()    # 生成随机数组
        if torch.cuda.is_available():
            eps = eps.cuda()
        return eps.mul(std).add_(mu)    # 用一个标准正态分布乘标准差,再加上均值,使隐含向量变为正太分布

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.tanh(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)          # 编码
        z = self.reparametrize(mu, logvar)   # 重新参数化成正态分布
        return self.decode(z), mu, logvar    # 解码,同时输出均值方差


net = VAE()  # 实例化网络
if torch.cuda.is_available():
    net = net.cuda()

reconstruction_function = nn.MSELoss(size_average=False)


def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MSE = reconstruction_function(recon_x, x)
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return MSE + KLD


optimizer = torch.optim.Adam(net.parameters(), lr=LR)


def to_img(x):   # x shape (bachsize, 28*28), x中每个像素点的大小范围[-1., 1.]
    ‘‘‘
    定义一个函数将最后的结果转换回图片
    ‘‘‘
    x = 0.5 * (x + 1.)
    x = x.clamp(0, 1)
    x = x.view(x.shape[0], 1, 28, 28)
    return x


for epoch in range(EPOCH):
    for iteration, (im, y) in enumerate(train_loader):
        im = im.view(im.shape[0], -1)
        if torch.cuda.is_available():
            im = im.cuda()
        recon_im, mu, logvar = net(im)
        loss = loss_function(recon_im, im, mu, logvar) / im.shape[0]   # 将 loss 平均
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if iteration % 100 == 0:
            print(‘epoch: {:2d} | iteration: {:4d} | Loss: {:.4f}‘.format(epoch, iteration, loss.data.numpy()))
            save = to_img(recon_im.cpu().data)
            if not os.path.exists(‘./vae_img‘):
                os.mkdir(‘./vae_img‘)
            save_image(save, ‘./vae_img/image_{}_{}.png‘.format(epoch, iteration))


# test
code = torch.randn(1, 20)   # 随机给一个符合正态分布的张量
out = net.decode(code)
img = to_img(out)
save_image(img, ‘./vae_img/test_img.png‘)

PyTorch实现简单的变分自动编码器VAE

标签:not   exp   htm   enc   element   func   test   函数   lamp   

原文地址:https://www.cnblogs.com/picassooo/p/12601785.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!