码迷,mamicode.com
首页 > 其他好文 > 详细

阿里云AI-深度学习糖尿病预测

时间:2018-12-04 20:03:14      阅读:133      评论:0      收藏:0      [点我收藏+]

标签:ros   6.4   node   initial   equal   深度学习   validate   9.4   dev   

#!/usr/bin/env python2 # -*- coding: utf-8 -*- """ Created on Sat Sep 15 10:54:53 2018 @author: myhaspl @email:myhaspl@myhaspl.com 糖尿病预测(多层) csv格式:怀孕次数、葡萄糖、血压、皮肤厚度,胰岛素,bmi,糖尿病血统函数,年龄,结果 """ import tensorflow as tf trainCount=10000 inputNodeCount=8 validateCount=50 sampleCount=200 testCount=10 outputNodeCount=1 ossPath="oss://myhaspl-ai.oss-cn-beijing-internal.aliyuncs.com/" localPath="./" dataPath=ossPath g=tf.Graph() with g.as_default(): def getWeights(shape,wname): weights=tf.Variable(tf.truncated_normal(shape,stddev=0.1),name=wname) return weights def getBias(shape,bname): biases=tf.Variable(tf.constant(0.1,shape=shape),name=bname) return biases def inferenceInput(x): layer1=tf.nn.relu(tf.add(tf.matmul(x,w1),b1)) result=tf.add(tf.matmul(layer1,w2),b2) return result def inference(x): yp=inferenceInput(x) return tf.sigmoid(yp) def loss(): yp=inferenceInput(x) return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y,logits=yp)) def train(learningRate,trainLoss,trainStep): trainOp=tf.train.AdamOptimizer(learningRate).minimize(trainLoss,global_step=trainStep) return trainOp def evaluate(x): return tf.cast(inference(x)>0.5,tf.float32) def accuracy(x,y,count): yp=evaluate(x) return tf.reduce_mean(tf.cast(tf.equal(yp,y),tf.float32)) def inputFromFile(fileName,skipLines=1): #生成文件名队列 fileNameQueue=tf.train.string_input_producer([fileName]) #生成记录键值对 reader=tf.TextLineReader(skip_header_lines=skipLines) key,value=reader.read(fileNameQueue) return value def getTestData(fileName,skipLines=1,n=10): #生成文件名队列 testFileNameQueue=tf.train.string_input_producer([fileName]) #生成记录键值对 testReader=tf.TextLineReader(skip_header_lines=skipLines) testKey,testValue=testReader.read(testFileNameQueue) testRecordDefaults=[[1.],[1.],[1.],[1.],[1.],[1.],[1.],[1.],[1.]] testDecoded=tf.decode_csv(testValue,record_defaults=testRecordDefaults) pregnancies,glucose,bloodPressure,skinThickness,insulin,bmi,diabetespedigreefunction,age,outcome=tf.train.shuffle_batch(testDecoded,batch_size=n,capacity=1000,min_after_dequeue=1) testFeatures=tf.transpose(tf.stack([pregnancies,glucose,bloodPressure,skinThickness,insulin,bmi,diabetespedigreefunction,age])) testY=tf.transpose([outcome]) return (testFeatures,testY) def getNextBatch(n,values): recordDefaults=[[1.],[1.],[1.],[1.],[1.],[1.],[1.],[1.],[1.]] decoded=tf.decode_csv(values,record_defaults=recordDefaults) pregnancies,glucose,bloodPressure,skinThickness,insulin,bmi,diabetespedigreefunction,age,outcome=tf.train.shuffle_batch(decoded,batch_size=n,capacity=1000,min_after_dequeue=1) features=tf.transpose(tf.stack([pregnancies,glucose,bloodPressure,skinThickness,insulin,bmi,diabetespedigreefunction,age])) y=tf.transpose([outcome]) return (features,y) with tf.name_scope("inputSample"): samples=inputFromFile(dataPath+"diabetes.csv",1) inputDs=getNextBatch(sampleCount,samples) with tf.name_scope("validateSamples"): validateInputs=getNextBatch(validateCount,samples) with tf.name_scope("testSamples"): testInputs=getTestData(dataPath+"diabetes_test.csv") with tf.name_scope("inputDatas"): x=tf.placeholder(dtype=tf.float32,shape=[None,inputNodeCount],name="input_x") y=tf.placeholder(dtype=tf.float32,shape=[None,outputNodeCount],name="input_y") with tf.name_scope("Variable"): w1=getWeights([inputNodeCount,12],"w1") b1=getBias((),"b1") w2=getWeights([12,outputNodeCount],"w2") b2=getBias((),"b2") trainStep=tf.Variable(0,dtype=tf.int32,name="tcount",trainable=False) with tf.name_scope("train"): trainLoss=loss() trainOp=train(0.005,trainLoss,trainStep) init=tf.global_variables_initializer() with tf.Session(graph=g) as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(coord=coord) logStr="" while trainStep.eval()<trainCount: sampleX,sampleY=sess.run(inputDs) sess.run(trainOp,feed_dict={x:sampleX,y:sampleY}) nowStep=sess.run(trainStep) if nowStep%500==0: validate_acc=sess.run(accuracy(sampleX,sampleY,sampleCount)) logStr=logStr+"%d次后=>正确率%g"%(nowStep,validate_acc)+"\n" print ".", if nowStep>trainCount: break testInputX,testInputY=sess.run(testInputs) logStr=logStr+"测试样本正确率%g"%sess.run(accuracy(testInputX,testInputY,testCount)) logStr=logStr+str((testInputX,testInputY)) logStr=logStr+str(sess.run(evaluate(testInputX))) with tf.gfile.GFile(dataPath+"ai_log.txt","wb") as f: f.write(logStr) coord.request_stop() coord.join(threads)
500次后=>正确率0.7
1000次后=>正确率0.8
1500次后=>正确率0.83
2000次后=>正确率0.78
2500次后=>正确率0.775
3000次后=>正确率0.76
3500次后=>正确率0.885
4000次后=>正确率0.85
4500次后=>正确率0.785
5000次后=>正确率0.79
5500次后=>正确率0.795
6000次后=>正确率0.87
6500次后=>正确率0.85
7000次后=>正确率0.79
7500次后=>正确率0.805
8000次后=>正确率0.775
8500次后=>正确率0.87
9000次后=>正确率0.84
9500次后=>正确率0.815
10000次后=>正确率0.805
测试样本正确率1(array([[1.00e+00, 8.90e+01, 6.60e+01, 2.30e+01, 9.40e+01, 2.81e+01,
        1.67e-01, 2.10e+01],
       [8.00e+00, 1.83e+02, 6.40e+01, 0.00e+00, 0.00e+00, 2.33e+01,
        6.72e-01, 3.20e+01],
       [1.00e+00, 1.26e+02, 6.00e+01, 0.00e+00, 0.00e+00, 3.01e+01,
        3.49e-01, 4.70e+01],
       [1.00e+00, 9.30e+01, 7.00e+01, 3.10e+01, 0.00e+00, 3.04e+01,
        3.15e-01, 2.30e+01],
       [8.00e+00, 1.83e+02, 6.40e+01, 0.00e+00, 0.00e+00, 2.33e+01,
        6.72e-01, 3.20e+01],
       [5.00e+00, 1.16e+02, 7.40e+01, 0.00e+00, 0.00e+00, 2.56e+01,
        2.01e-01, 3.00e+01],
       [8.00e+00, 1.83e+02, 6.40e+01, 0.00e+00, 0.00e+00, 2.33e+01,
        6.72e-01, 3.20e+01],
       [1.00e+00, 8.50e+01, 6.60e+01, 2.90e+01, 0.00e+00, 2.66e+01,
        3.51e-01, 3.10e+01],
       [6.00e+00, 1.48e+02, 7.20e+01, 3.50e+01, 0.00e+00, 3.36e+01,
        6.27e-01, 5.00e+01],
       [9.00e+00, 8.90e+01, 6.20e+01, 0.00e+00, 0.00e+00, 2.25e+01,
        1.42e-01, 3.30e+01]], dtype=float32), array([[0.],
       [1.],
       [1.],
       [0.],
       [1.],
       [0.],
       [1.],
       [0.],
       [1.],
       [0.]], dtype=float32))[[0.]
 [1.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]
 [1.]
 [0.]]

阿里云AI-深度学习糖尿病预测

标签:ros   6.4   node   initial   equal   深度学习   validate   9.4   dev   

原文地址:http://blog.51cto.com/13959448/2326090

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!