企业项目管理、ORK、研发管理与敏捷开发工具平台

网站首页 > 精选文章 正文

47.人工智能——手写数字识别的模型搭建、训练、推理预测

wudianyun 2024-12-23 10:08:32 精选文章 54 ℃

手写数字数据集:MNIST数据集是从NIST的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。

本文将使用飞桨PaddlePaddle,来实现手写数字识别的模型搭建、训练、推理预测全部流程。

#引入库
import paddle
import paddle.vision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore",category=DeprecationWarning)

一、定义和查看框架自带的手写数据集

#定义一个归一化操作
transform=T.Normalize(mean=[127.5],std=[127.5])
#定义训练集和验证集
train_dataset=paddle.vision.datasets.MNIST(mode="train",transform=transform)
eval_dataset=paddle.vision.datasets.MNIST(mode="test",transform=transform)
#输出训练集和验证集的数量
print(len(train_dataset),len(eval_dataset))
print(train_dataset[0][0].shape,train_dataset[0][1].shape)

输出结果:

60000 10000

(1, 28, 28) (1,)

随机取一条训练集,查看和显示结果。每个手写数字使用(1,28,28)来存储的,显示图像的时候需要reshape((28,28))

rndinx=np.random.randint(len(train_dataset))
plt.imshow(train_dataset[rndinx][0].reshape((28,28)))
plt.show()
print(train_dataset[rndinx][1])


0对应的图像


7对应的图像

二、使用paddle定义网络模型,并查看模型

这里我们使用Sequential来组网,

MNIST=paddle.nn.Sequential(
    paddle.nn.Flatten(), #先Flatten展平数据,
    paddle.nn.Linear(784,512),#第一个全连接层输出512
    paddle.nn.ReLU(), #使用ReLU激活函数
    paddle.nn.Linear(512,10) #最后一个连接层输出10个分类
)
model=paddle.Model(MNIST)
model.summary((1,28,28))
#网络模型结构
---------------------------------------------------------------------------
 Layer (type)       Input Shape          Output Shape         Param #    
===========================================================================
   Flatten-3       [[1, 28, 28]]           [1, 784]              0       
   Linear-6          [[1, 784]]            [1, 512]           401,920    
    ReLU-2           [[1, 512]]            [1, 512]              0       
   Linear-7          [[1, 512]]            [1, 10]             5,130     
===========================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
---------------------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 1.55
Estimated Total Size (MB): 1.57
---------------------------------------------------------------------------
  
  {'total_params': 407050, 'trainable_params': 407050}

三、使用Paddle高级API来训练

1、模型训练

#定义模型参数
model.prepare(paddle.optimizer.Adam(learning_rate=0.001,parameters=MNIST.parameters()),
#定义损失函数
loss=paddle.nn.CrossEntropyLoss(),
#定义评估函数
metrics=paddle.metric.Accuracy())
#开启训练,迭代5次,批次大小64,
model.fit(train_dataset,
eval_dataset,
epochs=5,
batch_size=64,
verbose=1)

训练过程数据:这里只截取第5轮的训练过程

