ADDOPS团队籍鑫璞 360云计算
学习一门编程语言,都是从最基本的“hello world”开始的,作为一个编程模式,tensorflow也有自己的“hello world”–MNIST。它是计算机视觉的最基础且最常见的数据集,由70000张28像素28像素的手写数字组成,每个图片对应一个标签,分别是0~9之间的一个数字。这70000个样本包含55000个训练样本,10000个测试样本以及5000个验证样本。我们随机选取了该数据集中的四个样本,如下图,它们对应的标签分别是5,0,4和1。
大家知道,电脑识别图片是从最基本的单元–像素开始的,为了将上面的手写数字转换成电脑能够识别的信号,我们引入灰度概念,将2828个像素点转换成灰度值(0到1区间内的一个值)。下图是将标签为1的图片转换成28乘以28的矩阵。
数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。同时,训练的数据将是一个55000*10的张量,第一个维度数字用来索引图片,第二个维度是label的向量。比如数字0,对应的label就是[1,0,0,0,0,0,0,0,0,0]。
这篇文章的主要任务是建立一个模型,能够识别手写的数字图片,得到0~9之间的值,即进行分类。目前实现MNIST模型有很多种,有的准确率超过90%以上,但是我们需要从最简单的Softmax Regression来开始。
我们知道MNIST的每一张图片都表示一个数字,从0到9。我们希望得到给定图片代表每个数字的概率。比如说,我们的模型可能推测一张包含9的图片代表数字9的概率是80%但是判断它是其他数字的概率比80%的概率要低。softmax模型可以用来给不同的对象分配概率,即使在之后,我们训练更加精细的模型时,最后一步也需要用softmax来分配概率。
Softmax Regression的工作原理很简单,将特征相加,然后将这些特征转化为判定是这类的概率。比如,某个像素的灰度值大代表很可能是数字n时,这个像素的权值就会很大;相反,如果不太可能是n,则权重有可能是负数。
我们可以将这些特征写出如下的公式,i代表第i类(0到9的值),j代表一张图片的第j个像素,bi是一个误差项。
得到所有像素的和以后,我们用softmax函数可以把这些证据转换成概率y。
我们将得到某个图片标签是0到9的概率值,我们取概率值最大的标签,即为该图片表示的数字。
接下来我们用tensorflow来实现一个softmax regression。在tensorflow中,定义一个softmax模型比较简单,只需要下面五行。
登录后复制
import tensorflow as tf
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
为了训练我们的模型,我们首先需要定义一个指标来评估这个模型是好的。其实,在机器学习,我们通常定义指标来表示一个模型是坏的,这个指标称为成本(cost)或损失(loss),然后尽量最小化这个指标。
一个非常常见的成本函数是“交叉熵”(cross-entropy)。交叉熵产生于信息论里面的信息压缩编码技术,但是它后来演变成为从博弈论到机器学习等其他领域里的重要技术手段。它的定义如下:
其中,y 是我们预测的概率分布, y’ 是实际的分布。
tensorflow中定义“交叉熵”也比较容易,代码如下:
登录后复制
y_ = tf.placeholder("float", [None,10])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
现在我们有了算法softmax regression,又有了损失函数cross-entropy的定义,只需要定义一个优化算法即可开始训练。我们采用常见的“随机梯度下降”来定义该优化方法,也比较容易,代码如下:
登录后复制
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
最后,我们需要定义一个会话来运行上面的计算图,代码如下:
登录后复制
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for i in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
本篇文章介绍了如何使用tensorflow来构建SoftMax Regression模型,从而实现对MNIST数据集的识别,希望对tensorflow的使用有个简单的认识。
免责声明:本文系网络转载或改编,未找到原创作者,版权归原作者所有。如涉及版权,请联系删