|
| 1 | +""" Using convolutional net on MNIST dataset of handwritten digits |
| 2 | +MNIST dataset: http://yann.lecun.com/exdb/mnist/ |
| 3 | +CS 20: "TensorFlow for Deep Learning Research" |
| 4 | +cs20.stanford.edu |
| 5 | +Chip Huyen (chiphuyen@cs.stanford.edu) |
| 6 | +Lecture 07 |
| 7 | +""" |
| 8 | +import os |
| 9 | +os.environ['TF_CPP_MIN_LOG_LEVEL']='2' |
| 10 | +import time |
| 11 | + |
| 12 | +import tensorflow as tf |
| 13 | + |
| 14 | +import utils |
| 15 | + |
| 16 | +def conv_relu(inputs, filters, k_size, stride, padding, scope_name): |
| 17 | + ''' |
| 18 | + A method that does convolution + relu on inputs |
| 19 | + ''' |
| 20 | + with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope: |
| 21 | + in_channels = inputs.shape[-1] |
| 22 | + kernel = tf.get_variable('kernel', |
| 23 | + [k_size, k_size, in_channels, filters], |
| 24 | + initializer=tf.truncated_normal_initializer()) |
| 25 | + biases = tf.get_variable('biases', |
| 26 | + [filters], |
| 27 | + initializer=tf.random_normal_initializer()) |
| 28 | + conv = tf.nn.conv2d(inputs, kernel, strides=[1, stride, stride, 1], padding=padding) |
| 29 | + return tf.nn.relu(conv + biases, name=scope.name) |
| 30 | + |
| 31 | +def maxpool(inputs, ksize, stride, padding='VALID', scope_name='pool'): |
| 32 | + '''A method that does max pooling on inputs''' |
| 33 | + with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope: |
| 34 | + pool = tf.nn.max_pool(inputs, |
| 35 | + ksize=[1, ksize, ksize, 1], |
| 36 | + strides=[1, stride, stride, 1], |
| 37 | + padding=padding) |
| 38 | + return pool |
| 39 | + |
| 40 | +def fully_connected(inputs, out_dim, scope_name='fc'): |
| 41 | + ''' |
| 42 | + A fully connected linear layer on inputs |
| 43 | + ''' |
| 44 | + with tf.variable_scope(scope_name, reuse=tf.AUTO_REUSE) as scope: |
| 45 | + in_dim = inputs.shape[-1] |
| 46 | + w = tf.get_variable('weights', [in_dim, out_dim], |
| 47 | + initializer=tf.truncated_normal_initializer()) |
| 48 | + b = tf.get_variable('biases', [out_dim], |
| 49 | + initializer=tf.constant_initializer(0.0)) |
| 50 | + out = tf.matmul(inputs, w) + b |
| 51 | + return out |
| 52 | + |
| 53 | +class ConvNet(object): |
| 54 | + def __init__(self): |
| 55 | + self.lr = 0.001 |
| 56 | + self.batch_size = 128 |
| 57 | + self.keep_prob = tf.constant(0.75) |
| 58 | + self.gstep = tf.Variable(0, dtype=tf.int32, |
| 59 | + trainable=False, name='global_step') |
| 60 | + self.n_classes = 10 |
| 61 | + self.skip_step = 20 |
| 62 | + self.n_test = 10000 |
| 63 | + self.training = True |
| 64 | + |
| 65 | + def get_data(self): |
| 66 | + with tf.name_scope('data'): |
| 67 | + train_data, test_data = utils.get_mnist_dataset(self.batch_size) |
| 68 | + iterator = tf.data.Iterator.from_structure(train_data.output_types, |
| 69 | + train_data.output_shapes) |
| 70 | + img, self.label = iterator.get_next() |
| 71 | + self.img = tf.reshape(img, shape=[-1, 28, 28, 1]) |
| 72 | + # reshape the image to make it work with tf.nn.conv2d |
| 73 | + |
| 74 | + self.train_init = iterator.make_initializer(train_data) # initializer for train_data |
| 75 | + self.test_init = iterator.make_initializer(test_data) # initializer for train_data |
| 76 | + |
| 77 | + def inference(self): |
| 78 | + conv1 = conv_relu(inputs=self.img, |
| 79 | + filters=32, |
| 80 | + k_size=5, |
| 81 | + stride=1, |
| 82 | + padding='SAME', |
| 83 | + scope_name='conv1') |
| 84 | + pool1 = maxpool(conv1, 2, 2, 'VALID', 'pool1') |
| 85 | + conv2 = conv_relu(inputs=pool1, |
| 86 | + filters=64, |
| 87 | + k_size=5, |
| 88 | + stride=1, |
| 89 | + padding='SAME', |
| 90 | + scope_name='conv2') |
| 91 | + pool2 = maxpool(conv2, 2, 2, 'VALID', 'pool2') |
| 92 | + feature_dim = pool2.shape[1] * pool2.shape[2] * pool2.shape[3] |
| 93 | + pool2 = tf.reshape(pool2, [-1, feature_dim]) |
| 94 | + fc = fully_connected(pool2, 1024, 'fc') |
| 95 | + dropout = tf.nn.dropout(tf.nn.relu(fc), self.keep_prob, name='relu_dropout') |
| 96 | + self.logits = fully_connected(dropout, self.n_classes, 'logits') |
| 97 | + |
| 98 | + def loss(self): |
| 99 | + ''' |
| 100 | + define loss function |
| 101 | + use softmax cross entropy with logits as the loss function |
| 102 | + compute mean cross entropy, softmax is applied internally |
| 103 | + ''' |
| 104 | + # |
| 105 | + with tf.name_scope('loss'): |
| 106 | + entropy = tf.nn.softmax_cross_entropy_with_logits(labels=self.label, logits=self.logits) |
| 107 | + self.loss = tf.reduce_mean(entropy, name='loss') |
| 108 | + |
| 109 | + def optimize(self): |
| 110 | + ''' |
| 111 | + Define training op |
| 112 | + using Adam Gradient Descent to minimize cost |
| 113 | + ''' |
| 114 | + self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, |
| 115 | + global_step=self.gstep) |
| 116 | + |
| 117 | + def summary(self): |
| 118 | + ''' |
| 119 | + Create summaries to write on TensorBoard |
| 120 | + ''' |
| 121 | + with tf.name_scope('summaries'): |
| 122 | + tf.summary.scalar('loss', self.loss) |
| 123 | + tf.summary.scalar('accuracy', self.accuracy) |
| 124 | + tf.summary.histogram('histogram loss', self.loss) |
| 125 | + self.summary_op = tf.summary.merge_all() |
| 126 | + |
| 127 | + def eval(self): |
| 128 | + ''' |
| 129 | + Count the number of right predictions in a batch |
| 130 | + ''' |
| 131 | + with tf.name_scope('predict'): |
| 132 | + preds = tf.nn.softmax(self.logits) |
| 133 | + correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(self.label, 1)) |
| 134 | + self.accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) |
| 135 | + |
| 136 | + def build(self): |
| 137 | + ''' |
| 138 | + Build the computation graph |
| 139 | + ''' |
| 140 | + self.get_data() |
| 141 | + self.inference() |
| 142 | + self.loss() |
| 143 | + self.optimize() |
| 144 | + self.eval() |
| 145 | + self.summary() |
| 146 | + |
| 147 | + def train_one_epoch(self, sess, saver, init, writer, epoch, step): |
| 148 | + start_time = time.time() |
| 149 | + sess.run(init) |
| 150 | + self.training = True |
| 151 | + total_loss = 0 |
| 152 | + n_batches = 0 |
| 153 | + try: |
| 154 | + while True: |
| 155 | + _, l, summaries = sess.run([self.opt, self.loss, self.summary_op]) |
| 156 | + writer.add_summary(summaries, global_step=step) |
| 157 | + if (step + 1) % self.skip_step == 0: |
| 158 | + print('Loss at step {0}: {1}'.format(step, l)) |
| 159 | + step += 1 |
| 160 | + total_loss += l |
| 161 | + n_batches += 1 |
| 162 | + except tf.errors.OutOfRangeError: |
| 163 | + pass |
| 164 | + saver.save(sess, 'checkpoints/convnet_mnist/mnist-convnet', step) |
| 165 | + print('Average loss at epoch {0}: {1}'.format(epoch, total_loss/n_batches)) |
| 166 | + print('Took: {0} seconds'.format(time.time() - start_time)) |
| 167 | + return step |
| 168 | + |
| 169 | + def eval_once(self, sess, init, writer, epoch, step): |
| 170 | + start_time = time.time() |
| 171 | + sess.run(init) |
| 172 | + self.training = False |
| 173 | + total_correct_preds = 0 |
| 174 | + try: |
| 175 | + while True: |
| 176 | + accuracy_batch, summaries = sess.run([self.accuracy, self.summary_op]) |
| 177 | + writer.add_summary(summaries, global_step=step) |
| 178 | + total_correct_preds += accuracy_batch |
| 179 | + except tf.errors.OutOfRangeError: |
| 180 | + pass |
| 181 | + |
| 182 | + print('Accuracy at epoch {0}: {1} '.format(epoch, total_correct_preds/self.n_test)) |
| 183 | + print('Took: {0} seconds'.format(time.time() - start_time)) |
| 184 | + |
| 185 | + def train(self, n_epochs): |
| 186 | + ''' |
| 187 | + The train function alternates between training one epoch and evaluating |
| 188 | + ''' |
| 189 | + utils.safe_mkdir('checkpoints') |
| 190 | + utils.safe_mkdir('checkpoints/convnet_mnist') |
| 191 | + writer = tf.summary.FileWriter('./graphs/convnet', tf.get_default_graph()) |
| 192 | + |
| 193 | + with tf.Session() as sess: |
| 194 | + sess.run(tf.global_variables_initializer()) |
| 195 | + saver = tf.train.Saver() |
| 196 | + ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/convnet_mnist/checkpoint')) |
| 197 | + if ckpt and ckpt.model_checkpoint_path: |
| 198 | + saver.restore(sess, ckpt.model_checkpoint_path) |
| 199 | + |
| 200 | + step = self.gstep.eval() |
| 201 | + |
| 202 | + for epoch in range(n_epochs): |
| 203 | + step = self.train_one_epoch(sess, saver, self.train_init, writer, epoch, step) |
| 204 | + self.eval_once(sess, self.test_init, writer, epoch, step) |
| 205 | + writer.close() |
| 206 | + |
| 207 | +if __name__ == '__main__': |
| 208 | + model = ConvNet() |
| 209 | + model.build() |
| 210 | + model.train(n_epochs=30) |
0 commit comments