网站首页 > 精选文章 正文
手写数字数据集:MNIST数据集是从NIST的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。
本文将使用飞桨PaddlePaddle,来实现手写数字识别的模型搭建、训练、推理预测全部流程。
#引入库
import paddle
import paddle.vision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore",category=DeprecationWarning)
一、定义和查看框架自带的手写数据集
#定义一个归一化操作
transform=T.Normalize(mean=[127.5],std=[127.5])
#定义训练集和验证集
train_dataset=paddle.vision.datasets.MNIST(mode="train",transform=transform)
eval_dataset=paddle.vision.datasets.MNIST(mode="test",transform=transform)
#输出训练集和验证集的数量
print(len(train_dataset),len(eval_dataset))
print(train_dataset[0][0].shape,train_dataset[0][1].shape)
输出结果:
60000 10000
(1, 28, 28) (1,)
随机取一条训练集,查看和显示结果。每个手写数字使用(1,28,28)来存储的,显示图像的时候需要reshape((28,28))
rndinx=np.random.randint(len(train_dataset))
plt.imshow(train_dataset[rndinx][0].reshape((28,28)))
plt.show()
print(train_dataset[rndinx][1])

0对应的图像

7对应的图像
二、使用paddle定义网络模型,并查看模型
这里我们使用Sequential来组网,
MNIST=paddle.nn.Sequential(
paddle.nn.Flatten(), #先Flatten展平数据,
paddle.nn.Linear(784,512),#第一个全连接层输出512
paddle.nn.ReLU(), #使用ReLU激活函数
paddle.nn.Linear(512,10) #最后一个连接层输出10个分类
)
model=paddle.Model(MNIST)
model.summary((1,28,28))
#网络模型结构
---------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===========================================================================
Flatten-3 [[1, 28, 28]] [1, 784] 0
Linear-6 [[1, 784]] [1, 512] 401,920
ReLU-2 [[1, 512]] [1, 512] 0
Linear-7 [[1, 512]] [1, 10] 5,130
===========================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 1.55
Estimated Total Size (MB): 1.57
---------------------------------------------------------------------------
{'total_params': 407050, 'trainable_params': 407050}
三、使用Paddle高级API来训练
1、模型训练
#定义模型参数
model.prepare(paddle.optimizer.Adam(learning_rate=0.001,parameters=MNIST.parameters()),
#定义损失函数
loss=paddle.nn.CrossEntropyLoss(),
#定义评估函数
metrics=paddle.metric.Accuracy())
#开启训练,迭代5次,批次大小64,
model.fit(train_dataset,
eval_dataset,
epochs=5,
batch_size=64,
verbose=1)
训练过程数据:这里只截取第5轮的训练过程
Epoch 5/5
step 938/938 [==============================] - loss: 0.1843 - acc: 0.9719 - ETA: 43s - 47ms/ste - loss: 0.0350 - acc: 0.9766 - ETA: 36s - 40ms/ste - loss: 0.0999 - acc: 0.9776 - ETA: 32s - 35ms/ste - loss: 0.0189 - acc: 0.9754 - ETA: 30s - 33ms/ste - loss: 0.0468 - acc: 0.9753 - ETA: 28s - 32ms/ste - loss: 0.0500 - acc: 0.9766 - ETA: 27s - 31ms/ste - loss: 0.0347 - acc: 0.9768 - ETA: 26s - 30ms/ste - loss: 0.0873 - acc: 0.9777 - ETA: 24s - 29ms/ste - loss: 0.0981 - acc: 0.9769 - ETA: 24s - 28ms/ste - loss: 0.1161 - acc: 0.9766 - ETA: 23s - 28ms/ste - loss: 0.1683 - acc: 0.9766 - ETA: 22s - 28ms/ste - loss: 0.0192 - acc: 0.9764 - ETA: 22s - 27ms/ste - loss: 0.1066 - acc: 0.9762 - ETA: 21s - 27ms/ste - loss: 0.0823 - acc: 0.9760 - ETA: 21s - 27ms/ste - loss: 0.0243 - acc: 0.9761 - ETA: 20s - 27ms/ste - loss: 0.0939 - acc: 0.9772 - ETA: 20s - 27ms/ste - loss: 0.0643 - acc: 0.9776 - ETA: 20s - 27ms/ste - loss: 0.2955 - acc: 0.9777 - ETA: 20s - 27ms/ste - loss: 0.0273 - acc: 0.9780 - ETA: 20s - 27ms/ste - loss: 0.0798 - acc: 0.9776 - ETA: 20s - 27ms/ste - loss: 0.0583 - acc: 0.9769 - ETA: 20s - 28ms/ste - loss: 0.0584 - acc: 0.9773 - ETA: 19s - 28ms/ste - loss: 0.0297 - acc: 0.9773 - ETA: 19s - 28ms/ste - loss: 0.1333 - acc: 0.9768 - ETA: 19s - 28ms/ste - loss: 0.0753 - acc: 0.9769 - ETA: 19s - 28ms/ste - loss: 0.0751 - acc: 0.9763 - ETA: 18s - 28ms/ste - loss: 0.1177 - acc: 0.9763 - ETA: 18s - 28ms/ste - loss: 0.1016 - acc: 0.9766 - ETA: 18s - 27ms/ste - loss: 0.0961 - acc: 0.9769 - ETA: 17s - 27ms/ste - loss: 0.0294 - acc: 0.9772 - ETA: 17s - 27ms/ste - loss: 0.0291 - acc: 0.9773 - ETA: 17s - 27ms/ste - loss: 0.0328 - acc: 0.9775 - ETA: 16s - 27ms/ste - loss: 0.0428 - acc: 0.9773 - ETA: 16s - 27ms/ste - loss: 0.0356 - acc: 0.9774 - ETA: 16s - 27ms/ste - loss: 0.0474 - acc: 0.9775 - ETA: 15s - 27ms/ste - loss: 0.0067 - acc: 0.9776 - ETA: 15s - 27ms/ste - loss: 0.0155 - acc: 0.9774 - ETA: 15s - 27ms/ste - loss: 0.0635 - acc: 0.9774 - ETA: 15s - 27ms/ste - loss: 0.0074 - acc: 0.9773 - ETA: 14s - 27ms/ste - loss: 0.0566 - acc: 0.9773 - ETA: 14s - 27ms/ste - loss: 0.0539 - acc: 0.9772 - ETA: 14s - 27ms/ste - loss: 0.2149 - acc: 0.9771 - ETA: 14s - 27ms/ste - loss: 0.2084 - acc: 0.9773 - ETA: 13s - 27ms/ste - loss: 0.0059 - acc: 0.9776 - ETA: 13s - 27ms/ste - loss: 0.0306 - acc: 0.9774 - ETA: 13s - 27ms/ste - loss: 0.0793 - acc: 0.9773 - ETA: 12s - 27ms/ste - loss: 0.1044 - acc: 0.9772 - ETA: 12s - 27ms/ste - loss: 0.0695 - acc: 0.9771 - ETA: 12s - 27ms/ste - loss: 0.0191 - acc: 0.9772 - ETA: 12s - 27ms/ste - loss: 0.0266 - acc: 0.9771 - ETA: 11s - 27ms/ste - loss: 0.1265 - acc: 0.9771 - ETA: 11s - 27ms/ste - loss: 0.0032 - acc: 0.9771 - ETA: 11s - 27ms/ste - loss: 0.0248 - acc: 0.9771 - ETA: 10s - 27ms/ste - loss: 0.0965 - acc: 0.9772 - ETA: 10s - 27ms/ste - loss: 0.0646 - acc: 0.9771 - ETA: 10s - 27ms/ste - loss: 0.0271 - acc: 0.9770 - ETA: 10s - 27ms/ste - loss: 0.1045 - acc: 0.9770 - ETA: 9s - 27ms/ste - loss: 0.0128 - acc: 0.9771 - ETA: 9s - 27ms/st - loss: 0.0825 - acc: 0.9771 - ETA: 9s - 27ms/st - loss: 0.0293 - acc: 0.9772 - ETA: 9s - 27ms/st - loss: 0.0090 - acc: 0.9773 - ETA: 8s - 27ms/st - loss: 0.0606 - acc: 0.9775 - ETA: 8s - 27ms/st - loss: 0.0867 - acc: 0.9773 - ETA: 8s - 27ms/st - loss: 0.0108 - acc: 0.9774 - ETA: 8s - 27ms/st - loss: 0.0836 - acc: 0.9773 - ETA: 7s - 27ms/st - loss: 0.0461 - acc: 0.9773 - ETA: 7s - 27ms/st - loss: 0.0333 - acc: 0.9771 - ETA: 7s - 27ms/st - loss: 0.0724 - acc: 0.9769 - ETA: 6s - 27ms/st - loss: 0.1333 - acc: 0.9769 - ETA: 6s - 27ms/st - loss: 0.1932 - acc: 0.9768 - ETA: 6s - 27ms/st - loss: 0.0654 - acc: 0.9767 - ETA: 6s - 27ms/st - loss: 0.0901 - acc: 0.9768 - ETA: 5s - 27ms/st - loss: 0.1853 - acc: 0.9766 - ETA: 5s - 27ms/st - loss: 0.0443 - acc: 0.9767 - ETA: 5s - 27ms/st - loss: 0.0390 - acc: 0.9766 - ETA: 5s - 27ms/st - loss: 0.0352 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0375 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0631 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0109 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0789 - acc: 0.9767 - ETA: 3s - 27ms/st - loss: 0.0593 - acc: 0.9767 - ETA: 3s - 27ms/st - loss: 0.0050 - acc: 0.9766 - ETA: 3s - 27ms/st - loss: 0.0182 - acc: 0.9766 - ETA: 2s - 27ms/st - loss: 0.0493 - acc: 0.9765 - ETA: 2s - 27ms/st - loss: 0.0556 .9764 - ETA: 2s - 27ms/st - loss: 0.0436 - acc: 0.9764 - ETA: 2s - 27ms/st - loss: 0.0222 - acc: 0.9764 - ETA: 1s - 27ms/st - loss: 0.0419 - acc: 0.9765 - ETA: 1s - 27ms/st - loss: 0.1046 - acc: 0.9764 - ETA: 1s - 27ms/st - loss: 0.1193 - acc: 0.9765 - ETA: 1s - 27ms/st - loss: 0.1115 - acc: 0.9765 - ETA: 0s - 27ms/st - loss: 0.0299 - acc: 0.9765 - ETA: 0s - 27ms/st - loss: 0.0296 - - acc: 0acc: 0.9765 - ETA: 0s - 27ms/st - loss: 0.1620 - acc: 0.9766 - 27ms/step
2、评估模型
result=model.evaluate(eval_dataset,verbose=100)
print(result)
评估结果:迭代5次,准确率达97.3%以上。效果还是可以的。
Eval begin...
Eval samples: 10000
{'loss': [4.768373e-07], 'acc': 0.9731}
四、保存模型
1、保存模型(训练格式),生成mnist.pdopt和mnist.pdparams两个文件,
model.save("finetuning/mnist")
2、保存模型(预测格式),生成mnist.pdiparams mnist.pdiparams.info和mnist.pdmodel三个文件。加一个training=False参数
model.save('infer/mnist', training=False)
五、预测推理
使用预测格式模型,进行预测推理
import numpy as np
from paddle.inference import Config
from paddle.inference import create_predictor
#定义配置
config = Config("infer/mnist.pdmodel", "infer/mnist.pdiparams")
config.disable_gpu()
# 创建PaddlePredictor
predictor = create_predictor(config)
# 获取输入的名称
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])
# 设置输入,随机取一条测试数据
rndinx=np.random.randint(1,len(eval_dataset))
fake_input=train_dataset[rndinx][0].reshape((1,1,28,28))
print("真实标签数字:",train_dataset[rndinx][1])
input_handle.reshape([1, 1, 28, 28])
input_handle.copy_from_cpu(fake_input)
# 运行predictor
predictor.run()
# 获取输出
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu() # numpy.ndarray类型
res=np.argmax(output_data)
print("预测标签数字:",res)
输出:
真实标签数字: [8]
预测标签数字: 8
猜你喜欢
- 2024-12-23 Apache Struts 2 漏洞被发现,因为概念验证正在传播
- 2024-12-23 纹身居然还能识别,到底是谁是坏人!
- 2024-12-23 716种二维材料扫描隧道显微镜(STM)图像数据库
- 2024-12-23 1.4亿张图像!史上最大人脸识别数据研究显示隐私泄露加剧
- 2024-12-23 材料平带数据库(Materials Flatband Database)
- 2024-12-23 推荐一款代码依赖包安全漏洞检查插件
- 2024-12-23 谷歌紧急发布Chrome浏览器更新,修复V8引擎高危类型混乱漏洞
- 2024-12-23 伽马调频频率倒谱系数(GFCC)
- 2024-12-23 美国NIST寻找抵御量子计算机攻击的算法,于2024年完成方案标准化
- 最近发表
- 标签列表
-
- 向日葵无法连接服务器 (32)
- git.exe (33)
- vscode更新 (34)
- dev c (33)
- git ignore命令 (32)
- gitlab提交代码步骤 (37)
- java update (36)
- vue debug (34)
- vue blur (32)
- vscode导入vue项目 (33)
- vue chart (32)
- vue cms (32)
- 大雅数据库 (34)
- 技术迭代 (37)
- 同一局域网 (33)
- github拒绝连接 (33)
- vscode php插件 (32)
- vue注释快捷键 (32)
- linux ssr (33)
- 微端服务器 (35)
- 导航猫 (32)
- 获取当前时间年月日 (33)
- stp软件 (33)
- http下载文件 (33)
- linux bt下载 (33)