标签:read ext 也有 success w16 imp options 缩小 需求
import os import time import math from glob import glob from PIL import Image import tensorflow as tf import numpy as np import ops # 层函数封装包 import utils # 其他辅助函数
def conv_out_size_same(size, stride): # 对浮点数向上取整(大于f的最小整数) return int(math.ceil(float(size) / float(stride)))
示例没有使用到,实际上一般类属性也会用到
类属性&__init__初始化:用于接收参数生成低层次的属性值,数据读取或者数据名列表一般也会放在__init__中
class DCGAN(): def __init__(self, sess, input_height=108, input_width=108, crop=True, batch_size=64, sample_num=64, output_height=64, output_width=64, z_dim=100, gf_dim=64, df_dim=64, gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name=‘default‘, input_fname_pattern=‘*.jpg‘, checkpoint_dir=None, sample_dir=None): """ Args: sess: TensorFlow session batch_size: The size of batch. Should be specified before training. z_dim: (optional) Dimension of dim for Z. [100] gf_dim: (optional) Dimension of gen filters in first conv layer. [64] df_dim: (optional) Dimension of discrim filters in first conv layer. [64] gfc_dim: (optional) Dimension of gen units for for fully connected layer. [1024] dfc_dim: (optional) Dimension of discrim units for fully connected layer. [1024] c_dim: (optional) Dimension of image color. For grayscale input, set to 1. [3] """ self.sess = sess self.batch_size = batch_size self.sample_num = sample_num # crop输入输出尺寸 # crop为True则output尺寸为网络输入尺寸 # crop为False则input直接进入网络输入层 self.crop = crop self.input_height = input_height self.input_width = input_width self.output_height = output_height self.output_width = output_width self.z_dim = z_dim self.gf_dim = gf_dim self.df_dim = df_dim self.dfc_dim = dfc_dim self.gfc_dim = gfc_dim self.g_bn0 = ops.batch_norm(name=‘g_bn0‘) self.g_bn1 = ops.batch_norm(name=‘g_bn1‘) self.g_bn2 = ops.batch_norm(name=‘g_bn2‘) self.g_bn3 = ops.batch_norm(name=‘g_bn3‘) self.d_bn1 = ops.batch_norm(name=‘d_bn1‘) self.d_bn2 = ops.batch_norm(name=‘d_bn2‘) self.d_bn3 = ops.batch_norm(name=‘d_bn3‘) ‘‘‘读取数据‘‘‘ self.dataset_name = dataset_name self.input_fname_pattern = input_fname_pattern self.checkpoint_dir = checkpoint_dir self.data = glob(os.path.join(‘./data‘, self.dataset_name, self.input_fname_pattern)) # 载入所有图片 ‘‘‘读取一张图片判断通道数目‘‘‘ imreadImg = np.asarray(Image.open(self.data[0])) if len(imreadImg.shape) >= 3: self.c_dim = imreadImg.shape[-1] else: self.c_dim = 1 self.grayscale = (self.c_dim == 1)
由于GAN的特殊性,被拆分了build_model(self)作为主干,discriminator(self,image,reuse=False)和generator(self,z)作为模组,这一过程包含了由数据进入网络到loss函数计算的整个流程
def build_model(self): if self.crop: image_dims = [self.output_height, self.output_width, self.c_dim] else: image_dims = [self.input_height, self.input_width, self.c_dim] ‘‘‘数据输入层‘‘‘ self.input_layer = tf.placeholder(tf.float32, [self.batch_size].extend(image_dims), name=‘input_layer‘) inputs = self.input_layer self.z = tf.placeholder(tf.float32, [None, self.z_dim], name=‘z‘) self.z_sum = tf.summary.histogram(‘z‘, self.z) ‘‘‘主要计算节点‘‘‘ # 生成 self.G = self.generator(self.z) self.D, self.D_logits = self.discriminator(inputs, reuse=False) self.sampler = self.sampler(self.z) self.D_, self.D_logits_ = self.discriminator(self.G, reuse=True) # 记录 self.G_sum = tf.summary.image(‘G‘, self.G) self.D_sum = tf.summary.histogram(‘D‘, self.D) self.D__sum = tf.summary.histogram(‘D_‘, self.D_) ‘‘‘损失函数‘‘‘ # 构建 self.d_loss_real = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits,tf.ones_like(self.D))) self.d_loss_fake = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.zeros_like(self.D_))) self.g_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(self.D_logits_,tf.ones_like(self.D_))) self.d_loss = self.d_loss_real + self.d_loss_fake # 记录 self.d_loss_real_sum = tf.Summary.scalar("d_loss_real",self.d_loss_real) self.d_loss_fake_sum = tf.Summary.scalar("d_loss_fake",self.d_loss_fake) self.g_loss_sum = tf.Summary.scalar("g_loss",self.g_loss) self.d_loss_sum = tf.Summary.scalar("d_loss",self.d_loss) # 训练参数分离 t_vars = tf.trainable_variables() self.d_vars = [var for var in t_vars if ‘d_‘ in var.name] self.g_vars = [var for var in t_vars if ‘g_‘ in var.name] # 保存器类 self.saver = tf.train.Saver() def discriminator(self,image,reuse=False): with tf.variable_scope(‘discriminator‘, reuse=reuse): h0 = ops.lrelu(ops.conv2d(image,self.df_dim,name=‘d_h0_conv‘)) h1 = ops.lrelu(self.d_bn1(ops.conv2d(h0,self.df_dim * 2,name=‘d_h1_conv‘))) h2 = ops.lrelu(self.d_bn2(ops.conv2d(h1,self.df_dim * 4,name=‘d_h2_conv‘))) h3 = ops.lrelu(self.d_bn3(ops.conv2d(h2,self.df_dim * 8,name=‘d_h3_conv‘))) h4 = ops.linear(tf.reshape(h3,[self.batch_size,-1]),1,‘d_h4_lin‘) return tf.nn.sigmoid(h4),h4 def generator(self,z): with tf.variable_scope(‘generator‘): s_h, s_w = self.output_height, self.output_width # 生成图片大小 s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2) s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2) s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2) s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2) # batch_size不变,h、w每层扩大一倍,c每层缩小一半 # 线性层 self.z_,self.h0_w,self.h0_b = ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,‘g_h0_lin‘,with_w=True) self.h0 = tf.reshape(self.z_,[-1,s_h16,s_w16,self.gf_dim * 8]) h0 = tf.nn.relu(self.g_bn0(self.h0)) # 转置卷积层 self.h1,self.h1_w,self.h1_b = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=‘g_h1‘,with_w=True) h1 = tf.nn.relu(self.g_bn1(self.h1)) h2,self.h2_w,self.h2_b = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=‘g_h2‘,with_w=True) h2 = tf.nn.relu(self.g_bn2(h2)) h3,self.h3_w,self.h3_b = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=‘g_h3‘,with_w=True) h3 = tf.nn.relu(self.g_bn3(h3)) h4,self.h4_w,self.h4_b = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=‘g_h4‘,with_w=True) return tf.nn.tanh(h4)
一般网络用于predict标签的部分,对应到GAN就是生成仿真图片的位置,这里是不参与训练的
def sampler(self,z): # 和生成器完全相同的结构且共享了变量,知识在正则化处is_training为False,这影响了滑动平均使用的两个部分 with tf.variable_scope("generator") as scope: scope.reuse_variables() s_h,s_w = self.output_height,self.output_width s_h2,s_w2 = conv_out_size_same(s_h,2),conv_out_size_same(s_w,2) s_h4,s_w4 = conv_out_size_same(s_h2,2),conv_out_size_same(s_w2,2) s_h8,s_w8 = conv_out_size_same(s_h4,2),conv_out_size_same(s_w4,2) s_h16,s_w16 = conv_out_size_same(s_h8,2),conv_out_size_same(s_w8,2) h0 = tf.reshape(ops.linear(z,self.gf_dim * 8 * s_h16 * s_w16,‘g_h0_lin‘), [-1,s_h16,s_w16,self.gf_dim * 8]) h0 = tf.nn.relu(self.g_bn0(h0,train=False)) h1 = ops.deconv2d(h0,[self.batch_size,s_h8,s_w8,self.gf_dim * 4],name=‘g_h1‘) h1 = tf.nn.relu(self.g_bn1(h1,train=False)) h2 = ops.deconv2d(h1,[self.batch_size,s_h4,s_w4,self.gf_dim * 2],name=‘g_h2‘) h2 = tf.nn.relu(self.g_bn2(h2,train=False)) h3 = ops.deconv2d(h2,[self.batch_size,s_h2,s_w2,self.gf_dim * 1],name=‘g_h3‘) h3 = tf.nn.relu(self.g_bn3(h3,train=False)) h4 = ops.deconv2d(h3,[self.batch_size,s_h,s_w,self.c_dim],name=‘g_h4‘)
超级麻烦的部分,
def train(self,config): # 辨别器优化(总) d_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1) .minimize(self.d_loss,var_list=self.d_vars) # 生成器优化 g_optim = tf.train.AdamOptimizer(config.learning_rate,beta1=config.beta1) .minimize(self.g_loss,var_list=self.g_vars) tf.global_variables_initializer().run() # 记录各个值迭代的变化 self.g_sum = tf.Summary.merge([self.z_sum,self.D__sum, self.G_sum,self.d_loss_fake_sum,self.g_loss_sum]) self.d_sum = tf.summary.merge([self.z_sum,self.d_sum,self.d_loss_real_sum,self.d_loss_sum]) self.writer = tf.Summary.Writer("./logs",self.sess.graph) # 读取sample_num张图片 sample_files = self.data[0:self.sample_num] sample = [utils.get_image(sample_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, crop=self.crop) for sample_file in sample_files] sample_inputs = np.array(sample).astype(np.float32) sample_z = np.random.uniform(-1,1,size=(self.sample_num,self.z_dim)) counter = 1 start_time = time.time() could_load,checkpoint_counter = self.load(self.checkpoint_dir) # 载入model继续训练 if could_load: counter = checkpoint_counter print(" [*] Load SUCCESS") else: print(" [!] Load failed...") for epoch in range(config.epoch): self.data = glob(os.path.join( "./data",config.dataset,self.input_fname_pattern)) batch_idxs = min(len(self.data),config.train_size) // config.batch_size for idx in range(0,batch_idxs): # 读取batch图片x batch_files = self.data[idx * config.batch_size:(idx + 1) * config.batch_size] batch = [ utils.get_image(batch_file, input_height=self.input_height, input_width=self.input_width, resize_height=self.output_height, resize_width=self.output_width, crop=self.crop) for batch_file in batch_files] batch_images = np.array(batch).astype(np.float32) # 生成噪声z batch_z = np.random.uniform(-1,1,[config.batch_size,self.z_dim]) .astype(np.float32) # Update D network _,summary_str = self.sess.run([d_optim,self.d_sum], feed_dict={self.input_layer: batch_images,self.z: batch_z}) self.writer.add_summary(summary_str,counter) # Update G network _,summary_str = self.sess.run([g_optim,self.g_sum], feed_dict={self.z: batch_z}) self.writer.add_summary(summary_str,counter) # 书写器书写的并不是一般意义上的记录而是普通的标量值 # Update G network # Run g_optim twice to make sure that d_loss does not go to zero (different from paper) _,summary_str = self.sess.run([g_optim,self.g_sum], feed_dict={self.z: batch_z}) self.writer.add_summary(summary_str,counter) # run损失值 errD_fake = self.d_loss_fake.eval({self.z: batch_z}) errD_real = self.d_loss_real.eval({self.input_layer: batch_images}) errG = self.g_loss.eval({self.z: batch_z}) counter += 1 print("Epoch: [%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" % (epoch,idx,batch_idxs, time.time() - start_time,errD_fake + errD_real,errG)) if np.mod(counter,100) == 1: try: samples,d_loss,g_loss = self.sess.run( [self.sampler,self.d_loss,self.g_loss], feed_dict={ self.z: sample_z, self.input_layer: sample_inputs, }, ) utils.save_images(samples,utils.image_manifold_size(samples.shape[0]), ‘./{}/train_{:02d}_{:04d}.png‘.format(config.sample_dir,epoch,idx)) print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss,g_loss)) except: print("one pic error!...") if np.mod(counter,500) == 2: self.save(config.checkpoint_dir,counter)
个人感觉功能有点臃肿,不过还是很值得借鉴的,
比如使用装饰器把函数隐藏成属性这个我就感觉很没必要,毕竟都是自家内部调用... ...
检查文件夹时的固定搭配这个就很不错:
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
作者为了跑不同的数据集在文件名归类上下了一番功夫,所以load模块比较复杂,所以适当的多给了一些注释
‘‘‘模型保存&载入‘‘‘ # checkpoint_dir/datasetname_batchsize_outputheight_outputwidth/模型 @property def model_dir(self): return "{}_{}_{}_{}".format( self.dataset_name,self.batch_size, self.output_height,self.output_width) def save(self,checkpoint_dir,step): model_name = "DCGAN.model" checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir) if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) self.saver.save(self.sess, os.path.join(checkpoint_dir,model_name), global_step=step) def load(self,checkpoint_dir): import re print(" [*] Reading checkpoints...") checkpoint_dir = os.path.join(checkpoint_dir,self.model_dir) # 合并模型根路径和数据集路径 ckpt = tf.train.get_checkpoint_state(checkpoint_dir) # 模型保存文件夹->最新模型文件名 if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) # 提取无路径模型文件名,感觉没有必要,checkpoint保存的名字本身就是不带路径的 self.saver.restore(self.sess,os.path.join(checkpoint_dir,ckpt_name)) # 载入参数 counter = int(next(re.finditer("(\d+)",ckpt_name)).group(0)) # 提取训练轮数 print(" [*] Success to read {}".format(ckpt_name)) return True,counter else: print(" [*] Failed to find a checkpoint") return False,0
import os import pprint import numpy as np import tensorflow as tf from model import DCGAN # 接收命令行参数分三步 flags = tf.app.flags flags.DEFINE_integer("epoch", 25, "Epoch to train [25]") flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]") flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]") flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]") flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]") flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]") flags.DEFINE_integer("input_width", None, "The size of image to use (will be center cropped). If None, same value as input_height [None]") flags.DEFINE_integer("output_height", 64, "The size of the output images to produce [64]") flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]") flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]") flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]") flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]") flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]") flags.DEFINE_boolean("train", False, "True for training, False for testing [False]") flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]") flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]") FLAGS = flags.FLAGS # 必须带参数,否则:‘TypeError: main() takes no arguments (1 given)‘; # main的参数名随意定义,无要求 def main(_): # pprint模块,更美观的显示数据结构 pp = pprint.PrettyPrinter() pp.pprint(flags.FLAGS.__flags) if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir) if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) run_config = tf.ConfigProto() # TensorFlow占用gpu资源的默认方式异常贪婪,这里修改为按需求申请 run_config.gpu_options.allow_growth = True # 下面的是按比例申请 # run_config.gpu_options.per_process_gpu_memory_fraction=0.333 with tf.Session(config=run_config) as sess: dcgan = DCGAN( sess, input_width=FLAGS.input_width, input_height=FLAGS.input_height, output_width=FLAGS.output_width, output_height=FLAGS.output_height, batch_size=FLAGS.batch_size, sample_num=FLAGS.batch_size, dataset_name=FLAGS.dataset, input_fname_pattern=FLAGS.input_fname_pattern, crop=FLAGS.crop, checkpoint_dir=FLAGS.checkpoint_dir, sample_dir=FLAGS.sample_dir) if FLAGS.train: dcgan.train(FLAGS) else: if not dcgan.load(FLAGS.checkpoint_dir)[0]: raise Exception("[!] Train a model first, then run test mode") if __name__==‘__main__‘: tf.app.run()
预测部分没写好,所以没加上来,但是这不妨碍理解思路
值得一提的是dcgan.train(FLAGS),这里直接传入了FLAGS,对应内部train函数接收参数config,{config.参数名}这样的调用方法十分方便,这也有助于理解脚本化TF程序的便利之处『TensorFlow』脚本化使用方法。
标签:read ext 也有 success w16 imp options 缩小 需求
原文地址:http://www.cnblogs.com/hellcat/p/7396094.html