目录
效果展示
任务目标:使用 Oxford Flowers 102 数据集训练一个模型识别花的种类(共102类)
第 1 步:打开并设置 Colab 环境
第 2 步:导入必要的库
第 3 步:加载 Oxford Flowers 102 数据集
第 4 步:可视化几张图像
第 5 步:数据预处理函数
第 6 步:构建模型(使用预训练 MobileNetV2)
第 7 步:编译模型
第 8 步:训练模型
训练结果分析
1. 前几轮进步很快
2. 训练集精度飞升,验证集缓慢上升
第 9 步:训练过程可视化
第 10 步:评估模型在测试集上的准确率
(可选)第 11 步:单张图片预测
你现在已经完成了花卉图像识别的完整模型训练!
(1)保存训练好的模型并下载到本地。
🔸 第一步:保存模型为文件(在 Colab 中)
🔸 第二步:下载模型文件到本地
(2)调用模型
本地完整 Python 代码:加载模型并识别文件夹中所有图片
注意事项
MobileNetV2模型介绍
1. 倒残差结构(Inverted Residuals)
2. 线性瓶颈(Linear Bottleneck)
3. 高效计算设计
模型结构示例(简化版)
性能对比(ImageNet)
效果展示
任务目标:使用 Oxford Flowers 102 数据集训练一个模型识别花的种类(共102类)
第 1 步:打开并设置 Colab 环境
进入 https://colab.research.google.com
创建一个新的 Python 3 Notebook
为防止断线,可以点击菜单栏:代码执行器 -> 更改运行时类型 -> GPU(建议选 GPU)
这一步是为了准备训练的“工作环境”。Colab 是谷歌提供的在线 Python 运行平台,使用它可以不用自己配置环境,而且还能免费用到 GPU(图形处理器),这在训练图像识别模型时会让速度快很多。打开 Colab 后建议设置一下运行类型,选 GPU,这样后面训练会更流畅。
第 2 步:导入必要的库
我们需要用到一些库,比如 TensorFlow(做深度学习用的),还有 TensorFlow Datasets(用来方便加载公开数据集的工具),Matplotlib 和 Numpy 也是常用的可视化和数值处理工具。简单说,这一步就是把后面会用到的“工具箱”准备好。
在 Colab 中运行以下代码:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
第 3 步:加载 Oxford Flowers 102 数据集
这一步是整个任务的基础:我们要做花卉识别,就需要有图片数据。Oxford Flowers 102 是一个公开的数据集,里面有 102 种花,每种花有几十张图片。我们用 TensorFlow Datasets 提供的接口来加载这个数据集,并且会自动帮我们把数据按训练集、验证集和测试集分开,后面模型训练、验证和评估都要用到这些。
# 加载数据集并自动分成训练/验证/测试集
(ds_train, ds_val, ds_test), ds_info = tfds.load(
'oxford_flowers102',
split=['train', 'validation', 'test'],
shuffle_files=True,
as_supervised=True, # 返回 (image, label)
with_info=True
)
第 4 步:可视化几张图像
在正式训练之前,我们最好先看看这个数据集里的图片是不是真的是我们需要的。比如看看图像尺寸是否一致、颜色正不正常、标签有没有问题。这一步就做一个简单的可视化,确保我们拿到的是正确的数据,并且对后面的预处理有点印象。
for image, label in ds_train.take(6):
plt.figure()
plt.imshow(image)
plt.title(f"Label: {label.numpy()}")
plt.axis("off")
第 5 步:数据预处理函数
原始图片尺寸大小不一,而且像素值的范围也不是我们训练模型需要的格式,所以这一步的任务是统一图像大小,并把像素值转换成 0 到 1 之间的浮点数。这样一来数据就“干净整齐”了,模型也更容易训练。处理完之后,我们还会把数据打成“批次”,方便模型每次读取一组图片来学习。
IMG_SIZE = 224 # 输入图像尺寸
def preprocess(image, label):
image = tf.image.resize(image, [IMG_SIZE, IMG_SIZE])
image = tf.cast(image, tf.float32) / 255.0 # 归一化
return image, label
应用预处理并打包成批次:
BATCH_SIZE = 32
train_ds = ds_train.map(preprocess).shuffle(1000).batch(BATCH_SIZE).prefetch(1)
val_ds = ds_val.map(preprocess).batch(BATCH_SIZE).prefetch(1)
test_ds = ds_test.map(preprocess).batch(BATCH_SIZE).prefetch(1)
第 6 步:构建模型(使用预训练 MobileNetV2)
我们选用 MobileNetV2 【在文章末尾有详细的介绍】作为模型的骨架,这是一个在图像识别中表现不错的轻量模型,而且已经在大规模数据集上训练过了,我们可以直接拿来用,节省训练时间。在它的基础上加一点自己的分类层,这样模型最后就可以输出我们需要的 102 个花卉类别了。这一步的目标就是:把识别“花的种类”这个问题翻译成一个神经网络可以解决的问题。
base_model = tf.keras.applications.MobileNetV2(
input_shape=(IMG_SIZE, IMG_SIZE, 3),
include_top=False,
weights='imagenet'
)
base_model.trainable = False # 冻结预训练层
# 构建模型
model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(102, activation='softmax') # 102类花
])
第 7 步:编译模型
有了模型之后,还需要告诉它怎么学习,也就是配置神经网络模型的训练参数。这一步就是设置优化器(比如用 Adam)、损失函数(我们用交叉熵,它适合分类问题)和我们关心的评价指标(准确率)。这些设置相当于告诉模型“你要怎么去学习”和“怎么判断你学得好不好”。
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
第 8 步:训练模型
这一步就像把之前的准备全部连接起来,让模型真正开始“学习”数据了。我们把处理好的训练数据送进模型里,让它一轮轮地学,同时用验证集来检查它有没有过拟合、是否在变好。这一步模型会不断地更新参数,从而逐步提高识别花的能力。
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=10
)
你可以根据训练速度选择是否加长训练周期(例如 20-30 轮)。
根据这里的数据来看,这点量对于这张T4显卡来说太简单了
训练结果分析
1. 前几轮进步很快
第一轮 accuracy 是 6%,val_accuracy 就达到了 35%,说明模型一开始是靠着迁移学习的权重快速建立起了基础。
第二轮直接跳到 61%,说明学习率起作用了,模型快速适应了你的花卉数据集。
2. 训练集精度飞升,验证集缓慢上升
从第三轮开始,训练集精度很快接近 100%,但验证集从 72% 慢慢到 79%,上升变得缓慢。
最后几轮训练集完全拟合了,但验证集没有太大提升(甚至有轻微震荡),说明模型在你训练的这部分数据上已经学得太好了,但泛化能力不足。
但是由于我们只是教学作用,所以我在这里就不精调模型了,如果大家想看的话,欢迎在评论区评论,如果想看的人多的话,我会发表这篇文章。
第 9 步:训练过程可视化
如上图。
训练完后,我们可以把模型每一轮的准确率画成图,来直观地看它有没有进步。比如训练准确率是一直升高的,但验证准确率开始上升后又下降了,那可能就是过拟合了。通过这一步,我们可以对模型的表现有个更清楚的了解。
plt.plot(history.history['accuracy'], label='Train Acc')
plt.plot(history.history['val_accuracy'], label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training vs Validation Accuracy')
plt.show()
第 10 步:评估模型在测试集上的准确率
模型在训练和验证集上表现好不代表它就真的学会了花的种类,所以我们还需要用“没见过的数据”来测试它,这就是测试集。我们看看模型在测试集上的准确率如何,这个结果才更接近于实际使用时的效果。
test_loss, test_acc = model.evaluate(test_ds)
print(f"Test accuracy: {test_acc:.2f}")
还是和刚才的结论差不多,结果并不太好,但是如果只是自己玩玩的话,还是很不错了。
(可选)第 11 步:单张图片预测
作为收尾,我们可以从测试集中拿一张图片,看看模型是否能正确识别。这就像是我们在检验它是否真的“看懂”了花。这一步也可以自己上传一张图片来测试,更加直观地验证模型的预测能力。
class_names = [
"pink primrose", "hard-leaved pocket orchid", "canterbury bells", "sweet pea", "english marigold",
"tiger lily", "moon orchid", "bird of paradise", "monkshood", "globe thistle",
"snapdragon", "colt's foot", "king protea", "spear thistle", "yellow iris",
"globe-flower", "purple coneflower", "peruvian lily", "balloon flower", "giant white arum lily",
"fire lily", "pincushion flower", "fritillary", "red ginger", "grape hyacinth",
"corn poppy", "prince of wales feathers", "stemless gentian", "artichoke", "sweet william",
"carnation", "garden phlox", "love in the mist", "mexican aster", "alpine sea holly",
"ruby-lipped cattleya", "cape flower", "great masterwort", "siam tulip", "tree poppy",
"gazania", "azalea", "water lily", "rose", "tree marigold",
"anthurium", "frangipani", "geranium", "orange", "primula",
"bishop's hat", "gerbera", "cornflower", "zinnia", "lily of the valley",
"poppy", "toad lily", "anemone", "watercress", "canna lily",
"hippeastrum", "bird of paradise", "cineraria", "cyclamen", "daffodil",
"dahlia", "freesia", "gladiolus", "iris", "lilac",
"lily", "lupine", "magnolia", "marigold", "narcissus",
"orchid", "pansy", "peony", "petunia", "phlox",
"poppy", "rose", "salvia", "sunflower", "sweet pea",
"tulip", "violet", "water lily", "wisteria", "yellow iris"
]
# 假设已经定义了class_names列表和模型model,并且有测试集ds_test
for image, label in ds_test.take(1):
# 预处理图片(调整大小、归一化、扩展batch维度)
img = tf.image.resize(image, [IMG_SIZE, IMG_SIZE]) / 255.0
img = tf.expand_dims(img, axis=0)
# 模型预测
prediction = model.predict(img)
predicted_class = tf.argmax(prediction, axis=1).numpy()[0]
# 获取真实花名和预测花名
true_name = class_names[label.numpy()]
pred_name = class_names[predicted_class]
# 显示图片和标题(显示花名)
plt.imshow(image)
plt.title(f"True: {true_name}, Predicted: {pred_name}")
plt.axis("off")
plt.show()
你现在已经完成了花卉图像识别的完整模型训练!
然后我们准备使用这个模型在本地。
(1)保存训练好的模型并下载到本地。
🔸 第一步:保存模型为文件(在 Colab 中)
我们已经完成训练,我们将其保存为 .h5 文件:
model.save("flower_model.h5")
这一步会在 Colab 的当前工作目录下生成一个文件:flower_model.h5。
🔸 第二步:下载模型文件到本地
在 Colab 中运行下面的代码,把刚才保存的文件下载到你的电脑:
from google.colab import files
files.download("flower_model.h5")
执行后会弹出下载窗口,你点击保存即可。
但是这个方法通常很慢,建议采用这篇博客的方法会快超级多。
如何快速的从Google colab 中下载文件(亲测好用)-CSDN博客
这样你就得到了训练好的模型文件,接下来就可以把它用于本地识别图片的 Python 代码中了。
(2)调用模型
好的,下面是在本地调用模型并批量识别图片的完整 Python 代码。我们已经从 Colab 下载了 .h5 格式的模型文件,并将其保存在本地(flower_model.h5),同时你有一个文件夹(比如叫 test_images/,看你自己的了)里面放的是要识别的花的图片。
本地完整 Python 代码:加载模型并识别文件夹中所有图片
import os
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.mobilenet_v2 import preprocess_input
from collections import Counter
# ===== 第一步:加载模型 =====
model_path = 'D:\\IDE\\PyCharmIDE\\projectPosition\\HandCatch\\all\\BOKE_MODEL\\flower_model.h5'
model = load_model(model_path)
print("模型加载成功。")
# ===== 第二步:设置图片文件夹路径 =====
folder_path = './images' # 图片文件夹路径
image_size = (224, 224) # MobileNetV2 的默认输入尺寸
# ===== 第三步:加载标签名 =====
class_names = [
'pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'english marigold',
'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', 'snapdragon',
'colt\'s foot', 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower',
'peruvian lily', 'balloon flower', 'giant white arum lily', 'fire lily', 'pincushion flower',
'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers',
'stemless gentian', 'artichoke', 'sweet william', 'carnation', 'garden phlox', 'love in the mist',
'mexican aster', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort',
'siam tulip', 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue',
'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 'wild pansy',
'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia',
'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush',
'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 'tree poppy',
'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower',
'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose',
'tree mallow', 'magnolia', 'cyclamen', 'watercress', 'canna lily', 'hippeastrum', 'bee balm',
'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia',
'blanket flower', 'trumpet creeper', 'blackberry lily'
]
# ===== 第四步:处理图像函数(保留原图) =====
def preprocess_image(img_path):
img = image.load_img(img_path, target_size=image_size)
original_array = image.img_to_array(img).astype('uint8') # 原始像素值
img_array = np.expand_dims(original_array.copy(), axis=0) # 复制用于处理
img_array = preprocess_input(img_array) # 归一化供模型使用
return original_array, img_array
# ===== 第五步:遍历文件夹,预测并显示图片 =====
results = []
for filename in os.listdir(folder_path):
if filename.lower().endswith(('.jpg', '.jpeg', '.png')):
img_path = os.path.join(folder_path, filename)
original_img, img_input = preprocess_image(img_path)
prediction = model.predict(img_input)
predicted_index = np.argmax(prediction[0])
predicted_label = class_names[predicted_index]
results.append(predicted_label)
# 显示原图和预测结果
plt.figure(figsize=(4, 4))
plt.imshow(original_img.astype('uint8')) # 显示原图像
plt.title(f"{filename}\n预测: {predicted_label}", fontsize=10)
plt.axis('off')
plt.tight_layout()
plt.show()
# ===== 第六步:预测汇总统计 =====
summary = Counter(results)
print("\n预测汇总结果:")
for label, count in summary.items():
print(f"{label}: {count} 张")
注意事项
class_names:上面用的是假名 "class_0" 到 "class_101"。如果你在训练时保存了真实花卉名字的列表,可以直接加载进来替换;
图片尺寸要和训练时一致(MobileNetV2 是 224x224);
如果有中文路径或中文文件名,可能需要处理编码问题,推荐路径中尽量使用英文;
可进一步保存结果为 CSV 或 Excel 文件,如果需要我可以加上。
MobileNetV2模型介绍
MobileNetV2是Google在2018年提出的轻量级卷积神经网络,主要面向移动端和嵌入式设备的计算机视觉任务。其核心创新体现在以下三个方面:
1. 倒残差结构(Inverted Residuals)
传统残差块的中间层通道数较少,而MobileNetV2采用反向设计:
其中扩张层通过 t 倍通道扩展(典型 t=6 ),例如输入通道经过扩张变为
2. 线性瓶颈(Linear Bottleneck)
在压缩层采用线性激活而非ReLU,避免低维空间的信息丢失。数学证明:当输入流形维度足够高时,ReLU变换才能保持信息完整性。
3. 高效计算设计
深度可分离卷积的计算量对比常规卷积: 其中为输出通道数,为卷积核尺寸
模型结构示例(简化版)
class InvertedResidual(nn.Module):
def __init__(self, in_ch, out_ch, stride, expand_ratio):
super().__init__()
hidden_ch = in_ch * expand_ratio
self.conv = nn.Sequential(
# 扩张层
nn.Conv2d(in_ch, hidden_ch, 1, bias=False),
nn.BatchNorm2d(hidden_ch),
nn.ReLU6(),
# 深度卷积
nn.Conv2d(hidden_ch, hidden_ch, 3, stride, 1, groups=hidden_ch, bias=False),
nn.BatchNorm2d(hidden_ch),
nn.ReLU6(),
# 压缩层
nn.Conv2d(hidden_ch, out_ch, 1, bias=False),
nn.BatchNorm2d(out_ch)
)
性能对比(ImageNet)
模型Top-1 AccParams(M)FLOPs(B)MobileNetV170.6%4.20.575MobileNetV272.0%3.40.300
该模型广泛应用于:
移动端图像分类(输入尺寸时仅需300M FLOPs)实时目标检测(SSD-MobileNetV2)语义分割(DeepLabv3+骨干网络)人脸识别等资源敏感场景
其设计思想通过空间维度与通道维度的解耦操作,在保证精度的同时实现计算效率的显著提升。