|
62 | 62 | init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) # the local var is for accuracy_op
|
63 | 63 | sess.run(init_op) # initialize var in graph
|
64 | 64 |
|
| 65 | +# following function (plot_with_labels) is for visualization, can be ignored if not interested |
| 66 | +from matplotlib import cm |
| 67 | +try: |
| 68 | + from sklearn.manifold import TSNE |
| 69 | + HAS_SK = True |
| 70 | +except: |
| 71 | + HAS_SK = False |
| 72 | +def plot_with_labels(lowDWeights, labels): |
| 73 | + plt.cla() |
| 74 | + X, Y = lowDWeights[:, 0], lowDWeights[:, 1] |
| 75 | + for x, y, s in zip(X, Y, labels): |
| 76 | + c = cm.rainbow(int(255 * s / 9)) |
| 77 | + plt.text(x, y, s, backgroundcolor=c, fontsize=9) |
| 78 | + plt.xlim(X.min(), X.max()) |
| 79 | + plt.ylim(Y.min(), Y.max()) |
| 80 | + plt.title('Visualize last layer') |
| 81 | + plt.show() |
| 82 | + plt.pause(0.01) |
| 83 | + |
| 84 | +plt.ion() |
| 85 | + |
65 | 86 | for step in range(600):
|
66 | 87 | b_x, b_y = mnist.train.next_batch(BATCH_SIZE)
|
67 | 88 | _, loss_ = sess.run([train_op, loss], {tf_x: b_x, tf_y: b_y})
|
68 | 89 | if step % 50 == 0:
|
69 |
| - accuracy_ = sess.run(accuracy, {tf_x: test_x, tf_y: test_y}) |
70 |
| - print('train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_) |
| 90 | + accuracy_, flat_representation = sess.run([accuracy, flat], {tf_x: test_x, tf_y: test_y}) |
| 91 | + print('Step:', step, '| train loss: %.4f' % loss_, '| test accuracy: %.2f' % accuracy_) |
| 92 | + |
| 93 | + if HAS_SK: |
| 94 | + # Visualization of trained flatten layer (T-SNE) |
| 95 | + tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000) |
| 96 | + plot_only = 500 |
| 97 | + low_dim_embs = tsne.fit_transform(flat_representation[:plot_only, :]) |
| 98 | + labels = np.argmax(test_y, axis=1)[:plot_only] |
| 99 | + plot_with_labels(low_dim_embs, labels) |
| 100 | +plt.ioff() |
71 | 101 |
|
72 | 102 | # print 10 predictions from test data
|
73 | 103 | test_output = sess.run(output, {tf_x: test_x[:10]})
|
|
0 commit comments