|
| 1 | +""" A neural chatbot using sequence to sequence model with |
| 2 | +attentional decoder. |
| 3 | +
|
| 4 | +This is based on Google Translate Tensorflow model |
| 5 | +https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/ |
| 6 | +
|
| 7 | +Sequence to sequence model by Cho et al.(2014) |
| 8 | +
|
| 9 | +Created by Chip Huyen as the starter code for assignment 3, |
| 10 | +class CS 20SI: "TensorFlow for Deep Learning Research" |
| 11 | +cs20si.stanford.edu |
| 12 | +
|
| 13 | +This file contains the code to run the model. |
| 14 | +
|
| 15 | +See readme.md for instruction on how to run the starter code. |
| 16 | +""" |
| 17 | +from __future__ import division |
| 18 | +from __future__ import print_function |
| 19 | + |
| 20 | +import argparse |
| 21 | +import os |
| 22 | +import random |
| 23 | +import sys |
| 24 | +import time |
| 25 | + |
| 26 | +import numpy as np |
| 27 | +import tensorflow as tf |
| 28 | + |
| 29 | +from model import ChatBotModel |
| 30 | +import config |
| 31 | +import data |
| 32 | + |
| 33 | +def _get_random_bucket(train_buckets_scale): |
| 34 | + """ Get a random bucket from which to choose a training sample """ |
| 35 | + rand = random.random() |
| 36 | + return min([i for i in xrange(len(train_buckets_scale)) |
| 37 | + if train_buckets_scale[i] > rand]) |
| 38 | + |
| 39 | +def _assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks): |
| 40 | + """ Assert that the encoder inputs, decoder inputs, and decoder masks are |
| 41 | + of the expected lengths """ |
| 42 | + if len(encoder_inputs) != encoder_size: |
| 43 | + raise ValueError("Encoder length must be equal to the one in bucket," |
| 44 | + " %d != %d." % (len(encoder_inputs), encoder_size)) |
| 45 | + if len(decoder_inputs) != decoder_size: |
| 46 | + raise ValueError("Decoder length must be equal to the one in bucket," |
| 47 | + " %d != %d." % (len(decoder_inputs), decoder_size)) |
| 48 | + if len(decoder_masks) != decoder_size: |
| 49 | + raise ValueError("Weights length must be equal to the one in bucket," |
| 50 | + " %d != %d." % (len(decoder_masks), decoder_size)) |
| 51 | + |
| 52 | +def run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, forward_only): |
| 53 | + """ Run one step in training. |
| 54 | + @forward_only: boolean value to decide whether a backward path should be created |
| 55 | + forward_only is set to True when you just want to evaluate on the test set, |
| 56 | + or when you want to the bot to be in chat mode. """ |
| 57 | + encoder_size, decoder_size = config.BUCKETS[bucket_id] |
| 58 | + _assert_lengths(encoder_size, decoder_size, encoder_inputs, decoder_inputs, decoder_masks) |
| 59 | + |
| 60 | + # input feed: encoder inputs, decoder inputs, target_weights, as provided. |
| 61 | + input_feed = {} |
| 62 | + for step in xrange(encoder_size): |
| 63 | + input_feed[model.encoder_inputs[step].name] = encoder_inputs[step] |
| 64 | + for step in xrange(decoder_size): |
| 65 | + input_feed[model.decoder_inputs[step].name] = decoder_inputs[step] |
| 66 | + input_feed[model.decoder_masks[step].name] = decoder_masks[step] |
| 67 | + |
| 68 | + last_target = model.decoder_inputs[decoder_size].name |
| 69 | + input_feed[last_target] = np.zeros([model.batch_size], dtype=np.int32) |
| 70 | + |
| 71 | + # output feed: depends on whether we do a backward step or not. |
| 72 | + if not forward_only: |
| 73 | + output_feed = [model.train_ops[bucket_id], # update op that does SGD. |
| 74 | + model.gradient_norms[bucket_id], # gradient norm. |
| 75 | + model.losses[bucket_id]] # loss for this batch. |
| 76 | + else: |
| 77 | + output_feed = [model.losses[bucket_id]] # loss for this batch. |
| 78 | + for step in xrange(decoder_size): # output logits. |
| 79 | + output_feed.append(model.outputs[bucket_id][step]) |
| 80 | + |
| 81 | + outputs = sess.run(output_feed, input_feed) |
| 82 | + if not forward_only: |
| 83 | + return outputs[1], outputs[2], None # Gradient norm, loss, no outputs. |
| 84 | + else: |
| 85 | + return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs. |
| 86 | + |
| 87 | +def _get_buckets(): |
| 88 | + """ Load the dataset into buckets based on their lengths. |
| 89 | + train_buckets_scale is the inverval that'll help us |
| 90 | + choose a random bucket later on. |
| 91 | + """ |
| 92 | + test_buckets = data.load_data('test_ids.enc', 'test_ids.dec') |
| 93 | + data_buckets = data.load_data('train_ids.enc', 'train_ids.dec') |
| 94 | + train_bucket_sizes = [len(data_buckets[b]) for b in xrange(len(config.BUCKETS))] |
| 95 | + print("Number of samples in each bucket:\n", train_bucket_sizes) |
| 96 | + train_total_size = sum(train_bucket_sizes) |
| 97 | + # list of increasing numbers from 0 to 1 that we'll use to select a bucket. |
| 98 | + train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size |
| 99 | + for i in xrange(len(train_bucket_sizes))] |
| 100 | + print("Bucket scale:\n", train_buckets_scale) |
| 101 | + return test_buckets, data_buckets, train_buckets_scale |
| 102 | + |
| 103 | +def _get_skip_step(iteration): |
| 104 | + """ How many steps should the model train before it saves all the weights. """ |
| 105 | + if iteration < 100: |
| 106 | + return 30 |
| 107 | + return 100 |
| 108 | + |
| 109 | +def _check_restore_parameters(sess, saver): |
| 110 | + """ Restore the previously trained parameters if there are any. """ |
| 111 | + ckpt = tf.train.get_checkpoint_state(os.path.dirname(config.CPT_PATH + '/checkpoint')) |
| 112 | + if ckpt and ckpt.model_checkpoint_path: |
| 113 | + print("Loading parameters for the Chatbot") |
| 114 | + saver.restore(sess, ckpt.model_checkpoint_path) |
| 115 | + else: |
| 116 | + print("Initializing fresh parameters for the Chatbot") |
| 117 | + |
| 118 | +def _eval_test_set(sess, model, test_buckets): |
| 119 | + """ Evaluate on the test set. """ |
| 120 | + for bucket_id in xrange(len(config.BUCKETS)): |
| 121 | + if len(test_buckets[bucket_id]) == 0: |
| 122 | + print(" Test: empty bucket %d" % (bucket_id)) |
| 123 | + continue |
| 124 | + start = time.time() |
| 125 | + encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(test_buckets[bucket_id], |
| 126 | + bucket_id, |
| 127 | + batch_size=config.BATCH_SIZE) |
| 128 | + _, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, |
| 129 | + decoder_masks, bucket_id, True) |
| 130 | + print('Test bucket {}: loss {}, time {}'.format(bucket_id, step_loss, time.time() - start)) |
| 131 | + |
| 132 | +def train(): |
| 133 | + """ Train the bot """ |
| 134 | + test_buckets, data_buckets, train_buckets_scale = _get_buckets() |
| 135 | + # in train mode, we need to create the backward path, so forwrad_only is False |
| 136 | + model = ChatBotModel(False, config.BATCH_SIZE) |
| 137 | + model.build_graph() |
| 138 | + |
| 139 | + saver = tf.train.Saver() |
| 140 | + initial_step = 0 |
| 141 | + |
| 142 | + with tf.Session() as sess: |
| 143 | + print('Running session') |
| 144 | + sess.run(tf.global_variables_initializer()) |
| 145 | + _check_restore_parameters(sess, saver) |
| 146 | + |
| 147 | + iteration = model.global_step.eval() |
| 148 | + total_loss = 0 |
| 149 | + while True: |
| 150 | + skip_step = _get_skip_step(iteration) |
| 151 | + bucket_id = _get_random_bucket(train_buckets_scale) |
| 152 | + encoder_inputs, decoder_inputs, decoder_masks = data.get_batch(data_buckets[bucket_id], |
| 153 | + bucket_id, |
| 154 | + batch_size=config.BATCH_SIZE) |
| 155 | + start = time.time() |
| 156 | + _, step_loss, _ = run_step(sess, model, encoder_inputs, decoder_inputs, decoder_masks, bucket_id, False) |
| 157 | + total_loss += step_loss |
| 158 | + iteration += 1 |
| 159 | + |
| 160 | + if iteration % skip_step == 0: |
| 161 | + print('Iter {}: loss {}, time {}'.format(iteration, total_loss/skip_step, time.time() - start)) |
| 162 | + start = time.time() |
| 163 | + total_loss = 0 |
| 164 | + saver.save(sess, os.path.join(config.CPT_PATH, 'chatbot'), global_step=model.global_step) |
| 165 | + if iteration % (10 * skip_step) == 0: |
| 166 | + # Run evals on development set and print their loss |
| 167 | + _eval_test_set(sess, model, test_buckets) |
| 168 | + start = time.time() |
| 169 | + sys.stdout.flush() |
| 170 | + |
| 171 | +def _get_user_input(): |
| 172 | + """ Get user's input, which will be transformed into encoder input later """ |
| 173 | + print("> ", end="") |
| 174 | + sys.stdout.flush() |
| 175 | + return sys.stdin.readline() |
| 176 | + |
| 177 | +def _find_right_bucket(length): |
| 178 | + """ Find the proper bucket for an encoder input based on its length """ |
| 179 | + return min([b for b in xrange(len(config.BUCKETS)) |
| 180 | + if config.BUCKETS[b][0] >= length]) |
| 181 | + |
| 182 | +def _construct_response(output_logits, inv_dec_vocab): |
| 183 | + """ Construct a response to the user's encoder input. |
| 184 | + @output_logits: the outputs from sequence to sequence wrapper. |
| 185 | + output_logits is decoder_size np array, each of dim 1 x DEC_VOCAB |
| 186 | + |
| 187 | + This is a greedy decoder - outputs are just argmaxes of output_logits. |
| 188 | + """ |
| 189 | + print(output_logits[0]) |
| 190 | + outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits] |
| 191 | + # If there is an EOS symbol in outputs, cut them at that point. |
| 192 | + if config.EOS_ID in outputs: |
| 193 | + outputs = outputs[:outputs.index(config.EOS_ID)] |
| 194 | + # Print out sentence corresponding to outputs. |
| 195 | + return " ".join([tf.compat.as_str(inv_dec_vocab[output]) for output in outputs]) |
| 196 | + |
| 197 | +def chat(): |
| 198 | + """ in test mode, we don't to create the backward path |
| 199 | + """ |
| 200 | + _, enc_vocab = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.enc')) |
| 201 | + inv_dec_vocab, _ = data.load_vocab(os.path.join(config.PROCESSED_PATH, 'vocab.dec')) |
| 202 | + |
| 203 | + model = ChatBotModel(True, batch_size=1) |
| 204 | + model.build_graph() |
| 205 | + |
| 206 | + saver = tf.train.Saver() |
| 207 | + |
| 208 | + with tf.Session() as sess: |
| 209 | + sess.run(tf.global_variables_initializer()) |
| 210 | + _check_restore_parameters(sess, saver) |
| 211 | + output_file = open(os.path.join(config.PROCESSED_PATH, config.OUTPUT_FILE), 'a+') |
| 212 | + # Decode from standard input. |
| 213 | + max_length = config.BUCKETS[-1][0] |
| 214 | + print('Welcome to TensorBro. Say something. Enter to exit. Max length is', max_length) |
| 215 | + while True: |
| 216 | + line = _get_user_input() |
| 217 | + if len(line) > 0 and line[-1] == '\n': |
| 218 | + line = line[:-1] |
| 219 | + if line == '': |
| 220 | + break |
| 221 | + output_file.write('HUMAN ++++ ' + line + '\n') |
| 222 | + # Get token-ids for the input sentence. |
| 223 | + token_ids = data.sentence2id(enc_vocab, str(line)) |
| 224 | + if (len(token_ids) > max_length): |
| 225 | + print('Max length I can handle is:', max_length) |
| 226 | + line = _get_user_input() |
| 227 | + continue |
| 228 | + # Which bucket does it belong to? |
| 229 | + bucket_id = _find_right_bucket(len(token_ids)) |
| 230 | + # Get a 1-element batch to feed the sentence to the model. |
| 231 | + encoder_inputs, decoder_inputs, decoder_masks = data.get_batch([(token_ids, [])], |
| 232 | + bucket_id, |
| 233 | + batch_size=1) |
| 234 | + # Get output logits for the sentence. |
| 235 | + _, _, output_logits = run_step(sess, model, encoder_inputs, decoder_inputs, |
| 236 | + decoder_masks, bucket_id, True) |
| 237 | + response = _construct_response(output_logits, inv_dec_vocab) |
| 238 | + print(response) |
| 239 | + output_file.write('BOT ++++ ' + response + '\n') |
| 240 | + output_file.write('=============================================\n') |
| 241 | + output_file.close() |
| 242 | + |
| 243 | +def main(): |
| 244 | + parser = argparse.ArgumentParser() |
| 245 | + parser.add_argument('--mode', choices={'train', 'chat'}, |
| 246 | + default='train', help="mode. if not specified, it's in the train mode") |
| 247 | + args = parser.parse_args() |
| 248 | + |
| 249 | + if not os.path.isdir(config.PROCESSED_PATH): |
| 250 | + data.prepare_raw_data() |
| 251 | + data.process_data() |
| 252 | + print('Data ready!') |
| 253 | + # create checkpoints folder if there isn't one already |
| 254 | + data.make_dir(config.CPT_PATH) |
| 255 | + |
| 256 | + if args.mode == 'train': |
| 257 | + train() |
| 258 | + elif args.mode == 'chat': |
| 259 | + chat() |
| 260 | + |
| 261 | +if __name__ == '__main__': |
| 262 | + main() |
0 commit comments