1+ """ A clean, no_frills character-level generative language model.
2+
3+ CS 20: "TensorFlow for Deep Learning Research"
4+ cs20.stanford.edu
5+ Danijar Hafner (mail@danijar.com)
6+ & Chip Huyen (chiphuyen@cs.stanford.edu)
7+ Lecture 11
8+ """
9+ import os
10+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ]= '2'
11+ import random
12+ import sys
13+ sys .path .append ('..' )
14+ import time
15+
16+ import tensorflow as tf
17+
18+ import utils
19+
20+ def vocab_encode (text , vocab ):
21+ return [vocab .index (x ) + 1 for x in text if x in vocab ]
22+
23+ def vocab_decode (array , vocab ):
24+ return '' .join ([vocab [x - 1 ] for x in array ])
25+
26+ def read_data (filename , vocab , window , overlap ):
27+ lines = [line .strip () for line in open (filename , 'r' ).readlines ()]
28+ while True :
29+ random .shuffle (lines )
30+
31+ for text in lines :
32+ text = vocab_encode (text , vocab )
33+ for start in range (0 , len (text ) - window , overlap ):
34+ chunk = text [start : start + window ]
35+ chunk += [0 ] * (window - len (chunk ))
36+ yield chunk
37+
38+ def read_batch (stream , batch_size ):
39+ batch = []
40+ for element in stream :
41+ batch .append (element )
42+ if len (batch ) == batch_size :
43+ yield batch
44+ batch = []
45+ yield batch
46+
47+ class CharRNN (object ):
48+ def __init__ (self , model ):
49+ self .model = model
50+ self .path = 'data/' + model + '.txt'
51+ if 'trump' in model :
52+ self .vocab = ("$%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
53+ " '\" _abcdefghijklmnopqrstuvwxyz{|}@#➡📈" )
54+ else :
55+ self .vocab = (" $%'()+,-./0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ"
56+ "\\ ^_abcdefghijklmnopqrstuvwxyz{|}" )
57+
58+ self .seq = tf .placeholder (tf .int32 , [None , None ])
59+ self .temp = tf .constant (1.5 )
60+ self .hidden_sizes = [128 , 256 ]
61+ self .batch_size = 64
62+ self .lr = 0.0003
63+ self .skip_step = 1
64+ self .num_steps = 50 # for RNN unrolled
65+ self .len_generated = 200
66+ self .gstep = tf .Variable (0 , dtype = tf .int32 , trainable = False , name = 'global_step' )
67+
68+ def create_rnn (self , seq ):
69+ layers = [tf .nn .rnn_cell .GRUCell (size ) for size in self .hidden_sizes ]
70+ cells = tf .nn .rnn_cell .MultiRNNCell (layers )
71+ batch = tf .shape (seq )[0 ]
72+ zero_states = cells .zero_state (batch , dtype = tf .float32 )
73+ self .in_state = tuple ([tf .placeholder_with_default (state , [None , state .shape [1 ]])
74+ for state in zero_states ])
75+ # this line to calculate the real length of seq
76+ # all seq are padded to be of the same length, which is num_steps
77+ length = tf .reduce_sum (tf .reduce_max (tf .sign (seq ), 2 ), 1 )
78+ self .output , self .out_state = tf .nn .dynamic_rnn (cells , seq , length , self .in_state )
79+
80+ def create_model (self ):
81+ seq = tf .one_hot (self .seq , len (self .vocab ))
82+ self .create_rnn (seq )
83+ self .logits = tf .layers .dense (self .output , len (self .vocab ), None )
84+ loss = tf .nn .softmax_cross_entropy_with_logits (logits = self .logits [:, :- 1 ],
85+ labels = seq [:, 1 :])
86+ self .loss = tf .reduce_sum (loss )
87+ # sample the next character from Maxwell-Boltzmann Distribution
88+ # with temperature temp. It works equally well without tf.exp
89+ self .sample = tf .multinomial (tf .exp (self .logits [:, - 1 ] / self .temp ), 1 )[:, 0 ]
90+ self .opt = tf .train .AdamOptimizer (self .lr ).minimize (self .loss , global_step = self .gstep )
91+
92+ def train (self ):
93+ saver = tf .train .Saver ()
94+ start = time .time ()
95+ min_loss = None
96+ with tf .Session () as sess :
97+ writer = tf .summary .FileWriter ('graphs/gist' , sess .graph )
98+ sess .run (tf .global_variables_initializer ())
99+
100+ ckpt = tf .train .get_checkpoint_state (os .path .dirname ('checkpoints/' + self .model + '/checkpoint' ))
101+ if ckpt and ckpt .model_checkpoint_path :
102+ saver .restore (sess , ckpt .model_checkpoint_path )
103+
104+ iteration = self .gstep .eval ()
105+ stream = read_data (self .path , self .vocab , self .num_steps , overlap = self .num_steps // 2 )
106+ data = read_batch (stream , self .batch_size )
107+ while True :
108+ batch = next (data )
109+
110+ # for batch in read_batch(read_data(DATA_PATH, vocab)):
111+ batch_loss , _ = sess .run ([self .loss , self .opt ], {self .seq : batch })
112+ if (iteration + 1 ) % self .skip_step == 0 :
113+ print ('Iter {}. \n Loss {}. Time {}' .format (iteration , batch_loss , time .time () - start ))
114+ self .online_infer (sess )
115+ start = time .time ()
116+ checkpoint_name = 'checkpoints/' + self .model + '/char-rnn'
117+ if min_loss is None :
118+ saver .save (sess , checkpoint_name , iteration )
119+ elif batch_loss < min_loss :
120+ saver .save (sess , checkpoint_name , iteration )
121+ min_loss = batch_loss
122+ iteration += 1
123+
124+ def online_infer (self , sess ):
125+ """ Generate sequence one character at a time, based on the previous character
126+ """
127+ for seed in ['Hillary' , 'I' , 'R' , 'T' , '@' , 'N' , 'M' , '.' , 'G' , 'A' , 'W' ]:
128+ sentence = seed
129+ state = None
130+ for _ in range (self .len_generated ):
131+ batch = [vocab_encode (sentence [- 1 ], self .vocab )]
132+ feed = {self .seq : batch }
133+ if state is not None : # for the first decoder step, the state is None
134+ for i in range (len (state )):
135+ feed .update ({self .in_state [i ]: state [i ]})
136+ index , state = sess .run ([self .sample , self .out_state ], feed )
137+ sentence += vocab_decode (index , self .vocab )
138+ print ('\t ' + sentence )
139+
140+ def main ():
141+ model = 'trump_tweets'
142+ utils .safe_mkdir ('checkpoints' )
143+ utils .safe_mkdir ('checkpoints/' + model )
144+
145+ lm = CharRNN (model )
146+ lm .create_model ()
147+ lm .train ()
148+
149+ if __name__ == '__main__' :
150+ main ()
0 commit comments