08. 使用训练好的CNN手写数字识别器

  训练好卷积神经网络模型net后,我们可以将其保存到本地磁盘中,这样在以后想要使用这个模型的时候,不用重复训练,可以直接加载到内存中使用。

1. 期望目标:

① 保存并加载训练好的模型

② 识别自己手写的数字图像

2. 模型的保存与加载

2-1. 保存模型

# 保存模型
path = './模型/3.MNIST/MNIST.model' #想要保存模型到本地的位置及其文件名(需要加后缀,如前面的'.model')
torch.save(net, path) #保存

  这样可以将模型的参数和结构全部保存下来,使用时只需加载自己定义的类即可。

2-2. 加载模型

import torch # 加载用PyTorch的函数保存的模型文件
import import_ipynb #将ipynb文件加载到内存中
from 006 import ConvNet #这行按需自定义,其中'006'表示ipynb的文件名,'ConvNet'表示训练模型时自己定义的类,如有多个可以用英文逗号隔开,当然也可以用*代替

# 加载模型
path = './模型/3.MNIST/MNIST.model'
net = torch.load(path)

  成功执行后,net就是我们加载进来的模型了,可以直接使用。

3. 识别自己的手写数字图像

  找一张白纸,用黑笔写下想要识别的数字,手动将其转化为灰度图,大小为28×28像素,本例的手写数字为:

quZ9fA.jpg

import matplotlib.pyplot as plt

# 打印出手写的全部图像
plt.figure(figsize=(15, 7))
for i in range(10):
    str = './数据集/3.MNIST/测试/new_' + '%s' % i + '.bmp'
    img = plt.imread(str)
    plt.subplot(1, 10, i + 1)
    plt.imshow(img)

q1PGj0.png

  想要用加载的net模型识别手写数字,我们需要将输入转变为4维张量,上述数字共10个,因此我们需要将这些像素值全部叠加到一起(当然可以不叠加,一个一个输入进行识别也是没有问题的)。

# 将所有图像拼接到一起直接输入网络来测试效果
for i in range(10):
    str = './数据集/3.MNIST/测试/new_' + '%s' % i + '.bmp' #文件名
    img = plt.imread(str)
    test = torch.FloatTensor(img)
    test = test.reshape(1,1,28,28)
    
    if i == 0:
        new = test
    else:
        new = torch.cat((new, test), 0) #在第 0 维将张量拼接起来

# 观察一下拼接后的四维张量尺寸
new.size()
torch.Size([10, 1, 28, 28])

查看识别后的结果:

# 查看识别结果
out = net(new) #获得模型的分类输出
pred = torch.max(out.data, 1)[1] #得到分类值
print('识别的结果为:{}'.format(pred[:]))
识别的结果为:tensor([1, 4, 8, 6, 2, 0, 3, 6, 3, 5])

  在Jupyter Notebook中,结果显示为:

q1iUMt.png

  这几个数字全部预测正确了,但是,据说69不太容易被识别,让我们多写几个 6 和 9 来试试:

qMd6L6.jpg

# 将所有图像拼接到一起直接输入网络来测试效果
for i in range(11, 22):
    str = './数据集/3.MNIST/测试/new_' + '%s' % i + '.bmp' #文件名
    img = plt.imread(str)
    test = torch.FloatTensor(img)
    test = test.reshape(1,1,28,28)
    
    if i == 11:
        new = test
    else:
        new = torch.cat((new, test), 0) #在第 0 维将张量拼接起来

# 查看结果
out = net(new) #获得模型的分类输出
pred = torch.max(out.data, 1)[1] #得到分类值
print('识别的结果为:{}'.format(pred[:]))
识别的结果为:tensor([6, 5, 6, 4, 9, 4, 9, 9, 6, 8, 3])

q1Flyq.png

  这次的结果为:11个里面对了6个,感觉不太理想,原因是什么呢?

  我们训练是国外整理的MNIST数据集,里面包含60000张训练图像,10000张测试图像,西方在手写6和9时,画的圈比较小、整体靠近图像的中心,并且他们写9的最后一笔不像我们一样这么直,所以9的识别率确实是比较低。要想提高模型的精度,需要加入符合东方习惯的图像数据集才能有本质突破。

打赏
文章目录