Epoch 5/5
step 938/938 [==============================] - loss: 0.1843 - acc: 0.9719 - ETA: 43s - 47ms/ste - loss: 0.0350 - acc: 0.9766 - ETA: 36s - 40ms/ste - loss: 0.0999 - acc: 0.9776 - ETA: 32s - 35ms/ste - loss: 0.0189 - acc: 0.9754 - ETA: 30s - 33ms/ste - loss: 0.0468 - acc: 0.9753 - ETA: 28s - 32ms/ste - loss: 0.0500 - acc: 0.9766 - ETA: 27s - 31ms/ste - loss: 0.0347 - acc: 0.9768 - ETA: 26s - 30ms/ste - loss: 0.0873 - acc: 0.9777 - ETA: 24s - 29ms/ste - loss: 0.0981 - acc: 0.9769 - ETA: 24s - 28ms/ste - loss: 0.1161 - acc: 0.9766 - ETA: 23s - 28ms/ste - loss: 0.1683 - acc: 0.9766 - ETA: 22s - 28ms/ste - loss: 0.0192 - acc: 0.9764 - ETA: 22s - 27ms/ste - loss: 0.1066 - acc: 0.9762 - ETA: 21s - 27ms/ste - loss: 0.0823 - acc: 0.9760 - ETA: 21s - 27ms/ste - loss: 0.0243 - acc: 0.9761 - ETA: 20s - 27ms/ste - loss: 0.0939 - acc: 0.9772 - ETA: 20s - 27ms/ste - loss: 0.0643 - acc: 0.9776 - ETA: 20s - 27ms/ste - loss: 0.2955 - acc: 0.9777 - ETA: 20s - 27ms/ste - loss: 0.0273 - acc: 0.9780 - ETA: 20s - 27ms/ste - loss: 0.0798 - acc: 0.9776 - ETA: 20s - 27ms/ste - loss: 0.0583 - acc: 0.9769 - ETA: 20s - 28ms/ste - loss: 0.0584 - acc: 0.9773 - ETA: 19s - 28ms/ste - loss: 0.0297 - acc: 0.9773 - ETA: 19s - 28ms/ste - loss: 0.1333 - acc: 0.9768 - ETA: 19s - 28ms/ste - loss: 0.0753 - acc: 0.9769 - ETA: 19s - 28ms/ste - loss: 0.0751 - acc: 0.9763 - ETA: 18s - 28ms/ste - loss: 0.1177 - acc: 0.9763 - ETA: 18s - 28ms/ste - loss: 0.1016 - acc: 0.9766 - ETA: 18s - 27ms/ste - loss: 0.0961 - acc: 0.9769 - ETA: 17s - 27ms/ste - loss: 0.0294 - acc: 0.9772 - ETA: 17s - 27ms/ste - loss: 0.0291 - acc: 0.9773 - ETA: 17s - 27ms/ste - loss: 0.0328 - acc: 0.9775 - ETA: 16s - 27ms/ste - loss: 0.0428 - acc: 0.9773 - ETA: 16s - 27ms/ste - loss: 0.0356 - acc: 0.9774 - ETA: 16s - 27ms/ste - loss: 0.0474 - acc: 0.9775 - ETA: 15s - 27ms/ste - loss: 0.0067 - acc: 0.9776 - ETA: 15s - 27ms/ste - loss: 0.0155 - acc: 0.9774 - ETA: 15s - 27ms/ste - loss: 0.0635 - acc: 0.9774 - ETA: 15s - 27ms/ste - loss: 0.0074 - acc: 0.9773 - ETA: 14s - 27ms/ste - loss: 0.0566 - acc: 0.9773 - ETA: 14s - 27ms/ste - loss: 0.0539 - acc: 0.9772 - ETA: 14s - 27ms/ste - loss: 0.2149 - acc: 0.9771 - ETA: 14s - 27ms/ste - loss: 0.2084 - acc: 0.9773 - ETA: 13s - 27ms/ste - loss: 0.0059 - acc: 0.9776 - ETA: 13s - 27ms/ste - loss: 0.0306 - acc: 0.9774 - ETA: 13s - 27ms/ste - loss: 0.0793 - acc: 0.9773 - ETA: 12s - 27ms/ste - loss: 0.1044 - acc: 0.9772 - ETA: 12s - 27ms/ste - loss: 0.0695 - acc: 0.9771 - ETA: 12s - 27ms/ste - loss: 0.0191 - acc: 0.9772 - ETA: 12s - 27ms/ste - loss: 0.0266 - acc: 0.9771 - ETA: 11s - 27ms/ste - loss: 0.1265 - acc: 0.9771 - ETA: 11s - 27ms/ste - loss: 0.0032 - acc: 0.9771 - ETA: 11s - 27ms/ste - loss: 0.0248 - acc: 0.9771 - ETA: 10s - 27ms/ste - loss: 0.0965 - acc: 0.9772 - ETA: 10s - 27ms/ste - loss: 0.0646 - acc: 0.9771 - ETA: 10s - 27ms/ste - loss: 0.0271 - acc: 0.9770 - ETA: 10s - 27ms/ste - loss: 0.1045 - acc: 0.9770 - ETA: 9s - 27ms/ste - loss: 0.0128 - acc: 0.9771 - ETA: 9s - 27ms/st - loss: 0.0825 - acc: 0.9771 - ETA: 9s - 27ms/st - loss: 0.0293 - acc: 0.9772 - ETA: 9s - 27ms/st - loss: 0.0090 - acc: 0.9773 - ETA: 8s - 27ms/st - loss: 0.0606 - acc: 0.9775 - ETA: 8s - 27ms/st - loss: 0.0867 - acc: 0.9773 - ETA: 8s - 27ms/st - loss: 0.0108 - acc: 0.9774 - ETA: 8s - 27ms/st - loss: 0.0836 - acc: 0.9773 - ETA: 7s - 27ms/st - loss: 0.0461 - acc: 0.9773 - ETA: 7s - 27ms/st - loss: 0.0333 - acc: 0.9771 - ETA: 7s - 27ms/st - loss: 0.0724 - acc: 0.9769 - ETA: 6s - 27ms/st - loss: 0.1333 - acc: 0.9769 - ETA: 6s - 27ms/st - loss: 0.1932 - acc: 0.9768 - ETA: 6s - 27ms/st - loss: 0.0654 - acc: 0.9767 - ETA: 6s - 27ms/st - loss: 0.0901 - acc: 0.9768 - ETA: 5s - 27ms/st - loss: 0.1853 - acc: 0.9766 - ETA: 5s - 27ms/st - loss: 0.0443 - acc: 0.9767 - ETA: 5s - 27ms/st - loss: 0.0390 - acc: 0.9766 - ETA: 5s - 27ms/st - loss: 0.0352 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0375 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0631 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0109 - acc: 0.9767 - ETA: 4s - 27ms/st - loss: 0.0789 - acc: 0.9767 - ETA: 3s - 27ms/st - loss: 0.0593 - acc: 0.9767 - ETA: 3s - 27ms/st - loss: 0.0050 - acc: 0.9766 - ETA: 3s - 27ms/st - loss: 0.0182 - acc: 0.9766 - ETA: 2s - 27ms/st - loss: 0.0493 - acc: 0.9765 - ETA: 2s - 27ms/st - loss: 0.0556 .9764 - ETA: 2s - 27ms/st - loss: 0.0436 - acc: 0.9764 - ETA: 2s - 27ms/st - loss: 0.0222 - acc: 0.9764 - ETA: 1s - 27ms/st - loss: 0.0419 - acc: 0.9765 - ETA: 1s - 27ms/st - loss: 0.1046 - acc: 0.9764 - ETA: 1s - 27ms/st - loss: 0.1193 - acc: 0.9765 - ETA: 1s - 27ms/st - loss: 0.1115 - acc: 0.9765 - ETA: 0s - 27ms/st - loss: 0.0299 - acc: 0.9765 - ETA: 0s - 27ms/st - loss: 0.0296 - - acc: 0acc: 0.9765 - ETA: 0s - 27ms/st - loss: 0.1620 - acc: 0.9766 - 27ms/step      

