在MATLAB中定义LSTM网络架构

这个例子展示了如何使用长短时记忆(LSTM)网络对序列数据进行分类。

若要训练深度神经网络对序列数据进行分类,您可以使用LSTM网络。LSTM网络使您能够将序列数据输入到网络中,并根据序列数据的单个时间步长进行预测。

本示例使用了日语元音数据集。这个例子训练一个LSTM网络来识别给定的代表两个连续日语元音的时间序列数据。训练数据包含了9名演讲者的时间序列数据。每个序列有12个特征,并且长度也有所不同。该数据集包含270个训练观察结果和370个测试观察。


Load Sequence Data

加载日语元音训练数据。

XTrain是一个包含270个长度12维序列的单元格阵列。Y是标签“1”,“2”,...,“9”的分类向量,对应于9个扬声者。XTrain中的条目是包含12行(每个特性为一行)和不同数量的列(每个时间步长为一列)的矩阵。

[XTrain,YTrain] = japaneseVowelsTrainData;
XTrain(1:5)

LSTM matlab定义 网络架构 lstm模型matlab_数据

Visualize the first time series in a plot. Each line corresponds to a feature.


figure
plot(XTrain{1}')
xlabel("Time Step")
title("Training Observation 1")
numFeatures = size(XTrain{1},1);
legend("Feature " + string(1:numFeatures),'Location','northeastoutside')

LSTM matlab定义 网络架构 lstm模型matlab_数据_02

Prepare Data for Padding

在训练期间,默认情况下,软件会将训练数据分成小批,并填充序列,使它们具有相同的长度。过多的填充物可能会对网络性能产生负面影响。为了防止训练过程添加过多的填充,可以按序列长度对训练数据进行排序,并选择一个小批的大小,以便小批中的序列具有相似的长度。

LSTM matlab定义 网络架构 lstm模型matlab_加载_03

Get the sequence lengths for each observation.

numObservations = numel(XTrain);
for i=1:numObservations
 sequence = XTrain{i};
 sequenceLengths(i) = size(sequence,2);
end

Sort the data by sequence length.

[sequenceLengths,idx] = sort(sequenceLengths);
XTrain = XTrain(idx);
YTrain = YTrain(idx);


View the sorted sequence lengths in a bar chart.

figure
bar(sequenceLengths)
ylim([0 30])
xlabel("Sequence")
ylabel("Length")
title("Sorted Data")

LSTM matlab定义 网络架构 lstm模型matlab_测试数据_04

选择一个27的小批量大小,以均匀地划分训练数据,并减少小批量中的填充量。


miniBatchSize = 27;

   

LSTM matlab定义 网络架构 lstm模型matlab_matlab_05


   

Define LSTM Network Architecture:

定义LSTM网络架构。将输入大小指定为12大小的序列(输入数据的尺寸)。指定一个包含100个隐藏单元的双向LSTM层,并输出序列的最后一个元素。最后,包含一个大小为9的全连接层,然后是一个softmax层和一个分类层。如果在预测时可以访问完整的序列,那么可以在网络中使用双向LSTM层。双向LSTM层在每个时间步长都从完整序列中学习。如果在预测时无法访问完整的序列,例如,如果正在预测值或一次预测一个时间步长,那么就使用LSTM层来代替。

inputSize = 12;
numHiddenUnits = 100;
numClasses = 9;
layers = [ ...
 sequenceInputLayer(inputSize)
 bilstmLayer(numHiddenUnits,'OutputMode','last')
 fullyConnectedLayer(numClasses)
 softmaxLayer
 classificationLayer]
 
 

layers =       
                5×1 Layer array with layers:       
                        4 
        1 '' Sequence Input Sequence input with 12 dimensions 
   
                        2 '' BiLSTM BiLSTM with 100 hidden units 
      
                        3 '' Fully Connected 9 fully connected layer 
      
                        4 '' Softmax softmax 
   
                        5 '' Classification Output crossentropyex
                        
                        



现在,指定培训选项。指定求解器为“adam”,梯度阈值为1,最大周期数为100。要减少小批量的填充量,请选择27。要将数据填充为与最长序列相同的长度,请指定序列长度为“longest”。要确保数据仍然按序列长度排序,请指定永远不要打乱数据。由于小批量处理很小,序列很短,所以训练更适合CPU。请将“ExecutionEnvironment”指定为“cpu”。若要在GPU上进行训练,如果可用,请将“ExecutionEnvironment”设置为“auto”(这是默认值)。

maxEpochs = 100;
miniBatchSize = 27;
options = trainingOptions('adam', ...
 'ExecutionEnvironment','cpu', ...
 'GradientThreshold',1, ...
 'MaxEpochs',maxEpochs, ...
 'MiniBatchSize',miniBatchSize, ...
 'SequenceLength','longest', ...
 'Shuffle','never', ...
 'Verbose',0, ...
 'Plots','training-progress');
 
 

Train LSTM Network

使用trainNetwork训练LSTM网络。

net = trainNetwork(XTrain,YTrain,layers,options);

Test LSTM Network

加载测试集,并将序列分类为扬声器。加载日语元音测试数据。

XTest是一个包含370个不同长度为12的序列的单元格阵列。YTest是标签“1”,“2”,...“9”的分类向量,对应于9个扬声器。

[XTest,YTest] = japaneseVowelsTestData;
XTest(1:3)

LSTM matlab定义 网络架构 lstm模型matlab_LSTM matlab定义 网络架构_06

LSTM网络网络使用相似长度的小批量序列进行训练。确保测试数据以相同的方式组织。按序列长度对测试数据进行排序。


numObservationsTest = numel(XTest);
for i=1:numObservationsTest
 sequence = XTest{i};
 sequenceLengthsTest(i) = size(sequence,2);
end
[sequenceLengthsTest,idx] = sort(sequenceLengthsTest);
XTest = XTest(idx);
YTest = YTest(idx);

对测试数据进行分类。为了减少分类过程中引入的填充量,请将小批量大小设置为27。要应用与训练数据相同的填充,请指定序列长度为'longest'


miniBatchSize = 27;
YPred = classify(net,XTest, ...
 'MiniBatchSize',miniBatchSize, ...
 'SequenceLength','longest');
 
 

计算预测的分类精度:


acc = sum(YPred == YTest)./numel(YTest)


免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删

QR Code
微信扫一扫,欢迎咨询~

联系我们
武汉格发信息技术有限公司
湖北省武汉市经开区科技园西路6号103孵化器
电话:155-2731-8020 座机:027-59821821
邮件:tanzw@gofarlic.com
Copyright © 2023 Gofarsoft Co.,Ltd. 保留所有权利
遇到许可问题?该如何解决!?
评估许可证实际采购量? 
不清楚软件许可证使用数据? 
收到软件厂商律师函!?  
想要少购买点许可证,节省费用? 
收到软件厂商侵权通告!?  
有正版license,但许可证不够用,需要新购? 
联系方式 155-2731-8020
预留信息,一起解决您的问题
* 姓名:
* 手机:

* 公司名称:

姓名不为空

手机不正确

公司不为空