码迷,mamicode.com
首页 > 编程语言 > 详细

TensorFlow(四) 用TensorFlow实现弹性网络回归算法(多线性回归)

时间:2018-06-12 16:10:10      阅读:179      评论:0      收藏:0      [点我收藏+]

标签:运行   div   code   abs   AC   feed   宽度   [1]   range   

弹性网络回归算法是综合lasso回归和岭回归的一种回归算法,通过在损失函数中增加L1正则和L2正则项,进而控制单个系数对结果的影响

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from sklearn import  datasets
sess=tf.Session()
#加载鸢尾花集
iris=datasets.load_iris()
#花瓣长度,花瓣宽度,花萼宽度 预测 花萼长度
x_vals=np.array([ [x[1],x[2],x[3]] for x in iris.data])
y_vals=np.array([y[0] for y in iris.data])

learning_rate=0.001
batch_size=50

x_data=tf.placeholder(shape=[None,3],dtype=tf.float32)
y_target=tf.placeholder(shape=[None,1],dtype=tf.float32)

A=tf.Variable(tf.random_normal(shape=[3,1]))
b=tf.Variable(tf.random_normal(shape=[1,1]))

#增加线性模型y=Ax+b  x*a==>shape(None,1)+b==>shape(NOne,1)
model_out=tf.add(tf.matmul(x_data,A),b)
#参数1,2
elastic_p1=tf.constant(1.)
elastic_p2=tf.constant(1.)


#声明损失函数 包含斜率的L1正则和L2正则。
#创建正则项
l1_a_loss=tf.reduce_mean(tf.abs(A))
l2_a_loss=tf.reduce_mean(tf.square(A))
e1_term=tf.multiply(elastic_p1,l1_a_loss)
e2_term=tf.multiply(elastic_p2,l2_a_loss)
#这里A是不规则的shape即3,1的数组形式  对应的loss也扩展成数组形式
loss=tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target-model_out)),e1_term),e2_term),0)

#初始化变量
init=tf.global_variables_initializer()
sess.run(init)


#梯度下降
my_opt=tf.train.GradientDescentOptimizer(learning_rate)
train_step=my_opt.minimize(loss)

#循环迭代
loss_rec=[]
for i in range(1000):
    rand_index=np.random.choice(len(x_vals),size=batch_size)
    #shape(None,3)
    rand_x= x_vals[rand_index]
    rand_y= np.transpose([y_vals[rand_index]])
    #运行
    sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y})
    temp_loss =sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y})

    #添加记录
    loss_rec.append(temp_loss)
    #打印
    if (i+1)%250==0:
        print(Step: %d A=%s b=%s%(i,str(sess.run(A)),str(sess.run(b))))
        print(Loss:%s% str(temp_loss[0]))

#弹性网络回归迭代图形
plt.plot(loss_rec,k-,label=Loss)
plt.title(Loss per Generation)
plt.xlabel(Generation)
plt.ylabel( loss )
plt.show()

技术分享图片

 

TensorFlow(四) 用TensorFlow实现弹性网络回归算法(多线性回归)

标签:运行   div   code   abs   AC   feed   宽度   [1]   range   

原文地址:https://www.cnblogs.com/x0216u/p/9173106.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!