MATLAB手写数字识别实战:MNIST数据集‌

制作数据集

手写体数字(MNIST)的基本信息在上一篇专栏( Pytorch 手写数字识别MNIST)里介绍过,这里只做简要说明

官网:   yann.lecun.com/exdb/mnist/

MNIST数据集官网

该数据集下载下来的二进制格式文件无法直接打开预览

MNIST数据集文件

这里主要介绍数据集的下载、解压和保存为标准的mat文件格式。

下载地址

url1 = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz" ;  %training set images

url2 = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz" ;  %training set labels

url3 = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz"  ;  %test set images

url4 = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz"  ;  %test set labels

保存文件到本地

filepath1 = websave("train-images-idx3-ubyte.gz", url1);

filepath2 = websave("train-labels-idx1-ubyte.gz", url2);

filepath3 = websave("t10k-images-idx3-ubyte.gz", url3);

filepath4 = websave("t10k-labels-idx1-ubyte.gz",url4);

解压所有压缩文件

files =  gunzip('*.gz'); %解压gz文件

为了方便后面的文件读写,需要按字节(1Byte = 8bit)转化为10进制数

function y = Byte2Dec(data)

    bin8 = dec2bin(data,8);  %按字节

    byte = [bin8(1,:),bin8(2,:),bin8(3,:),bin8(4,:)];

    y = bin2dec(byte);

end

%%上面是转化函数

制作测试集

fid1 = fopen(files{1});

m1 = fread(fid1,4);

n1 = fread(fid1,4);

r1 = fread(fid1,4);

c1 = fread(fid1,4);

m1 = Byte2Dec(m1);

n1 = Byte2Dec(n1);

r1 = Byte2Dec(r1);

c1 = Byte2Dec(c1);

test_imgs = cell(n1,1);

for i = 1:n1

   temp = fread(fid1,r1*c1);

   temp = reshape(temp,[r1,c1]);

   test_imgs{i} = temp';

end

fclose(fid1);

fid2 =  fopen(files{2}) ;

m2 = fread(fid2,4);      

n2 = fread(fid2,4);

m2 = Byte2Dec(m2);

n2 = Byte2Dec(n2);

test_labels = zeros(n2,1);

for i = 1:n2

   test_labels(i) = fread(fid2,1);

end

fclose(fid2);

for index = 1:10000

   img = test_imgs{index};

   label = num2str(test_labels(index));

   path = fullfile('D:\mnist','testdata',label,['img',label,num2str(index),'.png']);

   imwrite(img,path);

end

制作训练集

fid3 = fopen(files{3});

m1 = fread(fid3,4);

n1 = fread(fid3,4);

r1 = fread(fid3,4);

c1 = fread(fid3,4);

m1 = Byte2Dec(m1);

n1 = Byte2Dec(n1);

r1 = Byte2Dec(r1);

c1 = Byte2Dec(c1);

train_imgs = cell(n1,1);

for i = 1:n1

   temp = fread(fid3,r1*c1);

   temp = reshape(temp,[r1,c1]);

   train_imgs{i} = temp';

end

fclose(fid3);

fid4 =  fopen(files{4}) ;

m2 = fread(fid4,4);      

n2 = fread(fid4,4);

m2 = Byte2Dec(m2);

n2 = Byte2Dec(n2);

train_labels = zeros(n2,1);

for i = 1:n2

   train_labels(i) = fread(fid4,1);

end

fclose(fid4);

for index = 1:60000

   img = train_imgs{index};

   label = num2str(train_labels(index));

   path = fullfile('D:\mnist','traindata',label,['img',label,num2str(index),'.png']);

   imwrite(img,path);

end

保存标准mat文件(变量)

train_labels = categorical(train_labels);

test_labels = categorical(test_labels);

save minist.mat train_imgs train_labels test_imgs test_labels 

以上就是MATLAB下导入一般的数据集文件,可能不是一般意义上做实验、数据标定来制作数据集。

注:以上代码运行可能会在imwrite函数下报错,可以手动提前准备文件目录,也可以使用绝对路径并添加到MATLAB的搜索路径下

cut-off

导入数据集(实际上面已经有了,这里假设刚开始只有mat文件)

load minist.mat

traindata = table(train_imgs,train_labels);

testdata = table(test_imgs,test_labels);

搭建网络,训练模型,进行预测

layers = [

    imageInputLayer([28,28,1])

    convolution2dLayer(3,16,'Padding','same')

    batchNormalizationLayer

    reluLayer

    maxPooling2dLayer(2,'Stride',2)

    fullyConnectedLayer(10)

    softmaxLayer

    classificationLayer

];

options = trainingOptions('adam',...

    'ExecutionEnvironment', 'gpu', ...

    'InitialLearnRate',0.01,...

    'MiniBatchSize',100,...

    'MaxEpochs',2,...

    'Shuffle','every-epoch',...

    'ValidationData',testdata,...

    'ValidationFrequency',50,...

    'Verbose',false,...

    'Plots','training-progress');

net = trainNetwork(traindata,layers,options);

Training Progress

pred_labels = classify(net,table(test_imgs));

accuracy = sum(pred_labels == test_labels)/length(test_labels)

plotconfusion(test_labels,pred_labels)  %不推荐使用在categories类的标签的分类问题上

ConfusionMatrix

figure

cm = confusionchart(test_labels,pred_labels);  %推荐使用

cm.ColumnSummary = 'column-normalized';

cm.RowSummary = 'row-normalized';

cm.Title = 'MNIST Confusion Matrix';

由上图可知,这样一个简单的神经网络的预测分类准确率高达97%,不算太高,一些典型神经网络在图像分类问题准确率高达99%以上。上一篇Pytorch训练的CNN就是这样的。

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空