标签:unknown 初始化 ror ssi min ict tmp learning flag
高级:
def linear_regression(): """ 自实现一个线性回归 :return: """ with tf.compat.v1.variable_scope("prepare_data"): # 1)准备数据 X = tf.compat.v1.random_normal(shape=[100, 1], name="feature") y_true = tf.matmul(X, [[0.8]]) + 0.7 with tf.compat.v1.variable_scope("create_model"): # 2)构造模型 # 定义模型参数 用 变量 weights = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1, 1]), name="Weights") bias = tf.Variable(initial_value=tf.compat.v1.random_normal(shape=[1, 1]), name="Bias") y_predict = tf.matmul(X, weights) + bias with tf.compat.v1.variable_scope("loss_function"): # 3)构造损失函数 error = tf.reduce_mean(tf.square(y_predict - y_true)) with tf.compat.v1.variable_scope("optimizer"): # 4)优化损失 optimizer = tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error) # 2_收集变量 tf.summary.scalar("error", error) tf.summary.histogram("weights", weights) tf.summary.histogram("bias", bias) # 3_合并变量 merged = tf.compat.v1.summary.merge_all() # 创建Saver对象 saver = tf.compat.v1.train.Saver() # 显式地初始化变量 init = tf.compat.v1.global_variables_initializer() # 开启会话 with tf.compat.v1.Session() as sess: # 初始化变量 sess.run(init) # 1_创建事件文件 file_writer = tf.compat.v1.summary.FileWriter("./tmp/linear", graph=sess.graph) # 查看初始化模型参数之后的值 print("训练前模型参数为:权重%f,偏置%f,损失为%f" % (weights.eval(), bias.eval(), error.eval())) # 开始训练 # for i in range(100): # sess.run(optimizer) # print("第%d次训练后模型参数为:权重%f,偏置%f,损失为%f" % (i+1, weights.eval(), bias.eval(), error.eval())) # # # 运行合并变量操作 # summary = sess.run(merged) # # 将每次迭代后的变量写入事件文件 # file_writer.add_summary(summary, i) # # # 保存模型 # if i % 10 ==0: # saver.save(sess, "./tmp/model/my_linear.ckpt") # 加载模型 if os.path.exists("./tmp/model/checkpoint"): saver.restore(sess, "./tmp/model/my_linear.ckpt") print("训练后模型参数为:权重%f,偏置%f,损失为%f" % (weights.eval(), bias.eval(), error.eval())) return None
# 1)定义命令行参数 tf.compat.v1.app.flags.DEFINE_integer("max_step", 100, "训练模型的步数") tf.compat.v1.app.flags.DEFINE_string("model_dir", "Unknown", "模型保存的路径+模型名字") # 2)简化变量名 FLAGS = tf.compat.v1.app.flags.FLAGS def command_demo(): """ 命令行参数演示 :return: """ print("max_step:\n", FLAGS.max_step) print("model_dir:\n", FLAGS.model_dir) return None
标签:unknown 初始化 ror ssi min ict tmp learning flag
原文地址:https://www.cnblogs.com/dazhi151/p/14436752.html