LeNet-5模型初体验

  |   0 评论   |   21 浏览

背景

MNIST数字识别问题初体验后,接着要掌握的就是卷积神经网络(Convolutional Neural Network, CNN)。

CNN中最经典的一个入门例子就是LeNet-5模型。

本文演示一下,如何使用LeNet-5模型,来求解MNIST数字识别问题。

初体验

代码

# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()

# 重构数据至4维(样本,像素X,像素Y,通道)
x_train=x_train.reshape(x_train.shape+(1,))
x_test=x_test.reshape(x_test.shape+(1,))

x_train, x_test = x_train / 255.0, x_test / 255.0

# 数据标签
label_train = tf.keras.utils.to_categorical(y_train, 10)
label_test = tf.keras.utils.to_categorical(y_test, 10)

# 建立LeNet-5模型
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'),
    tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
    tf.keras.layers.Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'),
    tf.keras.layers.AveragePooling2D(pool_size=(2, 2), strides=(2, 2), padding='valid'),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(120, activation='tanh'),
    tf.keras.layers.Dense(84, activation='tanh'),
    tf.keras.layers.Dense(10, activation='softmax'),
])

# 编译模型,使用SGD优化器
model.compile(optimizer='SGD',
              loss=tf.keras.losses.categorical_crossentropy,
              metrics=['accuracy'])

# 学习20轮,使用20%数据交叉验证
records = model.fit(x_train, label_train, epochs=20, validation_split=0.2)

# 预测
y_pred = np.argmax(model.predict(x_test), axis=1)
print("prediction accuracy: {}".format(1.0*sum(y_pred==y_test)/len(y_test)))

# 绘制结果
plt.plot(records.history['loss'],label='training set loss')
plt.plot(records.history['val_loss'],label='validation set loss')
plt.ylabel('categorical cross-entropy'); plt.xlabel('epoch')
plt.legend()
plt.show()

结果

运行结果

Train on 48000 samples, validate on 12000 samples
Epoch 1/20
48000/48000 [==============================] - 10s 199us/step - loss: 0.7060 - acc: 0.8112 - val_loss: 0.3408 - val_acc: 0.9048
Epoch 2/20
48000/48000 [==============================] - 9s 196us/step - loss: 0.3173 - acc: 0.9082 - val_loss: 0.2564 - val_acc: 0.9274
Epoch 3/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.2454 - acc: 0.9284 - val_loss: 0.2049 - val_acc: 0.9403
Epoch 4/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.1989 - acc: 0.9410 - val_loss: 0.1705 - val_acc: 0.9515
Epoch 5/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.1651 - acc: 0.9509 - val_loss: 0.1449 - val_acc: 0.9573
Epoch 6/20
48000/48000 [==============================] - 9s 193us/step - loss: 0.1404 - acc: 0.9586 - val_loss: 0.1261 - val_acc: 0.9636
Epoch 7/20
48000/48000 [==============================] - 9s 194us/step - loss: 0.1217 - acc: 0.9642 - val_loss: 0.1127 - val_acc: 0.9681
Epoch 8/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.1080 - acc: 0.9684 - val_loss: 0.1034 - val_acc: 0.9703
Epoch 9/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0970 - acc: 0.9712 - val_loss: 0.0942 - val_acc: 0.9736
Epoch 10/20
48000/48000 [==============================] - 9s 196us/step - loss: 0.0882 - acc: 0.9740 - val_loss: 0.0891 - val_acc: 0.9747
Epoch 11/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0812 - acc: 0.9760 - val_loss: 0.0827 - val_acc: 0.9769
Epoch 12/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0749 - acc: 0.9783 - val_loss: 0.0799 - val_acc: 0.9772
Epoch 13/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0702 - acc: 0.9801 - val_loss: 0.0742 - val_acc: 0.9788
Epoch 14/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0655 - acc: 0.9810 - val_loss: 0.0707 - val_acc: 0.9796
Epoch 15/20
48000/48000 [==============================] - 9s 194us/step - loss: 0.0620 - acc: 0.9820 - val_loss: 0.0689 - val_acc: 0.9807
Epoch 16/20
48000/48000 [==============================] - 9s 196us/step - loss: 0.0584 - acc: 0.9832 - val_loss: 0.0674 - val_acc: 0.9808
Epoch 17/20
48000/48000 [==============================] - 9s 196us/step - loss: 0.0555 - acc: 0.9845 - val_loss: 0.0638 - val_acc: 0.9815
Epoch 18/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0528 - acc: 0.9847 - val_loss: 0.0628 - val_acc: 0.9813
Epoch 19/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0503 - acc: 0.9858 - val_loss: 0.0636 - val_acc: 0.9818
Epoch 20/20
48000/48000 [==============================] - 9s 195us/step - loss: 0.0479 - acc: 0.9866 - val_loss: 0.0602 - val_acc: 0.9819

预测正确率:98.2%

prediction accuracy: 0.982

imagepng

过程

imagepng

第一层:卷积层C1

