博客专栏

EEPW首页 > 博客 > 热文 | 卷积神经网络入门案例,轻松实现花朵分类(2)

热文 | 卷积神经网络入门案例,轻松实现花朵分类(2)

发布人:AI科技大本营 时间:2021-05-15 来源:工程师 发布文章

构建模型

常见卷积神经网络(CNN),主要由几个 卷积层Conv2D 和 池化层MaxPooling2D 层组成。卷积层与池化层的叠加实现对输入数据的特征提取,最后连接全连接层实现分类。

特征提取——卷积层与池化层

实现分类——全连接层

CNN 的输入是张量 (Tensor) 形式的 (image_height, image_width, color_channels),包含了图像高度、宽度及颜色信息。通常图像使用 RGB 色彩模式,color_channels 为 (R,G,B) 分别对应 RGB 的三个颜色通道,即:image_height 和 image_width 根据图像的像素高度、宽度决定color_channels是3,对应RGB的3通道。

花朵数据集中的图片,形状是 (180, 180, 3),我们可以在声明第一层时将形状赋值给参数 input_shape 。

num_classes = 5
model = Sequential([
  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
  layers.Conv2D(16, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(32, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Conv2D(64, 3, padding='same', activation='relu'),
  layers.MaxPooling2D(),
  layers.Flatten(),
  layers.Dense(128, activation='relu'),
  layers.Dense(num_classes)
])

该模型由三个卷积块组成,每个卷积块中包括2D卷积层+最大池化层。最后有一个全连接层,有128个单元,可以通过relu激活功能激活该层。

编译模型

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

查看一下网络模型:tf.keras.utils.plot_model(model) ,或者用这样方式看看:model.summary()

3.png

训练模型

这里我们输入准备好的训练集数据(包括图像、对应的标签),测试集的数据(包括图像、对应的标签),模型一共训练10次。

epochs=10
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs
)

下图是训练过程的截图:

4.png

通常loss越小越好,对了解释下什么是loss;简单来说是 模型预测值 和 真实值 的相差的值,反映模型预测的结果和真实值的相差程度;通常准确度accuracy 越高,模型效果越好。

评估模型

在训练和验证集上创建损失和准确性图。

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

5.png

从图中可以看出,训练精度和验证精度相差很大,模型仅在验证集上获得了约60%的精度。

训练精度随时间增长,而验证精度在训练过程中停滞在60%左右。训练和验证准确性之间的准确性差异很明显,这是过拟合的标志。

可能过拟合出现的原因 :当训练示例数量很少时,像这次的只有3000多张图片,该模型有时会从训练示例中的噪音或不必要的细节中学习,从而模型在新示例上的性能产生负面影响。 

使用模型

通常使用 model.predict( )  函数进行预测。 

优化模型、重新构建模型、训练模型、使用模型

这里的优化模型,主要是针对模型出现“过拟合”的问题。

过拟合

模型将过度拟合训练数据,在训练集上达到较高的准确性,但在未见的数据(测试集)上得到比较低的准确性;模型的“泛化能力”不足。

我们训练模型的主要目的,也是希望模型在未见数据的预测上能有较高的准确性;解决过拟合问题是比较重要的。

解决过拟合的思路

使用更完整的训练数据。(最好的解决方案)

使用正则化之类的技术。

简化神经网络结构。

使用更完整的训练数据,数据集应涵盖模型应处理的所有输入范围。仅当涉及新的有趣案例时,其他数据才有用。

比如:在训练集的花朵图片都是近距离拍摄的,测试集的花朵有部分是远距离拍摄,训练出来的模型,自然在测试集的准确度不高了;如果一开始在训练集也包含部分远距离的花朵图片,那么模型在测试集时准确度会较高,基本和训练集的准确度接近。

使用正规化等技术,这些限制了模型可以存储的信息的数量和类型。如果一个网络只能记住少量的模式,优化过程将迫使它专注于最突出的模式,这些模式更有可能很好地概括。

简化神经网络结构,如果训练集比较小,网络结构又十分复杂,使得模型过度拟合训练数据,这时需要考虑简化模型了。减少一些神经元数量,或减少一些网络层。

结合上面的例子,使用数据增强和正则化技术,来优化网络。

数据增强

通过对已有的训练集图片 随机转换(反转、旋转、缩放等),来生成其它训练数据。这有助于将模型暴露在数据的更多方面,并更好地概括。

这里使用 tf.layers.experimental.preprocessing 层实现数据增强。

data_augmentation = keras.Sequential(
  [
    layers.experimental.preprocessing.RandomFlip("horizontal", 
                                                 input_shape=(img_height, 
                                                              img_width,
                                                              3)),
    layers.experimental.preprocessing.RandomRotation(0.1),
    layers.experimental.preprocessing.RandomZoom(0.1),
  ]
)

RandomFlip("horizontal", input_shape=(img_height,  img_width, 3)) 指定输入图片,并对图片进行随机水平反转

RandomRotation(0.1) 对图片进行随机旋转

RandomZoom(0.1)     对图片进行随机缩放

通过将数据增强应用到同一图像中几次来可视化几个增强示例的外观:

plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
  for i in range(9):
    augmented_images = data_augmentation(images)
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(augmented_images[0].numpy().astype("uint8"))
    plt.axis("off")

7.png

*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。

单片机相关文章:单片机教程


单片机相关文章:单片机视频教程


单片机相关文章:单片机工作原理


汽车防盗机相关文章:汽车防盗机原理


关键词: AI

相关推荐

技术专区

关闭