MNIST数字识别问题初体验

  |   0 评论   |   24 浏览

背景

MNIST手写体数字识别问题。MNIST是一个数据集,包含6万张训练图片和1万张测试图片。

是TensorFlow入坑的第一个问题。

初体验

代码

import tensorflow as tf
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test)  = mnist.load_data()
x_train, x_test = x_train /  255.0, x_test /  255.0model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)
val_loss, val_acc = model.evaluate(x_test, y_test)
print(val_loss)
print(val_acc)

下载数据集

下载 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz,放到至文件 ~/.keras/datasets/mnist.npz

运行

$ python mnist.py
Epoch 1/5
2019-02-08 23:53:20.367634: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
60000/60000 [==============================] - 6s 95us/step - loss: 0.1993 - acc: 0.9413
Epoch 2/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0786 - acc: 0.9758
Epoch 3/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0522 - acc: 0.9839
Epoch 4/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0356 - acc: 0.9886
Epoch 5/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0276 - acc: 0.9912
10000/10000 [==============================] - 0s 31us/step
0.06755815939957101
0.9804

过程

图片数目

(x_train, y_train),(x_test, y_test) = mnist.load_data()
print(len(x_train))
print(len(y_train))
print(len(x_test))
print(len(y_test))

结果

60000
60000
10000
10000

第一幅图

(x_train, y_train),(x_test, y_test) = mnist.load_data()
print(x_train[0])

from matplotlib import pyplot as plt
plt.imshow(x_train[0],cmap=plt.cm.binary) # 显示黑白图像
plt.show()

结果

第一幅图片是一个 28 x 28 的矩阵,如下:

[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   3  18  18  18 126 136
  175  26 166 255 247 127   0   0   0   0]
 [  0   0   0   0   0   0   0   0  30  36  94 154 170 253 253 253 253 253
  225 172 253 242 195  64   0   0   0   0]
 [  0   0   0   0   0   0   0  49 238 253 253 253 253 253 253 253 253 251
   93  82  82  56  39   0   0   0   0   0]
 [  0   0   0   0   0   0   0  18 219 253 253 253 253 253 198 182 247 241
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  80 156 107 253 253 205  11   0  43 154
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0  14   1 154 253  90   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0 139 253 190   2   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0  11 190 253  70   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  35 241 225 160 108   1
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0  81 240 253 253 119
   25   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  45 186 253 253
  150  27   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  16  93 252
  253 187   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 249
  253 249  64   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0  46 130 183 253
  253 207   2   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0  39 148 229 253 253 253
  250 182   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0  24 114 221 253 253 253 253 201
   78   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0  23  66 213 253 253 253 253 198  81   2
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0  18 171 219 253 253 253 253 195  80   9   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0  55 172 226 253 253 253 253 244 133  11   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0 136 253 253 253 212 135 132  16   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]
 [  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
    0   0   0   0   0   0   0   0   0   0]]

如果显示为黑白图片,则如下:

imagepng

图片归一化

x_train = tf.keras.utils.normalize(x_train, axis=1)
x_test = tf.keras.utils.normalize(x_test, axis=1)

或者在本案例中,使用

x_train, x_test = x_train / 255.0, x_test / 255.0

也可以。此时为:

[[0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.00393124 0.02332955 0.02620568 0.02625207 0.17420356 0.17566281
  0.28629534 0.05664824 0.51877786 0.71632322 0.77892406 0.89301644
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.05780486 0.06524513 0.16128198 0.22713296
  0.22277047 0.32790981 0.36833534 0.3689874  0.34978968 0.32678448
  0.368094   0.3747499  0.79066747 0.67980478 0.61494005 0.45002403
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.12250613 0.45858525 0.45852825 0.43408872 0.37314701
  0.33153488 0.32790981 0.36833534 0.3689874  0.34978968 0.32420121
  0.15214552 0.17865984 0.25626376 0.1573102  0.12298801 0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.04500225 0.4219755  0.45852825 0.43408872 0.37314701
  0.33153488 0.32790981 0.28826244 0.26543758 0.34149427 0.31128482
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.1541463  0.28272888 0.18358693 0.37314701
  0.33153488 0.26569767 0.01601458 0.         0.05945042 0.19891229
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.0253731  0.00171577 0.22713296
  0.33153488 0.11664776 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.20500962
  0.33153488 0.24625638 0.00291174 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.01622378
  0.24897876 0.32790981 0.10191096 0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.04586451 0.31235677 0.32757096 0.23335172 0.14931733 0.00129164
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.10498298 0.34940902 0.3689874  0.34978968 0.15370495
  0.04089933 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.06551419 0.27127137 0.34978968 0.32678448
  0.245396   0.05882702 0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.02333517 0.12857881 0.32549285
  0.41390126 0.40743158 0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.32161793
  0.41390126 0.54251585 0.20001074 0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.06697006 0.18959827 0.25300993 0.32678448
  0.41390126 0.45100715 0.00625034 0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.05110617 0.19182076 0.33339444 0.3689874  0.34978968 0.32678448
  0.40899334 0.39653769 0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.04117838 0.16813739
  0.28960162 0.32790981 0.36833534 0.3689874  0.34978968 0.25961929
  0.12760592 0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.04431706 0.11961607 0.36545809 0.37314701
  0.33153488 0.32790981 0.36833534 0.28877275 0.111988   0.00258328
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.05298497 0.42752138 0.4219755  0.45852825 0.43408872 0.37314701
  0.33153488 0.25273681 0.11646967 0.01312603 0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.37491383 0.56222061
  0.66525569 0.63253163 0.48748768 0.45852825 0.43408872 0.359873
  0.17428513 0.01425695 0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.92705966 0.82698729
  0.74473314 0.63253163 0.4084877  0.24466922 0.22648107 0.02359823
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]
 [0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.         0.         0.
  0.         0.         0.         0.        ]]

