MNIST手写数字数据集

作者: 发布时间:1970-01-01 08:00

MNIST手写数字数据集

下载地址:

链接: https://pan.baidu.com/s/11vcMAgey55qD3A7yJ5RbYQ 提取码: xbqb

MNIST是一个很有名的手写数字识别数据集(基本可以算是“Hello World”级别的了吧),我们要了解的情况是,对于每张图片,存储的方式是一个 28 * 28 的矩阵,但是我们在导入数据进行使用的时候会自动展平成 1 * 784(28 * 28)的向量,这在TensorFlow导入很方便。

TensorFlow 对MNIST数据集的操作:

from tensorflow.examples.tutorials.mnist import input_data
# 第一次运行会自动下载到代码所在的路径下

mnist = input_data.read_data_sets('location', one_hot=True)
# location 是保存的文件夹的名称

打印MNIST数据集的一些信息,通过这些我们就可以知道这些数据大致如何使用了

# 打印 mnist 的一些信息

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

print("type of 'mnist is %s'" % (type(mnist)))
print("number of train data is %d" % mnist.train.num_examples)
print("number of test data is %d" % mnist.test.num_examples)

# 将所有的数据加载为这样的四个数组 方便之后的使用
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels

print("Type of training is %s" % (type(trainimg)))
print("Type of trainlabel is %s" % (type(trainlabel)))
print("Type of testing is %s" % (type(testimg)))
print("Type of testing is %s" % (type(testlabel)))

输出结果:

type of 'mnist is '
number of train data is 55000    # 训练集共有55000条数据
number of test data is 10000     # 训练集有10000条数据
Type of training is     # 四个都是Numpy数组的类型
Type of trainlabel is 
Type of testing is 
Type of testing is 

如果我们想看一看每条数据保存的图片是什么样子,可以使用 matplot()函数

# 接上面的代码

nsmaple = 5
randidx = np.random.randint(trainimg.shape[0], size=nsmaple)

for i in randidx:
    curr_img = np.reshape(trainimg[i,:], (28, 28))  # 数据中保存的是 1*784 先reshape 成 28*28
    curr_label = np.argmax(trainlabel[i, :])
    plt.matshow(curr_img, cmap=plt.get_cmap('gray'))
    plt.show()

通过上面的代码可以看出数据集中的一些特点,下面建立一个简单的模型来识别这些数字。


标签:
Copyright © 2020 万物律动 旗下 AI算法狮 京ICP备20010037号-1
本站内容来源于网络开放内容的收集整理,并且仅供学习交流使用;
如有侵权,请联系删除相关内容;