CNN网络实现垃圾分类
作者:王镇
1. 背景
自2019年7月1日起,随着《上海市生活垃圾管理条例》正式实施,垃圾分类工作在全国由点到面逐步推开。垃圾分类可以最大限度的实现垃圾资源利用,减少垃圾处置量,改善生存环境质量,降低垃圾对于地下水的污染。由于垃圾分类条例刚开始实施,很多居民还没有足够强的垃圾分类意识,生活中垃圾分类并没能得到很好的落实。因此垃圾收集站依然有很强的垃圾自动分类需求。本文通过搭建一个简单的CNN网络实现对垃圾进行自动分类。
2. 数据集
本文使用的数据集来自kaggle上的垃圾分类数据集,共2527张图片,分为六类生活垃圾,分别为cardboard 403张,glass 501张,metal 410张,paper 594张,plastic 482张,trash 137张。该数据集的图片具有相同的规格尺寸,且要检测的垃圾大多数位于图片中央,因此非常适合于训练深度学习模型。
3. 数据预处理
由于该数据集是一个较小的数据集,仅有两千多张图片,但通常训练深度学习模型都要求有上万张乃至更多的图片。去网上采集更多的图片扩充数据集显然是最理想的方法,但那样会花费大量的时间与精力,因此通过数据增强扩充数据集成了我们的选择。
ImageDataGenerator()
是keras.preprocessing.image
模块中的图片生成器,同时也可以在batch中对数据进行增强,扩充数据集大小,增强模型的泛化能力。比如对图片进行平移,旋转,缩放,翻转等。在将图片输入模型之前,为了提高模型的准确率,还需将图片进行归一化处理,即将每个像素的值映射到(0,1)之间。
实验中我们随机选择90%的图片作为训练集,对训练集进行数据增强操作,剩下10%的图片作为测试集,测试集不做数据增强处理。
4. 模型搭建
该网络使用了四层的卷积层,每一卷积层后接一最大池化层,最后紧跟两层Dense层,将输出转化为6×1的向量。网络结构如下所示:
model = Sequential([
Conv2D(filters=32, kernel_size=3, padding='same', activation='relu', input_shape=(300, 300, 3)),
MaxPooling2D(pool_size=2),
Conv2D(filters=64, kernel_size=3, padding='same', activation='relu'),
MaxPooling2D(pool_size=2),
Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),
MaxPooling2D(pool_size=2),
Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),
MaxPooling2D(pool_size=2),
Flatten(),
Dense(64, activation='relu'),
Dense(6, activation='softmax')
])
5. 结果展示
图1绘制出了该模型训练过程的学习率曲线。从图中可见模型在epoch达到40时便开始趋向收敛,之后随着训练次数的增加,模型效果也没有提升,在整个训练过程中测试集上取得的最好的准确率为79.4%。而学习率曲线波动很大原因是测试集太小,仅含两百多涨图片存在较大的偶然性。
图1 CNN的学习率曲线
图2 部分分类结果展示
图2展示了部分分类结果,其中pred为模型所判断的每张图片的所属垃圾类别,truth为每张图片真实的所属垃圾类别。从图中可见,模型对大部分的垃圾图片分类结果都比较准确,仅第四张将一个玻璃瓶误判为了金属,这也情有可原,因为该图片中玻璃瓶的反光因素使得该玻璃瓶确实看起来有点像银白色金属。
项目源码地址:https://momodel.cn/explore/5e006c3d744bdda4f67a2bc7?type=app
6. 参考文献
- kaggle数据集. https://www.kaggle.com/asdasdasasdas/garbage-classification
- https://www.kesci.com/home/project/5d26a62b688d36002c58a627/code
关于我们
Mo(网址:https://momodel.cn) 是一个支持 Python的人工智能在线建模平台,能帮助你快速开发、训练并部署模型。
Mo人工智能俱乐部 是由 Mo 的研发与产品团队发起、致力于降低人工智能开发与使用门槛的俱乐部。团队具备大数据处理分析、可视化与数据建模经验,已承担多领域智能项目,具备从底层到前端的全线设计开发能力。主要研究方向为大数据管理分析与人工智能技术,并以此来促进数据驱动的科学研究。
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。