制作数据集
手写体数字(MNIST)的基本信息在上一篇专栏( Pytorch 手写数字识别MNIST)里介绍过,这里只做简要说明
官网: yann.lecun.com/exdb/mnist/
MNIST数据集官网
该数据集下载下来的二进制格式文件无法直接打开预览
MNIST数据集文件
这里主要介绍数据集的下载、解压和保存为标准的mat文件格式。
下载地址
url1 = "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz" ; %training set images
url2 = "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz" ; %training set labels
url3 = "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz" ; %test set images
url4 = "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz" ; %test set labels
保存文件到本地
filepath1 = websave("train-images-idx3-ubyte.gz", url1);
filepath2 = websave("train-labels-idx1-ubyte.gz", url2);
filepath3 = websave("t10k-images-idx3-ubyte.gz", url3);
filepath4 = websave("t10k-labels-idx1-ubyte.gz",url4);
解压所有压缩文件
files = gunzip('*.gz'); %解压gz文件
为了方便后面的文件读写,需要按字节(1Byte = 8bit)转化为10进制数
function y = Byte2Dec(data)
bin8 = dec2bin(data,8); %按字节
byte = [bin8(1,:),bin8(2,:),bin8(3,:),bin8(4,:)];
y = bin2dec(byte);
end
%%上面是转化函数
制作测试集
fid1 = fopen(files{1});
m1 = fread(fid1,4);
n1 = fread(fid1,4);
r1 = fread(fid1,4);
c1 = fread(fid1,4);
m1 = Byte2Dec(m1);
n1 = Byte2Dec(n1);
r1 = Byte2Dec(r1);
c1 = Byte2Dec(c1);
test_imgs = cell(n1,1);
for i = 1:n1
temp = fread(fid1,r1*c1);
temp = reshape(temp,[r1,c1]);
test_imgs{i} = temp';
end
fclose(fid1);
fid2 = fopen(files{2}) ;
m2 = fread(fid2,4);
n2 = fread(fid2,4);
m2 = Byte2Dec(m2);
n2 = Byte2Dec(n2);
test_labels = zeros(n2,1);
for i = 1:n2
test_labels(i) = fread(fid2,1);
end
fclose(fid2);
for index = 1:10000
img = test_imgs{index};
label = num2str(test_labels(index));
path = fullfile('D:\mnist','testdata',label,['img',label,num2str(index),'.png']);
imwrite(img,path);
end
制作训练集
fid3 = fopen(files{3});
m1 = fread(fid3,4);
n1 = fread(fid3,4);
r1 = fread(fid3,4);
c1 = fread(fid3,4);
m1 = Byte2Dec(m1);
n1 = Byte2Dec(n1);
r1 = Byte2Dec(r1);
c1 = Byte2Dec(c1);
train_imgs = cell(n1,1);
for i = 1:n1
temp = fread(fid3,r1*c1);
temp = reshape(temp,[r1,c1]);
train_imgs{i} = temp';
end
fclose(fid3);
fid4 = fopen(files{4}) ;
m2 = fread(fid4,4);
n2 = fread(fid4,4);
m2 = Byte2Dec(m2);
n2 = Byte2Dec(n2);
train_labels = zeros(n2,1);
for i = 1:n2
train_labels(i) = fread(fid4,1);
end
fclose(fid4);
for index = 1:60000
img = train_imgs{index};
label = num2str(train_labels(index));
path = fullfile('D:\mnist','traindata',label,['img',label,num2str(index),'.png']);
imwrite(img,path);
end
保存标准mat文件(变量)
train_labels = categorical(train_labels);
test_labels = categorical(test_labels);
save minist.mat train_imgs train_labels test_imgs test_labels
以上就是MATLAB下导入一般的数据集文件,可能不是一般意义上做实验、数据标定来制作数据集。
注:以上代码运行可能会在imwrite函数下报错,可以手动提前准备文件目录,也可以使用绝对路径并添加到MATLAB的搜索路径下
导入数据集(实际上面已经有了,这里假设刚开始只有mat文件)
load minist.mat
traindata = table(train_imgs,train_labels);
testdata = table(test_imgs,test_labels);
搭建网络,训练模型,进行预测
layers = [
imageInputLayer([28,28,1])
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
fullyConnectedLayer(10)
softmaxLayer
classificationLayer
];
options = trainingOptions('adam',...
'ExecutionEnvironment', 'gpu', ...
'InitialLearnRate',0.01,...
'MiniBatchSize',100,...
'MaxEpochs',2,...
'Shuffle','every-epoch',...
'ValidationData',testdata,...
'ValidationFrequency',50,...
'Verbose',false,...
'Plots','training-progress');
net = trainNetwork(traindata,layers,options);
Training Progress
pred_labels = classify(net,table(test_imgs));
accuracy = sum(pred_labels == test_labels)/length(test_labels)
plotconfusion(test_labels,pred_labels) %不推荐使用在categories类的标签的分类问题上
ConfusionMatrix
figure
cm = confusionchart(test_labels,pred_labels); %推荐使用
cm.ColumnSummary = 'column-normalized';
cm.RowSummary = 'row-normalized';
cm.Title = 'MNIST Confusion Matrix';
由上图可知,这样一个简单的神经网络的预测分类准确率高达97%,不算太高,一些典型神经网络在图像分类问题准确率高达99%以上。上一篇Pytorch训练的CNN就是这样的。