imagepng

建立模型

参数 tf.keras.models.Sequential 表明这里使用了Keras 的 Sequential 模型(顺序模型)。这是一个常见的模型,是一个按顺序向前传递的神经网络。

添加图层

完整的参数为:

[
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(512, activation=tf.nn.relu),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation=tf.nn.softmax)
]

这里逐个介绍。

  1. 输入层:图像矩阵展平

由于神经网络的输入是一个一维向量,所以需要把图像矩阵展平成一维向量,从 28 x 28 变为 1 x 784。

这里使用了Keras的Flatten()方法,对应的参数为tf.keras.layers.Flatten()

  1. 隐藏层

这里使用了最简单的Dense层(全连接层), 每一个神经元与前后两层的所有神经元相连,对应的参数为tf.keras.layers.Dense(512, activation=tf.nn.relu)

此全连接层有 512 个单元,使用了 reLU 激活函数。

  1. 隐藏层

这里使用了Dropout层,来防止过拟合,提高模型的泛化能力,对应的参数为tf.keras.layers.Dropout(0.2)

这里的保留比例为0.2,即每一个元素被保留下来的概率为0.2。

  1. 输出层

输出层同样使用了Dense层,对应的参数为tf.keras.layers.Dense(10, activation=tf.nn.softmax)

输出层有10个单元,对应于0-9共10个数字。这里使用softmax激活函数,

这里使用 softmax 函数作为激活函数,可以得到不同数字的相对分布。

编译模型

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

optimizer: 优化器,默认为adam
loss: 损失函数,这里为sparse_categorical_crossentropy,计算分类结果的交叉熵损失。
metrics列表: 这里为评估模型性能的指标,典型用法即 metrics=['accuracy']

拟合模型

model.fit(x_train, y_train, epochs=5)

对应的输出为

Epoch 1/5
60000/60000 [==============================] - 6s 95us/step - loss: 0.1993 - acc: 0.9413
Epoch 2/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0786 - acc: 0.9758
Epoch 3/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0522 - acc: 0.9839
Epoch 4/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0356 - acc: 0.9886
Epoch 5/5
60000/60000 [==============================] - 5s 89us/step - loss: 0.0276 - acc: 0.9912

可见损失值loss 在下降,准确度acc在升高。

测试模型

最后来测试一下模型,看看学习的效果如何。

model.evaluate(x_test, y_test)

对应的输出为

10000/10000 [==============================] - 0s 31us/step
0.06755815939957101
0.9804

即损失值loss为0.0675,准确度acc为0.9804

测试

识别测试集第一个数字

predictions = model.predict(x_test)
print(predictions[0])

结果如下:

[1.0469671e-09 1.9170661e-09 2.7298618e-08 2.8374097e-05 1.0196099e-12
 3.7379849e-10 7.5198276e-12 9.9997115e-01 3.0012792e-08 4.3838830e-07]

即结果为7的概率为0.9997115

也可以通过

import numpy as np
print(np.argmax(predictions[0]))

结果为

7

来直观的看出来第7个参数值最大。

实际的图像是什么呢?我们来画出来。

from matplotlib import pyplot as plt
plt.imshow(x_test[0],cmap=plt.cm.binary)
plt.show()

对应的图片为:

imagepng

即就是7。

识别训练集第一个数字

trains = model.predict(x_train)
print(trains[0])
import numpy as np
print(np.argmax(trains[0]))

结果为

[4.2104886e-14 5.7015930e-11 2.9668783e-09 1.6136320e-01 2.0297861e-21
 8.3863670e-01 6.2727927e-15 1.4089062e-11 1.6384524e-10 1.3477487e-07]
5

即结果为5的概率为0.8386356

确实如此,最上面的那个图片确实在0-9的这10个数字中最接近于5,但是又不是非常非常的像5。

参考

评论

发表评论

validate