MATLAB机器学习入门:神经网络MNIST手写数字识别‌

代码讲解视频:

相关网站:MNIST数据集(http://yann.lecun.com/exdb/mnist/)

MNIST网站截图

运行界面1-部分数据展示

运行界面2-学习过程

运行界面Part1-部分数据展示

说明:将MNIST数据集网站的4个文件解压后放到源代码路径下即可

源代码:(工地英语注释QwQ)

clear;

% data read
fImg = fopen('train-images.idx3-ubyte');
offset = fseek(fImg,0,'bof');
magicNumber = swapbytes(uint32(fread(fImg,1,'uint32')));
dataNumber = swapbytes(uint32(fread(fImg,1,'uint32')));
rowN = swapbytes(uint32(fread(fImg,1,'uint32')));
colN = swapbytes(uint32(fread(fImg,1,'uint32')));
%%
% get all train data
dataImg = double(swapbytes(uint8(fread(fImg,[rowN*colN,dataNumber],'uint8'))));
fclose(fImg);
fLab = fopen('train-labels.idx1-ubyte');
offset = fseek(fLab,8,'bof');
dataLab = swapbytes(uint8(fread(fLab,[dataNumber],'uint8')));
dataLabMat = zeros(10,dataNumber);
fclose(fLab);
for i=1:dataNumber
    dataLabMat(uint32(dataLab(i))+i*10-9) = 1;
end
%%
% show MNIST handwriting image

%
for i=1:10
    % imshow(show)
    imagesc(reshape(dataImg(:,i),[rowN,colN]).')
    title(['This image is ',num2str(dataLab(i)),'    (',num2str(i),'/10)'])
    pause(0.5);
end
%}
%%
% NeuralNets modeling 1 -random initialize
% layerN is the modle definition, which tells you homany units per layer
layerN = [rowN*colN 64 16 10];

w1 = double((rand(layerN(1),layerN(2))-0.5)/512);
b1 = double((rand(layerN(2),1)-0.5)/512);

w2 = double((rand(layerN(2),layerN(3))-0.5)/4);
b2 = double((rand(layerN(3),1)-0.5)/4);

w3 = double((rand(layerN(3),layerN(4))-0.5)/16);
b3 = double((rand(layerN(4),1)-0.5)/16);

%%
% learning loop
lRate = 0.0004; % acts very bad when it is lager than 0.0004
epoch = 1000; % too lagert may over-trained or have NaN/Inf in model, less than 4k is safe
epochSize = 500; % this can smooth the training process but SLOW it down
los = [];
accuracy = [];

title('Training start!')
pause(0);
tic;
for i=1:epoch
    % data select
    randomSelect = randi(dataNumber,1,epochSize);
    epochImg = dataImg(:,randomSelect);
    epochLab = dataLabMat(:,randomSelect);

    % model calculate

    z1 = double(w1.'*epochImg + b1);
    a1 = double(max(z1,0.01*z1));% Leaky ReLU

    z2 = double(w2.'*z1 + b2);
    a2 = double(max(z2,0.01*z2));

    z3 = double(w3.'*z2 + b3);
    a30 = double(min(exp(z3),10e300));
    a3 = double(a30./sum(a30));

    % grident calculate

    dz3 = double(a3-epochLab);
    dw3 = double(a2*(dz3.')/epochSize);
    db3 = double(sum(dz3.').'/epochSize);

    da2 = double(w3*dz3);
    dg2 = double(z2);

    dg2(find(z2>=0)) = double(1);
    dg2(find(z2<2)) = double(0.01);

    dz2 = double(da2.*dg2);

    dw2 = double(a1*(dz2.')/epochSize);
    db2 = double(sum( dz2.' ).'/epochSize);

    da1 = double(w2*dz2);
    dg1 = double(z1);

    dg1(find(z1>= 0)) = double(1);
    dg1(find(z1<0)) = double(0.01);

    dz1 = double(da1.*dg1);

    dw1 = double(epochImg*(dz1.')/epochSize);
    db1 = double(sum( dz1.').'/epochSize);

    % gradent decent
    w1 = w1-dw1*lRate;
    b1 = b1-db1*lRate;
    w2 = w2-dw2*lRate;
    b2 = b2-db2*lRate;
    w3 = w3-dw3*lRate;
    b3 = b3-db3*lRate;

    % los & accuracy calculate
    los(i) = mean(sum(-log(a3).*epochLab-log(1-a3).*(1-epochLab)));
    [m,p] = max(a3);
    modelAns = zeros(10,epochSize);
    for j=1:epochSize
       modelAns(p(j)+j*10-10) = 1;
    end
    accuracy(i) = mean(sum(modelAns.*epochLab));

    % gradualy slow down mode
    lRate = lRate*(1-0.0005);

    %{
    subplot(2,1,1)
    plot(1:i,los)
     title('Los per epoch')
    subplot(2,1,2)
    plot(1:i,accuracy)
     title('Accuracy per epoch')
    pause(0);
    %}
end

toc
subplot(2,1,1)
plot(1:epoch,los)
title('Los per epoch')
subplot(2,1,2)
plot(1:epoch,accuracy)
title('Accuracy per epoch')
%%
% test set
% data read
fImg = fopen('t10k-images.idx3-ubyte');
offset = fseek(fImg,0,'bof');
magicNumber = swapbytes(uint32(fread(fImg,1,'uint32')));
testNumber = swapbytes(uint32(fread(fImg,1,'uint32')));
rowN = swapbytes(uint32(fread(fImg,1,'uint32')));
colN = swapbytes(uint32(fread(fImg,1,'uint32')));

testImg = double(swapbytes(uint8(fread(fImg,[rowN*colN,testNumber],'uint8'))));
fclose(fImg);
fLab = fopen('t10k-labels.idx1-ubyte');
offset = fseek(fLab,8,'bof');
testLab = swapbytes(uint8(fread(fLab,[testNumber],'uint8')));
testLabMat = zeros(10,testNumber);
fclose(fLab);
for i=1:testNumber
    testLabMat(uint32(testLab(i))+i*10-9) = 1;
end

%%
% testing
z1 = double(w1.'*testImg + b1);
a1 = double(max(z1,0.01*z1));

z2 = double(w2.'*z1 + b2);
a2 = double(max(z2,0.01*z2));

z3 = double(w3.'*z2 + b3);
a30 = double(min(exp(z3),10e300));
a3 = double(a30./sum(a30));

%%
% accuracy
[m,p] = max(a3);
modelAns = zeros(10,testNumber);
for j=1:testNumber
   modelAns(p(j)+j*10-10) = 1;
end
teatAccu = mean(sum(modelAns.*testLabMat))

MATLAB

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空