码迷,mamicode.com
首页 > Web开发 > 详细

Context_Encoder在mnist的实战

时间:2019-10-16 16:21:53      阅读:222      评论:0      收藏:0      [点我收藏+]

标签:set   row   github   二维   rand   www   from   mpi   ros   

Context_Encoder是一种基于GAN的人脸修复框架,后面附带了简单地的理论讲解。论文中人脸照片被攻击的方式有三种:在图片(矩阵)中扣一个正方形,让正方形的数字变成0;在图片中任意扣除n个正方形,让正方形中的数字变成0;最后一种是让图片中(矩阵)中任意的一些数字变成0.第三种才是大家比较喜欢的,也是最接近现实的。keras的官方教程给出了3通道的cifar(三维数据)的人脸修复代码,修复的也是第一种攻击方式。在这个代码的基础上,我将其修改到了mnist二维数据集的人脸修复上。

理论讲解:

1.https://blog.csdn.net/qq_33594380/article/details/85317922

2.https://www.cnblogs.com/wmr95/p/10636804.html

keras的cifar的教程:

https://github.com/eriklindernoren/Keras-GAN/blob/master/context_encoder/context_encoder.py

 

以下是我修改后的代码:

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply, GaussianNoise
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K

import matplotlib.pyplot as plt

import numpy as np

class ContextEncoder():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.mask_height = 8
        self.mask_width = 8
        self.channels = 1
        self.num_classes = 2
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.missing_shape = (self.mask_height, self.mask_width, self.channels)

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss=‘binary_crossentropy‘,
            optimizer=optimizer,
            metrics=[‘accuracy‘])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates the missing
        # part of the image
        masked_img = Input(shape=self.img_shape)
        gen_missing = self.generator(masked_img)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines
        # if it is generated or if it is a real image
        valid = self.discriminator(gen_missing)

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model(masked_img , [gen_missing, valid])
        self.combined.compile(loss=[‘mse‘, ‘binary_crossentropy‘],
            loss_weights=[0.999, 0.001],
            optimizer=optimizer)

    def build_generator(self):


        model = Sequential()

        # Encoder
        model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(512, kernel_size=1, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.5))

        # Decoder
        model.add(UpSampling2D())
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation(‘relu‘))
        model.add(BatchNormalization(momentum=0.8))
        model.add(UpSampling2D())
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation(‘relu‘))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
        model.add(Activation(‘tanh‘))

        model.summary()

        masked_img = Input(shape=self.img_shape)
        gen_missing = model(masked_img)

        return Model(masked_img, gen_missing)

    def build_discriminator(self):

        model = Sequential()

        model.add(Conv2D(64, kernel_size=3, strides=2, input_shape=self.missing_shape, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, kernel_size=3, strides=2, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(256, kernel_size=3, padding="same"))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Flatten())
        model.add(Dense(1, activation=‘sigmoid‘))
        model.summary()

        img = Input(shape=self.missing_shape)
        validity = model(img)

        return Model(img, validity)
    def mask_randomly(self, imgs):
        y1 = np.random.randint(0, self.img_rows - self.mask_height, imgs.shape[0])
        y2 = y1 + self.mask_height
        x1 = np.random.randint(0, self.img_rows - self.mask_width, imgs.shape[0])
        x2 = x1 + self.mask_width

        masked_imgs = np.empty_like(imgs)
        missing_parts = np.empty((imgs.shape[0], self.mask_height, self.mask_width, self.channels))
        for i, img in enumerate(imgs):
            masked_img = img.copy()
            _y1, _y2, _x1, _x2 = y1[i], y2[i], x1[i], x2[i]
            missing_parts[i] = masked_img[_y1:_y2, _x1:_x2, :].copy()
            masked_img[_y1:_y2, _x1:_x2, :] = 0
            masked_imgs[i] = masked_img

        return masked_imgs, missing_parts, (y1, y2, x1, x2)



    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, y_train), (_, _) = mnist.load_data()

        # Extract dogs and cats
        X_cats = X_train[(y_train == 3).flatten()]
        X_dogs = X_train[(y_train == 5).flatten()]
        X_train = np.vstack((X_cats, X_dogs))
        X_train = X_train.reshape(-1,28,28,1)
        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        y_train = y_train.reshape(-1, 1)

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            masked_imgs, missing_parts, _ = self.mask_randomly(imgs)

            # Generate a batch of new images
            gen_missing = self.generator.predict(masked_imgs)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(missing_parts, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_missing, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            g_loss = self.combined.train_on_batch(masked_imgs, [missing_parts, valid])

            # Plot the progress
            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                idx = np.random.randint(0, X_train.shape[0], 6)
                imgs = X_train[idx]
                self.sample_images(epoch, imgs)
    def plot_image(self,image):
        fig=plt.gcf()
        fig.set_size_inches(2,2)
        plt.imshow(image,cmap=‘binary‘)
        plt.show()
    def sample_images(self, epoch, imgs):

        masked_imgs, missing_parts, (y1, y2, x1, x2) = self.mask_randomly(imgs)
        gen_missing = self.generator.predict(masked_imgs)
        
        imgs = 0.5 * imgs + 0.5                #完整图片
        masked_imgs = 0.5 * masked_imgs + 0.5  #残缺图片
        gen_missing = 0.5 * gen_missing + 0.5  #模拟的缺失值
        filled_in = imgs[1].copy()
        filled_in[y1[1]:y2[1], x1[1]:x2[1], :] = gen_missing[1]
        
        #print("数组的维度",imgs.shape)
        imgs = imgs.reshape(-1,28,28)
        masked_imgs = masked_imgs.reshape(-1,28,28)
        filled_in = filled_in.reshape(-1,28,28)
        self.plot_image(imgs[1])
        
        self.plot_image(masked_imgs[1])
        
        self.plot_image(filled_in[0])
        
        
        plt.close()

    def save_model(self):

        def save(model, model_name):
            model_path = "saved_model/%s.json" % model_name
            weights_path = "saved_model/%s_weights.hdf5" % model_name
            options = {"file_arch": model_path,
                        "file_weight": weights_path}
            json_string = model.to_json()
            open(options[‘file_arch‘], ‘w‘).write(json_string)
            model.save_weights(options[‘file_weight‘])

        save(self.generator, "generator")
        save(self.discriminator, "discriminator")


if __name__ == ‘__main__‘:
    context_encoder = ContextEncoder()
    context_encoder.train(epochs=2000, batch_size=64, sample_interval=1999)

  

Context_Encoder在mnist的实战

标签:set   row   github   二维   rand   www   from   mpi   ros   

原文地址:https://www.cnblogs.com/nanhaijindiao/p/11686105.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!