标签:spl eval ons ssi %s ada const mini bat
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Sat Sep 15 10:54:53 2018
@author: myhaspl
@email:myhaspl@myhaspl.com
糖尿病预测
"""
import tensorflow as tf
import os
trainCount=10000
inputNodeCount=8
validateCount=20
sampleCount=100
testCount=7
outputNodeCount=1
g=tf.Graph()
with g.as_default():
def getWeights(shape,wname):
weights=tf.Variable(tf.truncated_normal(shape,stddev=0.1),name=wname)
return weights
def getBias(shape,bname):
biases=tf.Variable(tf.constant(0.1,shape=shape),name=bname)
return biases
def inferenceInput(x):
result=tf.add(tf.matmul(x,w1),b1)
return result
def inference(x):
yp=inferenceInput(x)
return tf.sigmoid(yp)
def loss():
yp=inferenceInput(x)
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y,logits=yp))
def train(learningRate,trainLoss,trainStep):
trainOp=tf.train.AdadeltaOptimizer(learningRate).minimize(trainLoss,global_step=trainStep)
return trainOp
def evaluate(x):
return tf.cast(inference(x)>0.5,tf.float32)
def accuracy(x,y,count):
yp=evaluate(x)
return tf.reduce_mean(tf.cast(tf.equal(yp,y),tf.float32))
def inputs(fileName,skipLines=1):
#生成文件名队列
fileNameQueue=tf.train.string_input_producer([fileName])
#生成记录键值对
reader=tf.TextLineReader(skip_header_lines=skipLines)
key,value=reader.read(fileNameQueue)
return value
def getNextBatch(n,values):
recordDefaults=[[1.],[1.],[1.],[1.],[1.],[1.],[1.],[1.],[1.]]
decoded=tf.decode_csv(values,record_defaults=recordDefaults)
pregnancies,glucose,bloodPressure,skinThickness,insulin,bmi,diabetespedigreefunction,age,outcome=tf.train.shuffle_batch(decoded,batch_size=n,capacity=1000,min_after_dequeue=1)
features=tf.transpose(tf.stack([pregnancies,glucose,bloodPressure,skinThickness,insulin,bmi,diabetespedigreefunction,age]))
y=tf.transpose([outcome])
return (features,y)
with tf.name_scope("inputSamples"):
samples=inputs(os.getcwd()+"/diabetes.csv",1)
inputX,inputY=getNextBatch(sampleCount,samples)
with tf.name_scope("validateSamples"):
validateInputX,validateInputY=getNextBatch(validateCount,samples)
with tf.name_scope("testSamples"):
testSamples=inputs(os.getcwd()+"/diabetes_test.csv",1)
testInputX,testInputY=getNextBatch(testCount,testSamples)
with tf.name_scope("inputDatas"):
x=tf.placeholder(dtype=tf.float32,shape=[None,inputNodeCount],name="input_x")
y=tf.placeholder(dtype=tf.float32,shape=[None,outputNodeCount],name="input_y")
with tf.name_scope("Variable"):
w1=tf.Variable(tf.truncated_normal([inputNodeCount,1],name="w1",stddev=0.1))
b1=tf.Variable(0.1,dtype=tf.float32,name="b1")
trainStep=tf.Variable(0,dtype=tf.int32,name="tcount",trainable=False)
with tf.Session(graph=g) as sess:
trainLoss=loss()
accuracyOp=accuracy(validateInputX,validateInputY,validateCount)
trainOp=train(0.025,trainLoss,trainStep)
init=tf.global_variables_initializer()
sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
while trainStep.eval()<trainCount:
sampleX=sess.run(inputX)
sampleY=sess.run(inputY)
sess.run(trainOp,feed_dict={x:sampleX,y:sampleY})
nowStep=sess.run(trainStep)
if nowStep%500==0:
validate_acc=sess.run(accuracyOp)
print "%d次后=>正确率%g"%(nowStep,validate_acc)
if nowStep>trainCount:
break
print "测试样本正确率%g"%sess.run(accuracy(testInputX,testInputY,testCount))
print sess.run(inputY)
print sess.run(evaluate(inputX))
coord.request_stop()
coord.join(threads)
标签:spl eval ons ssi %s ada const mini bat
原文地址:http://blog.51cto.com/13959448/2318748