码迷,mamicode.com
首页 > 其他好文 > 详细

使用生成对抗网络(GAN)生成手写字

时间:2019-12-08 15:36:48      阅读:156      评论:0      收藏:0      [点我收藏+]

标签:feed   stat   sha   nat   shape   tor   dom   init   保留   

先放结果

技术图片
这是通过GAN迭代训练30W次,耗时3小时生成的手写字图片效果,大部分的还是能看出来是数字的。

实现原理

简单说下原理,生成对抗网络需要训练两个任务,一个叫生成器,一个叫判别器,如字面意思,一个负责生成图片,一个负责判别图片,生成器不断生成新的图片,然后判别器去判断哪儿哪儿不行,生成器再不断去改进,不断的像真实的图片靠近。

这就如同一个造假团伙一样,A负责生产,B负责就鉴定,刚开始的时候,两个人都是菜鸟,A随便画了一幅画拿给B看,B说你这不行,然后A再改进,当然需要改进的不止A,随着A的改进,B也得不断提升,B需要发现更细微的差异,直至他们觉得已经没什么差异了(实际肯定还存在差异),他们便决定停止"训练",开始卖吧。

实现代码
# -*- coding: utf-8 -*-

# @author: Awesome_Tang
# @date: 2019-02-22
# @version: python2.7


import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from datetime import datetime
import numpy as np
import os
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


class Config:
    alpha = 1e-2
    drop_rate = 0.5  # 保留比例
    steps = 300000  # 迭代次数
    batch_size = 128  # 每批次训练样本数
    epochs = 100  # 训练轮次

    num_units = 128
    size = 784
    noise_size = 100

    smooth = 0.01
    learning_rate = 1e-4

    print_per_step = 1000


class Gan:

    def __init__(self):
        print('Loading data......')
        # 读取MNIST数据集
        self.mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

        # 定义占位符,真实图片和生成的图片
        self.real_images = tf.placeholder(tf.float32, [None, Config.size], name='real_images')
        self.noise = tf.placeholder(tf.float32, [None, Config.noise_size], name='noise')
        self.drop_rate = tf.placeholder('float')

        self.train_step()

    def generator_graph(self, noise, n_units, out_dim, alpha, reuse=False):

        with tf.variable_scope('generator', reuse=reuse):
            # Hidden layer
            h1 = tf.layers.dense(noise, n_units, activation=None)
            # Leaky ReLU
            h1 = tf.maximum(alpha * h1, h1)
            h1 = tf.layers.dropout(h1, rate=self.drop_rate)
            # Logits and tanh output
            logits = tf.layers.dense(h1, out_dim, activation=None)
            out = tf.tanh(logits)

        return out

    @staticmethod
    def discriminator_graph(image, n_units, alpha, reuse=False):

        with tf.variable_scope('discriminator', reuse=reuse):
            # Hidden layer
            h1 = tf.layers.dense(image, n_units, activation=None)
            # Leaky ReLU
            h1 = tf.maximum(alpha * h1, h1)

            logits = tf.layers.dense(h1, 1, activation=None)
            # out = tf.sigmoid(logits)

        return logits

    def net(self):
        # generator
        fake_image = self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha)

        # discriminator
        real_logits = self.discriminator_graph(self.real_images, Config.num_units, Config.alpha)
        fake_logits = self.discriminator_graph(fake_image, Config.num_units, Config.alpha, reuse=True)

        # discriminator的loss
        # 识别真实图片
        d_loss_real = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)) * (
                    1 - Config.smooth))
        # 识别生成的图片
        d_loss_fake = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
        # 总体loss
        d_loss = tf.add(d_loss_real, d_loss_fake)

        # generator的loss
        g_loss = tf.reduce_mean(
            tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)) * (
                    1 - Config.smooth))

        net_vars = tf.trainable_variables()

        # generator中的tensor
        g_vars = [var for var in net_vars if var.name.startswith("generator")]
        # discriminator中的tensor
        d_vars = [var for var in net_vars if var.name.startswith("discriminator")]

        # optimizer
        dis_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(d_loss, var_list=d_vars)
        gen_optimizer = tf.train.AdamOptimizer(Config.learning_rate).minimize(g_loss, var_list=g_vars)

        return dis_optimizer, gen_optimizer, d_loss, g_loss

    def train_step(self):
        dis_optimizer, gen_optimizer, d_loss, g_loss = self.net()

        print('Training & Evaluating......')
        start_time = datetime.now()
        sess = tf.Session()
        sess.run(tf.global_variables_initializer())

        for step in range(Config.steps):
            real_image, _ = self.mnist.train.next_batch(Config.batch_size)

            real_image = real_image * 2 - 1

            # generator的输入噪声
            batch_noise = np.random.uniform(-1, 1, size=(Config.batch_size, Config.noise_size))

            sess.run(gen_optimizer, feed_dict={self.noise: batch_noise, self.drop_rate: Config.drop_rate})
            sess.run(dis_optimizer, feed_dict={self.noise: batch_noise, self.real_images: real_image})

            if step % Config.print_per_step == 0:
                dis_loss = sess.run(d_loss, feed_dict={self.noise: batch_noise, self.real_images: real_image})
                gen_loss = sess.run(g_loss, feed_dict={self.noise: batch_noise, self.drop_rate: 1.})
                end_time = datetime.now()
                time_diff = (end_time - start_time).seconds

                msg = 'Step {:3}k Dis_Loss:{:6.2f}, Gen_Loss:{:6.2f}, Time_Usage:{:6.2f} mins.'
                print(msg.format(int(step / 1000), dis_loss, gen_loss, time_diff / 60.))

        self.gen_image(sess)

    def gen_image(self, sess):
        sample_noise = np.random.uniform(-1, 1, size=(25, Config.noise_size))
        samples = sess.run(
            self.generator_graph(self.noise, Config.num_units, Config.size, Config.alpha, reuse=True),
            feed_dict={self.noise: sample_noise})

        plt.figure(figsize=(8, 8), dpi=80)
        for i in range(25):
            img = samples[i]
            plt.subplot(5, 5, i + 1)
            plt.imshow(img.reshape((28, 28)), cmap='Greys_r')
            plt.axis('off')
        plt.show()


if __name__ == "__main__":
    Gan()

Peace~~

使用生成对抗网络(GAN)生成手写字

标签:feed   stat   sha   nat   shape   tor   dom   init   保留   

原文地址:https://www.cnblogs.com/awesometang/p/12005661.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!