2、评估模型

result=model.evaluate(eval_dataset,verbose=100)
print(result)

评估结果:迭代5次,准确率达97.3%以上。效果还是可以的。

Eval begin...
Eval samples: 10000
{'loss': [4.768373e-07], 'acc': 0.9731}

四、保存模型

1、保存模型(训练格式),生成mnist.pdopt和mnist.pdparams两个文件,

model.save("finetuning/mnist")

2、保存模型(预测格式),生成mnist.pdiparams mnist.pdiparams.info和mnist.pdmodel三个文件。加一个training=False参数

model.save('infer/mnist', training=False)

五、预测推理

使用预测格式模型,进行预测推理

import numpy as np
from paddle.inference import Config
from paddle.inference import create_predictor
#定义配置
config = Config("infer/mnist.pdmodel", "infer/mnist.pdiparams")
config.disable_gpu()

# 创建PaddlePredictor
predictor = create_predictor(config)

# 获取输入的名称
input_names = predictor.get_input_names()
input_handle = predictor.get_input_handle(input_names[0])

# 设置输入,随机取一条测试数据
rndinx=np.random.randint(1,len(eval_dataset))
fake_input=train_dataset[rndinx][0].reshape((1,1,28,28))
print("真实标签数字:",train_dataset[rndinx][1])

input_handle.reshape([1, 1, 28, 28])
input_handle.copy_from_cpu(fake_input)

# 运行predictor
predictor.run()

# 获取输出
output_names = predictor.get_output_names()
output_handle = predictor.get_output_handle(output_names[0])
output_data = output_handle.copy_to_cpu() # numpy.ndarray类型

res=np.argmax(output_data)
print("预测标签数字:",res)

输出:

真实标签数字: [8]

预测标签数字: 8

Tags:

最近发表
标签列表