Skip to content

Commit 432d7b0

Browse files
MorvanZhouMorvan Zhou
authored and
Morvan Zhou
committed
update
1 parent ec9543c commit 432d7b0

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ All methods mentioned below have their video and text tutorial in Chinese. Visit
5353
<img class="course-image" src="https://morvanzhou.github.io/static/results/torch/1-1-3.gif">
5454
</a>
5555

56+
### [CNN](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/401_CNN.py)
57+
<a href="https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/401_CNN.py">
58+
<img class="course-image" src="https://morvanzhou.github.io/static/results/torch/4-1-2.gif" >
59+
</a>
60+
5661
### [RNN](https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/403_RNN_regression.py)
5762

5863
<a href="https://github.com/MorvanZhou/Tensorflow-Tutorial/blob/master/tutorial-contents/403_RNN_regression.py">

tutorial-contents/401_CNN.py

+32-2
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,42 @@
6262
init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_op
6363
sess.run(init_op) # initialize var in graph
6464

65+
# following function (plot_with_labels) is for visualization, can be ignored if not interested
66+
from matplotlib import cm
67+
try:
68+
from sklearn.manifold import TSNE
69+
HAS_SK = True
70+
except:
71+
HAS_SK = False
72+
def plot_with_labels(lowDWeights, labels):
73+
plt.cla()
74+
X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
75+
for x, y, s in zip(X, Y, labels):
76+
c = cm.rainbow(int(255 * s / 9))
77+
plt.text(x, y, s, backgroundcolor=c, fontsize=9)
78+
plt.xlim(X.min(), X.max())
79+
plt.ylim(Y.min(), Y.max())
80+
plt.title('Visualize last layer')
81+
plt.show()
82+
plt.pause(0.01)
83+
84+
plt.ion()
85+
6586
for step in range(600):
6687
b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
6788
_, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
6889
if step % 50 == 0:
69-
accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y})
70-
print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)
90+
accuracy_, flat_representation = sess.run([accuracy, flat], {tf_x: test_x, tf_y: test_y})
91+
print('Step:', step, '| train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_)
92+
93+
if HAS_SK:
94+
# Visualization of trained flatten layer (T-SNE)
95+
tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
96+
plot_only = 500
97+
low_dim_embs = tsne.fit_transform(flat_representation[:plot_only, :])
98+
labels = np.argmax(test_y, axis=1)[:plot_only]
99+
plot_with_labels(low_dim_embs, labels)
100+
plt.ioff()
71101

72102
# print 10 predictions from test data
73103
test_output = sess.run(output, {tf_x: test_x[:10]})

0 commit comments

Comments
 (0)