C1层是卷积层,过滤器尺寸为5*5;深度为6,即有6个特征图谱。每个特征图谱内共用一个卷积核,有5*5个连接参数,加上1个偏置参数,共5*5+1=26个参数。6个特征图谱,总共有26*6=156个训练参数。

卷积区域每次滑动一个像素,这样每个特征图谱的大小为28*28

所以C1层总共的连接数为:156*28*28=122304

第二层:池化层S2

S2层是一个下采样层。

对C1层的每个特征图谱,进行2*2为单位的下采样,得到一个14*14的图。每个特征图谱使用一个下采样核。共有6*14*14*5=5880个连接。

第三层:卷积层C3

C3层和C1层一样,也是一个卷积层。但是C3层的每个节点,和S2层中的多个图相连。C3层的深度为16,共有1610*10的图。

C3层的每个图,与S2层的连接采用的是不对称的组合方式。

每层有(5*5*3 + 1)*6 + (5*5*4 + 1)*3 + (5*5*4 + 1)*6 + (5*5*6 + 1)*1 = 1516个训练参数,共有1516*10*10=151600个连接。详情参考[5]。

第四层:池化层S4

S4层也是一个下采样层。C3层的1610*10的图分别进行以2*2为单位的下抽样,得到165*5的图。16*5*5*5=2000个连接。

第五层:全连接层C5

C5层是一个全连接层,有120个节点。每个节点与S4层的16个图相连,共有16*5*5+1个参数。120个节点,共有120*(16*5*5+1)=48120个参数,同样有48120个连接。

第六层:全连接层F6

F6层是全连接层。F6层有84个节点,对应于一个7*12的比特图,该层的训练参数和连接数都是84*(120+1)=10164

第七层:全连接层Output

Output层也是全连接层,共有10个节点,分别代表数字0到9。

以上是LeNet-5的卷积神经网络的完整结构,共约有60,840个训练参数,340,908个连接。

基础知识

卷积神经网络

卷积神经网络是仿造生物的视知觉(visual perception)机制构建的。

CNN结构

输入层

输入层可以处理多维数据

隐藏层

包含:

  • 卷积层(convolutional layer): 特征提取
  • 池化层(pooling layer): 特征选择和信息过滤
  • 全连接层(fully-connected layer): 即传统前馈神经网络中的隐藏层。特征图在此层中会由原来的多维结构展开为向量,通过激励函数传递至下一层。

其中,由多个卷积层和池化层可以组成一个更高级的模块单元,称为Inception模块(Inception module)。Inception模块最早在GoogLeNet中使用的。

卷积层

  • 卷积核(convolutional kernel): 对输入数据进行特征提取
  • 卷积层参数: 大小,步长,填充
  • 激活函数(activation function)

池化层

池化函数: 将特征图中单个点的结果替换为其相邻区域的特征图统计量

常见的池化函数有:

  • Lp池化(Lp pooling)
  • 随机(stochastic pooling)/混合池化(mixed pooling)
  • 谱池化(spectral pooling)

输出层

输出层通常在全连接层后面,作用与传统前馈神经网络中输出层是一样的。

CNN分类

常见的卷积神经网络有:

  • TDNN(1987): 时间延迟网络(Time Delay Neural Network),一维卷积核,用于语音识别问题。
  • LeNet-5(LeCun et al, 1998): 用于图像分类问题的二维卷积神经网络。
  • AlexNet(2012): 2012年ILSVRC图像分类和物体识别算法优胜者。
  • ZFNet(2013): 2013年ILSVRC图像分类算法的优胜者
  • VGGNet(2014): 视觉几何团队(Visual Geometry Group, VGG)。VGG-16是2014年ILSVRC物体识别算法的优胜者。
  • GoogLeNet(2014): 2014年ILSVRC图像分类算法的优胜者。
  • ResNet(2015): 残差神经网络(Residual Network, ResNet),来自微软。2015年ILSVRC图像分类和物体识别算法的优胜者。
  • WaveNet(Van Den Oord et al, 2016): 一维卷积神经网络,用于语音建模,采用扩张卷积和跳跃连接提升了神经网络对长距离依赖的学习能力。Google Assistant使用。

CNN软件

  • TensorFlow
  • Keras
  • Thenao
  • Microsoft-CNTK
  • MATLAB

参考

  1. 《Tensorflow 实战Google深度学习框架》:这是一本书
  2. 卷积神经网络: 原理介绍,及上面的代码来源
  3. Tensorflow实例:(卷积神经网络)LeNet-5模型
  4. TensorFlow实现mnist数字识别——CNN LeNet-5模型: LeNet-5详细介绍
  5. 深度学习 CNN卷积神经网络 LeNet-5详解
  6. TensorFlow高层封装-Slim,Keras,Estimator
  7. 如何比较Keras, TensorLayer, TFLearn ?: 学习:Keras >> TFLearn >> Tensorlayer, 工作:Tensorlayer >> TFLearn

评论

发表评论

validate