码迷,mamicode.com
首页 > 其他好文 > 详细

pytorch实现VAE

时间:2017-09-28 13:15:27      阅读:288      评论:0      收藏:0      [点我收藏+]

标签:shuf   def   names   interval   action   data   vol   any   epo   

一、VAE的具体结构

技术分享

二、VAE的pytorch实现

1加载并规范化MNIST

 import相关类:

from __future__ import print_function
import argparse
import torch
import torch.utils.data
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms

 设置参数:

parser = argparse.ArgumentParser(description=PyTorch MNIST Example)
parser.add_argument(--batch-size, type=int, default=128, metavar=N,
                    help=input batch size for training (default: 128))
parser.add_argument(--epochs, type=int, default=10, metavar=N,
                    help=number of epochs to train (default: 10))
parser.add_argument(--no-cuda, action=store_true, default=False,
                    help=enables CUDA training)
parser.add_argument(--seed, type=int, default=1, metavar=S,
                    help=random seed (default: 1))
parser.add_argument(--log-interval, type=int, default=10, metavar=N,
                    help=how many batches to wait before logging training status)
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
print(args)

#Sets the seed for generating random numbers. And returns a torch._C.Generator object.
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)    

 

输出结果:

Namespace(batch_size=128, cuda=True, epochs=10, log_interval=10, no_cuda=False, seed=1)

下载数据集到./data/目录下:

kwargs = {num_workers: 1, pin_memory: True} if args.cuda else {}
trainset = datasets.MNIST(../data, train=True, download=True,transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=args.batch_size, shuffle=True, **kwargs)
testset= datasets.MNIST(../data, train=False, transform=transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(
    testset,
    batch_size=args.batch_size, shuffle=True, **kwargs)
image, label = trainset[0]  
print(len(trainset))
print(image.size())
image, label = testset[0]  
print(len(testset))
print(image.size())

输出结果:

60000
torch.Size([1, 28, 28])
10000
torch.Size([1, 28, 28])

2定义VAE

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(784, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 784)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def encode(self, x):
        h1 = self.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = Variable(std.data.new(std.size()).normal_())
        return eps.mul(std).add_(mu)

    def decode(self, z):
        h3 = self.relu(self.fc3(z))
        return self.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 784))
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar

model = VAE()
if args.cuda:
    model.cuda()

3.定义一个损失函数


reconstruction_function = nn.BCELoss()
reconstruction_function.size_average = False

def loss_function(recon_x, x, mu, logvar):
    BCE = reconstruction_function(recon_x, x.view(-1, 784))

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)

    return BCE + KLD


optimizer = optim.Adam(model.parameters(), lr=1e-3)

4.在训练数据上训练神经网络

我们只需要对数据迭代器进行循环,并将输入反馈到网络并进行优化。

for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)

其中 

def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = Variable(data)
        if args.cuda:
            data = data.cuda()
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.data[0]
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print(Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.data[0] / len(data)))
            

    print(====> Epoch: {} Average loss: {:.4f}.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    for data, _ in test_loader:
        if args.cuda:
            data = data.cuda()
        data = Variable(data, volatile=True)
        recon_batch, mu, logvar = model(data)
        test_loss += loss_function(recon_batch, data, mu, logvar).data[0]

    test_loss /= len(test_loader.dataset)
    print(====> Test set loss: {:.4f}.format(test_loss))

 

pytorch实现VAE

标签:shuf   def   names   interval   action   data   vol   any   epo   

原文地址:http://www.cnblogs.com/xueqiuqiu/p/7605796.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!