码迷,mamicode.com
首页 > 其他好文 > 详细

tensorflow bilstm官方示例

时间:2017-06-10 18:30:05      阅读:543      评论:0      收藏:0      [点我收藏+]

标签:atm   org   format   amp   match   oop   sof   ace   log   

  1 ‘‘‘
  2 A Bidirectional Recurrent Neural Network (LSTM) implementation example using TensorFlow library.
  3 This example is using the MNIST database of handwritten digits (http://yann.lecun.com/exdb/mnist/)
  4 Long Short Term Memory paper: http://deeplearning.cs.cmu.edu/pdfs/Hochreiter97_lstm.pdf
  5 
  6 Author: Aymeric Damien
  7 Project: https://github.com/aymericdamien/TensorFlow-Examples/
  8 ‘‘‘
  9 
 10 from __future__ import print_function
 11 
 12 import tensorflow as tf
 13 from tensorflow.contrib import rnn
 14 import numpy as np
 15 
 16 # Import MNIST data
 17 from tensorflow.examples.tutorials.mnist import input_data
 18 mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
 19 
 20 ‘‘‘
 21 To classify images using a bidirectional recurrent neural network, we consider
 22 every image row as a sequence of pixels. Because MNIST image shape is 28*28px,
 23 we will then handle 28 sequences of 28 steps for every sample.
 24 ‘‘‘
 25 
 26 # Parameters
 27 learning_rate = 0.001
 28 
 29 # 可以理解为,训练时总共用的样本数
 30 training_iters = 100000
 31 
 32 # 每次训练的样本大小
 33 batch_size = 128
 34 
 35 # 这个是用来显示的。
 36 display_step = 10
 37 
 38 # Network Parameters
 39 # n_steps*n_input其实就是那张图 把每一行拆到每个time step上。
 40 n_input = 28 # MNIST data input (img shape: 28*28)
 41 n_steps = 28 # timesteps
 42 
 43 # 隐藏层大小
 44 n_hidden = 128 # hidden layer num of features
 45 n_classes = 10 # MNIST total classes (0-9 digits)
 46 
 47 # tf Graph input
 48 # [None, n_steps, n_input]这个None表示这一维不确定大小
 49 x = tf.placeholder("float", [None, n_steps, n_input])
 50 y = tf.placeholder("float", [None, n_classes])
 51 
 52 # Define weights
 53 weights = {
 54     # Hidden layer weights => 2*n_hidden because of forward + backward cells
 55     out: tf.Variable(tf.random_normal([2*n_hidden, n_classes]))
 56 }
 57 biases = {
 58     out: tf.Variable(tf.random_normal([n_classes]))
 59 }
 60 
 61 
 62 def BiRNN(x, weights, biases):
 63 
 64     # Prepare data shape to match `bidirectional_rnn` function requirements
 65     # Current data input shape: (batch_size, n_steps, n_input)
 66     # Required shape: ‘n_steps‘ tensors list of shape (batch_size, n_input)
 67 
 68     # Unstack to get a list of ‘n_steps‘ tensors of shape (batch_size, n_input)
 69     # 变成了n_steps*(batch_size, n_input)
 70     x = tf.unstack(x, n_steps, 1)
 71 
 72     # Define lstm cells with tensorflow
 73     # Forward direction cell
 74     lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
 75     # Backward direction cell
 76     lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
 77 
 78     # Get lstm cell output
 79     try:
 80         outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
 81                                               dtype=tf.float32)
 82     except Exception: # Old TensorFlow version only returns outputs not states
 83         outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
 84                                         dtype=tf.float32)
 85 
 86     # Linear activation, using rnn inner loop last output
 87     return tf.matmul(outputs[-1], weights[out]) + biases[out]
 88 
 89 pred = BiRNN(x, weights, biases)
 90 
 91 # Define loss and optimizer
 92 # softmax_cross_entropy_with_logits:Measures the probability error in discrete classification tasks in which the classes are mutually exclusive
 93 # return a 1-D Tensor of length batch_size of the same type as logits with the softmax cross entropy loss.
 94 # reduce_mean就是对所有数值(这里没有指定哪一维)求均值。
 95 cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
 96 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 97 
 98 # Evaluate model
 99 correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
100 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
101 
102 # Initializing the variables
103 init = tf.global_variables_initializer()
104 
105 # Launch the graph
106 with tf.Session() as sess:
107     sess.run(init)
108     step = 1
109     # Keep training until reach max iterations
110     while step * batch_size < training_iters:
111         batch_x, batch_y = mnist.train.next_batch(batch_size)
112         # Reshape data to get 28 seq of 28 elements
113         batch_x = batch_x.reshape((batch_size, n_steps, n_input))
114         # Run optimization op (backprop)
115         sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})
116         if step % display_step == 0:
117             # Calculate batch accuracy
118             acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})
119             # Calculate batch loss
120             loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y})
121             print("Iter " + str(step*batch_size) + ", Minibatch Loss= " + 122                   "{:.6f}".format(loss) + ", Training Accuracy= " + 123                   "{:.5f}".format(acc))
124         step += 1
125     print("Optimization Finished!")
126 
127     # Calculate accuracy for 128 mnist test images
128     test_len = 128
129     test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
130     test_label = mnist.test.labels[:test_len]
131     print("Testing Accuracy:", 132         sess.run(accuracy, feed_dict={x: test_data, y: test_label}))

官方关于bilstm的例子写的很清楚了。因为是第一次看,还是要查许多东西。尤其是数据处理方面。

数据的处理(https://segmentfault.com/a/1190000008793389)

拼接

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0) ==> [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 1) ==> [[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]]
tf.stack([t1, t2], 0)  ==> [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]
tf.stack([t1, t2], 1)  ==> [[[1, 2, 3], [7, 8, 9]], [[4, 5, 6], [10, 11, 12]]]
tf.stack([t1, t2], 2)  ==> [[[1, 7], [2, 8], [3, 9]], [[4, 10], [5, 11], [6, 12]]]

从shape的角度看:

t1 = [[1, 2, 3], [4, 5, 6]]
t2 = [[7, 8, 9], [10, 11, 12]]
tf.concat([t1, t2], 0)  # [2,3] + [2,3] ==> [4, 3]
tf.concat([t1, t2], 1)  # [2,3] + [2,3] ==> [2, 6]
tf.stack([t1, t2], 0)   # [2,3] + [2,3] ==> [2*,2,3]
tf.stack([t1, t2], 1)   # [2,3] + [2,3] ==> [2,2*,3]
tf.stack([t1, t2], 2)   # [2,3] + [2,3] ==> [2,3,2*]

抽取:

input = [[[1, 1, 1], [2, 2, 2]],
         [[3, 3, 3], [4, 4, 4]],
         [[5, 5, 5], [6, 6, 6]]]
tf.slice(input, [1, 0, 0], [1, 1, 3]) ==> [[[3, 3, 3]]]
tf.slice(input, [1, 0, 0], [1, 2, 3]) ==> [[[3, 3, 3],
                                            [4, 4, 4]]]
tf.slice(input, [1, 0, 0], [2, 1, 3]) ==> [[[3, 3, 3]],
                                           [[5, 5, 5]]]
                                           
tf.gather(input, [0, 2]) ==> [[[1, 1, 1], [2, 2, 2]],
                              [[5, 5, 5], [6, 6, 6]]]

 

tensorflow bilstm官方示例

标签:atm   org   format   amp   match   oop   sof   ace   log   

原文地址:http://www.cnblogs.com/linyx/p/6979119.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!