码迷,mamicode.com
首页 > 其他好文 > 详细

《Tensorflow基础泰坦尼克获救预测》-- 网易云课堂

时间:2018-12-01 11:16:39      阅读:231      评论:0      收藏:0      [点我收藏+]

标签:style   batch   0.00   分享   ini   show   lse   tle   nump   

https://study.163.com/course/courseMain.htm?courseId=1004937015

 1 # -*- coding: utf-8 -*-
 2 
 3 import tensorflow as tf
 4 import pandas as pd
 5 import numpy as np
 6 
 7 data = pd.read_csv(train.csv)
 8 data = data[[Survived, Pclass, Sex, Age, SibSp, Parch, Fare, Cabin, Embarked]]
 9 
10 data[Age] = data[Age].fillna(data[Age].mean())
11 data[Cabin] = pd.factorize(data[Cabin])[0]
12 data.fillna(0, inplace=True)
13 data[Sex] = [1 if x==male else 0 for x in data[Sex]]
14 data[p1] = np.array(data[Pclass]==1).astype(np.int32)
15 data[p2] = np.array(data[Pclass]==2).astype(np.int32)
16 data[p3] = np.array(data[Pclass]==3).astype(np.int32)
17 del data[Pclass]
18 data[e1] = np.array(data[Embarked]==S).astype(np.int32)
19 data[e2] = np.array(data[Embarked]==C).astype(np.int32)
20 data[e3] = np.array(data[Embarked]==Q).astype(np.int32)
21 del data[Embarked]
22 
23 data_train = data[[ Sex, Age, SibSp, Parch, Fare, Cabin, p1, p2, p3, e1, e2, e3]]
24 data_target = data[Survived].values.reshape(len(data), 1)
25 
26 x = tf.placeholder("float", shape=[None, 12])
27 y = tf.placeholder("float", shape=[None, 1])
28 
29 weight = tf.Variable(tf.random_normal([12, 1]))
30 bias = tf.Variable(tf.random_normal([1]))
31 output = tf.matmul(x, weight) + bias
32 pred = tf.cast(tf.sigmoid(output) > 0.5, tf.float32)
33 
34 loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=output))
35 train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss)
36 accurary = tf.reduce_mean(tf.cast(tf.equal(pred, y), tf.float32))
37 
38 data_test = pd.read_csv(test.csv)
39 data_test = data_test[[Pclass, Sex, Age, SibSp, Parch, Fare, Cabin, Embarked]]
40 data_test[Age] = data_test[Age].fillna(data_test[Age].mean())
41 data_test[Cabin] = pd.factorize(data_test[Cabin])[0]
42 data_test.fillna(0, inplace=True)
43 data_test[Sex] = [1 if x==male else 0 for x in data_test[Sex]]
44 data_test[p1] = np.array(data_test[Pclass]==1).astype(np.int32)
45 data_test[p2] = np.array(data_test[Pclass]==2).astype(np.int32)
46 data_test[p3] = np.array(data_test[Pclass]==3).astype(np.int32)
47 del data_test[Pclass]
48 data_test[e1] = np.array(data_test[Embarked]==S).astype(np.int32)
49 data_test[e2] = np.array(data_test[Embarked]==C).astype(np.int32)
50 data_test[e3] = np.array(data_test[Embarked]==Q).astype(np.int32)
51 del data_test[Embarked]
52 
53 test_label = pd.read_csv(gender_submission.csv)
54 test_label = np.reshape(test_label[Survived].values.astype(np.float32), (418,1))
55 
56 sess = tf.Session()
57 sess.run(tf.global_variables_initializer())
58 loss_train = []
59 train_acc = []
60 test_acc = []
61 
62 data_train = data_train.values
63 for i in range(25000):
64     index = np.random.permutation(len(data_target))
65     data_train = data_train[index]
66     data_target = data_target[index]
67     for n in range(len(data_target)//100 + 1):
68         batch_xs = data_train[n*100:n*100+100]
69         batch_ys = data_target[n*100:n*100+100]
70         sess.run(train_step, feed_dict={x:batch_xs, y:batch_ys})
71         
72     if i%1000 == 0:
73         loss_temp = sess.run(loss, feed_dict={x:batch_xs, y:batch_ys})
74         loss_train.append(loss_temp)
75         train_acc_temp = sess.run(accurary, feed_dict={x:batch_xs, y:batch_ys})
76         train_acc.append(train_acc_temp)
77         test_acc_temp = sess.run(accurary, feed_dict={x:data_test, y:test_label})
78         test_acc.append(test_acc_temp)
79         print(loss_temp,train_acc_temp,test_acc_temp)
80         
81 import matplotlib.pyplot as plt
82 
83 plt.plot(loss_train, k-)
84 plt.title(train loss)
85 plt.show()
86 
87 plt.plot(train_acc, b-, label=train_acc)
88 plt.plot(test_acc, r--, label=test_acc)
89 plt.title(train and test accuracy)
90 plt.legend()
91 plt.show()

技术分享图片

技术分享图片

技术分享图片

 

《Tensorflow基础泰坦尼克获救预测》-- 网易云课堂

标签:style   batch   0.00   分享   ini   show   lse   tle   nump   

原文地址:https://www.cnblogs.com/LearnFromNow/p/10048167.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!