Skip to content
This repository was archived by the owner on Jan 1, 2021. It is now read-only.

Commit 51c01e5

Browse files
committed
list comprehension
1 parent 25e7c90 commit 51c01e5

2 files changed

Lines changed: 10 additions & 16 deletions

File tree

assignments/chatbot/model.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,16 @@ def __init__(self, forward_only, batch_size):
3434
def _create_placeholders(self):
3535
# Feeds for inputs. It's a list of placeholders
3636
print('Create placeholders')
37-
self.encoder_inputs = []
38-
self.decoder_inputs = []
39-
self.decoder_masks = []
40-
for i in xrange(config.BUCKETS[-1][0]): # Last bucket is the biggest one.
41-
self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
42-
name='encoder{}'.format(i)))
43-
for i in xrange(config.BUCKETS[-1][1] + 1):
44-
self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
45-
name='decoder{}'.format(i)))
46-
self.decoder_masks.append(tf.placeholder(tf.float32, shape=[None],
47-
name='mask{}'.format(i)))
37+
self.encoder_inputs = [tf.placeholder(tf.int32, shape=[None], name='encoder{}'.format(i))
38+
for i in xrange(config.BUCKETS[-1][0])]
39+
self.decoder_inputs = [tf.placeholder(tf.int32, shape=[None], name='decoder{}'.format(i))
40+
for i in xrange(config.BUCKETS[-1][1] + 1)]
41+
self.decoder_masks = [tf.placeholder(tf.float32, shape=[None], name='mask{}'.format(i))
42+
for i in xrange(config.BUCKETS[-1][1] + 1)]
4843

4944
# Our targets are decoder inputs shifted by one (to ignore <s> symbol)
50-
self.targets = [self.decoder_inputs[i + 1]
51-
for i in xrange(len(self.decoder_inputs) - 1)]
52-
45+
self.targets = self.decoder_inputs[1:]
46+
5347
def _inference(self):
5448
print('Create inference')
5549
# If we use sampled softmax, we need an output projection.

examples/11_char_rnn_gist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ def training(vocab, seq, loss, optimizer, global_step, temp, sample, in_state, o
8181
batch_loss, _ = sess.run([loss, optimizer], {seq: batch})
8282
if (iteration + 1) % SKIP_STEP == 0:
8383
print('Iter {}. \n Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))
84-
online_intference(sess, vocab, seq, sample, temp, in_state, out_state)
84+
online_inference(sess, vocab, seq, sample, temp, in_state, out_state)
8585
start = time.time()
8686
saver.save(sess, 'checkpoints/arvix/char-rnn', iteration)
8787
iteration += 1
8888

89-
def online_intference(sess, vocab, seq, sample, temp, in_state, out_state, seed='T'):
89+
def online_inference(sess, vocab, seq, sample, temp, in_state, out_state, seed='T'):
9090
""" Generate sequence one character at a time, based on the previous character
9191
"""
9292
sentence = seed

0 commit comments

Comments
 (0)