Skip to content
This repository was archived by the owner on Jan 1, 2021. It is now read-only.

Commit a6c4a6d

Browse files
authored
Fix lazy loading
1 parent 642ba95 commit a6c4a6d

1 file changed

Lines changed: 8 additions & 5 deletions

File tree

examples/03_logistic_regression_mnist_sol.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,18 @@
7373
print('Optimization Finished!') # should be around 0.35 after 25 epochs
7474

7575
# test the model
76+
77+
preds = tf.nn.softmax(logits)
78+
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y, 1))
79+
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) # need numpy.count_nonzero(boolarr) :(
80+
7681
n_batches = int(mnist.test.num_examples/batch_size)
7782
total_correct_preds = 0
83+
7884
for i in range(n_batches):
7985
X_batch, Y_batch = mnist.test.next_batch(batch_size)
80-
loss_batch, logits_batch = sess.run([loss, logits], feed_dict={X: X_batch, Y:Y_batch})
81-
preds = tf.nn.softmax(logits_batch)
82-
correct_preds = tf.equal(tf.argmax(preds, 1), tf.argmax(Y_batch, 1))
83-
accuracy = tf.reduce_sum(tf.cast(correct_preds, tf.float32)) # need numpy.count_nonzero(boolarr) :(
84-
total_correct_preds += sess.run(accuracy)
86+
accuracy_batch = sess.run([accuracy], feed_dict={X: X_batch, Y:Y_batch})
87+
total_correct_preds += accuracy_batch
8588

8689
print('Accuracy {0}'.format(total_correct_preds/mnist.test.num_examples))
8790

0 commit comments

Comments
 (0)