Skip to content
This repository was archived by the owner on Jan 1, 2021. It is now read-only.

Commit cb78f84

Browse files
committed
code for lecture 11
1 parent c8dfdd4 commit cb78f84

File tree

3 files changed

+40065
-0
lines changed

3 files changed

+40065
-0
lines changed

examples/11_char_rnn.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
""" A clean, no_frills character-level generative language model.
2+
3+
CS 20: "TensorFlow for Deep Learning Research"
4+
cs20.stanford.edu
5+
Danijar Hafner (mail@danijar.com)
6+
& Chip Huyen (chiphuyen@cs.stanford.edu)
7+
Lecture 11
8+
"""
9+
import os
10+
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
11+
import random
12+
import sys
13+
sys.path.append('..')
14+
import time
15+
16+
import tensorflow as tf
17+
18+
import utils
19+
20+
def vocab_encode(text, vocab):
21+
return [vocab.index(x) + 1 for x in text if x in vocab]
22+
23+
def vocab_decode(array, vocab):
24+
return ''.join([vocab[x - 1] for x in array])
25+
26+
def read_data(filename, vocab, window, overlap):
27+
lines = [line.strip() for line in open(filename, 'r').readlines()]
28+
while True:
29+
random.shuffle(lines)
30+
31+
for text in lines:
32+
text = vocab_encode(text, vocab)
33+
for start in range(0, len(text) - window, overlap):
34+
chunk = text[start: start + window]
35+
chunk += [0] * (window - len(chunk))
36+
yield chunk
37+
38+
def read_batch(stream, batch_size):
39+
batch = []
40+
for element in stream:
41+
batch.append(element)
42+
if len(batch) == batch_size:
43+
yield batch
44+
batch = []
45+
yield batch
46+
47+
class CharRNN(object):
48+
def __init__(self, model):
49+
self.model = model
50+
self.path = 'data/' + model + '.txt'
51+
if 'trump' in model:
52+
self.vocab = ("$%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
53+
" '\"_abcdefghijklmnopqrstuvwxyz{|}@#➡📈")
54+
else:
55+
self.vocab = (" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
56+
"\\^_abcdefghijklmnopqrstuvwxyz{|}")
57+
58+
self.seq = tf.placeholder(tf.int32, [None, None])
59+
self.temp = tf.constant(1.5)
60+
self.hidden_sizes = [128, 256]
61+
self.batch_size = 64
62+
self.lr = 0.0003
63+
self.skip_step = 1
64+
self.num_steps = 50 # for RNN unrolled
65+
self.len_generated = 200
66+
self.gstep = tf.Variable(0, dtype=tf.int32, trainable=False, name='global_step')
67+
68+
def create_rnn(self, seq):
69+
layers = [tf.nn.rnn_cell.GRUCell(size) for size in self.hidden_sizes]
70+
cells = tf.nn.rnn_cell.MultiRNNCell(layers)
71+
batch = tf.shape(seq)[0]
72+
zero_states = cells.zero_state(batch, dtype=tf.float32)
73+
self.in_state = tuple([tf.placeholder_with_default(state, [None, state.shape[1]])
74+
for state in zero_states])
75+
# this line to calculate the real length of seq
76+
# all seq are padded to be of the same length, which is num_steps
77+
length = tf.reduce_sum(tf.reduce_max(tf.sign(seq), 2), 1)
78+
self.output, self.out_state = tf.nn.dynamic_rnn(cells, seq, length, self.in_state)
79+
80+
def create_model(self):
81+
seq = tf.one_hot(self.seq, len(self.vocab))
82+
self.create_rnn(seq)
83+
self.logits = tf.layers.dense(self.output, len(self.vocab), None)
84+
loss = tf.nn.softmax_cross_entropy_with_logits(logits=self.logits[:, :-1],
85+
labels=seq[:, 1:])
86+
self.loss = tf.reduce_sum(loss)
87+
# sample the next character from Maxwell-Boltzmann Distribution
88+
# with temperature temp. It works equally well without tf.exp
89+
self.sample = tf.multinomial(tf.exp(self.logits[:, -1] / self.temp), 1)[:, 0]
90+
self.opt = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.gstep)
91+
92+
def train(self):
93+
saver = tf.train.Saver()
94+
start = time.time()
95+
min_loss = None
96+
with tf.Session() as sess:
97+
writer = tf.summary.FileWriter('graphs/gist', sess.graph)
98+
sess.run(tf.global_variables_initializer())
99+
100+
ckpt = tf.train.get_checkpoint_state(os.path.dirname('checkpoints/' + self.model + '/checkpoint'))
101+
if ckpt and ckpt.model_checkpoint_path:
102+
saver.restore(sess, ckpt.model_checkpoint_path)
103+
104+
iteration = self.gstep.eval()
105+
stream = read_data(self.path, self.vocab, self.num_steps, overlap=self.num_steps//2)
106+
data = read_batch(stream, self.batch_size)
107+
while True:
108+
batch = next(data)
109+
110+
# for batch in read_batch(read_data(DATA_PATH, vocab)):
111+
batch_loss, _ = sess.run([self.loss, self.opt], {self.seq: batch})
112+
if (iteration + 1) % self.skip_step == 0:
113+
print('Iter {}. \n Loss {}. Time {}'.format(iteration, batch_loss, time.time() - start))
114+
self.online_infer(sess)
115+
start = time.time()
116+
checkpoint_name = 'checkpoints/' + self.model + '/char-rnn'
117+
if min_loss is None:
118+
saver.save(sess, checkpoint_name, iteration)
119+
elif batch_loss < min_loss:
120+
saver.save(sess, checkpoint_name, iteration)
121+
min_loss = batch_loss
122+
iteration += 1
123+
124+
def online_infer(self, sess):
125+
""" Generate sequence one character at a time, based on the previous character
126+
"""
127+
for seed in ['Hillary', 'I', 'R', 'T', '@', 'N', 'M', '.', 'G', 'A', 'W']:
128+
sentence = seed
129+
state = None
130+
for _ in range(self.len_generated):
131+
batch = [vocab_encode(sentence[-1], self.vocab)]
132+
feed = {self.seq: batch}
133+
if state is not None: # for the first decoder step, the state is None
134+
for i in range(len(state)):
135+
feed.update({self.in_state[i]: state[i]})
136+
index, state = sess.run([self.sample, self.out_state], feed)
137+
sentence += vocab_decode(index, self.vocab)
138+
print('\t' + sentence)
139+
140+
def main():
141+
model = 'trump_tweets'
142+
utils.safe_mkdir('checkpoints')
143+
utils.safe_mkdir('checkpoints/' + model)
144+
145+
lm = CharRNN(model)
146+
lm.create_model()
147+
lm.train()
148+
149+
if __name__ == '__main__':
150+
main()

0 commit comments

Comments
 (0)