最近这一段时间,TensorFlow 2.0发布,这是你从没有体验过的全新的版本,为了能够尽快接触和体验到2.0新版本的特性,AI柠檬博主从mnist手写数字识别Demo入手,开始学习TensorFlow 2.0版。由于tf2原生内置keras包,无需另外安装,本样例将以tf.keras代码实现,并且在这一过程中发现,原本的keras代码仅需极少数改动即可迁移到TensorFlow 2.0,这对于之前一直使用Keras的用户来说,可谓非常友好了。
整个程序的代码思路来源于GitHub 上别人用Keras实现的一个小项目,并使用TensorFlow 2.0框架复现了一遍:
https://gist.github.com/alexcpn/0683bb940cae510cf84d5976c1652abd
运行环境
所需软件
Python 3.7
依赖包
tensorflow或tensorflow-gpu (版本>=2.0)
numpy
matplotlib
代码讲解
依赖包导入
从这里我们可以看到原来的直接导入keras的方式,变成了导入tensorflow.keras的方式。
# 3. Import libraries and modules import numpy as np import tensorflow as tf from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten from tensorflow.keras.layers import Convolution2D, MaxPooling2D from tensorflow.keras.datasets import mnist
设置随机数种子
np.random.seed(123) # for reproducibility
导入数据
Keras自身内置了一个mnist数据集,但是由于其存放在亚马逊云服务器上,在中国大陆我们没办法下载下来,会提示下载失败。我这里有一份下载好的数据集:
def load_data(path='mnist.npz'): """Loads the MNIST dataset. # Arguments path: path where to cache the dataset locally (relative to ~/.keras/datasets). # Returns Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. """ #path = get_file(path, # origin='https://s3.amazonaws.com/img-datasets/mnist.npz', # file_hash='8a61469f7ea1b51cbae51d4f78837e45') f = np.load(path) x_train, y_train = f['x_train'], f['y_train'] x_test, y_test = f['x_test'], f['y_test'] f.close() return (x_train, y_train), (x_test, y_test) (X_train, y_train), (X_test, y_test) = load_data()
我们可以查看一下数据的格式,输入输出维度
print(X_train.shape,y_train.shape,X_test.shape,y_test.shape)
数据可视化
为了方便,或者更直观,我们往往需要将数据可视化,这里我们需要用到一个强大的绘图库matplotlib。
import matplotlib.pyplot as plt # plot 4 images as gray scale plt.subplot(221) print(y_train[4545],y_train[1],y_train[2],y_train[3]) plt.imshow(X_train[4545], cmap=plt.get_cmap('gray')) plt.subplot(222) plt.imshow(X_train[1], cmap=plt.get_cmap('gray')) plt.subplot(223) plt.imshow(X_train[2], cmap=plt.get_cmap('gray')) plt.subplot(224) plt.imshow(X_train[3], cmap=plt.get_cmap('gray')) # show the plot plt.show() # Reshape the Input for the backend X_train = X_train.reshape(X_train.shape[0], 1, 28, 28) X_test = X_test.reshape(X_test.shape[0], 1, 28, 28) plt.subplot(224) plt.imshow(X_train[4545][0], cmap=plt.get_cmap('gray')) plt.show()
数据预处理
# convert data type and normalize values X_train = X_train.astype('float32') X_test = X_test.astype('float32') X_train /= 255 X_test /= 255 print(y_train[4545]) plt.imshow(X_train[4545][0], cmap=plt.get_cmap('gray')) plt.show() print (y_train.shape)
进行one hot处理
可以将1 2 3等的数字标签正交化,以供计算机分类成10个类别。
Y_train = tf.one_hot(y_train, 10) Y_test = tf.one_hot(y_test, 10) print (Y_train.shape)
值得一提的是,在旧版keras中,可以使用keras.utils.np_utils.to_categorical()方法实现,在tf2.0中的keras里面去除了这一API,转而使用tf自身的方法实现。
建立神经网络模型
model = Sequential() # add a sequential layer # declare a input layer model.add(Convolution2D(32,(3,3),activation='relu',data_format='channels_first',input_shape=(1,28,28))) print (model.output_shape) model.add(Convolution2D(32, 3, 3, activation='relu')) model.add(MaxPooling2D(pool_size=(2,2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(10, activation='softmax'))# output 10 classes corresponds to 0 to 9 digits we need to find model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
加载已保存的模型文件
如果程序已经运行过训练,之后再次运行,要继续训练或者预测时,需要首先加载已经保存的模型文件,如果是首次训练,可以忽略这一项。
model.load_weights('mn2.model')
训练模型并保存
model.fit(X_train, Y_train,batch_size=32, epochs=1, verbose=1) model.save_weights('mn2.model')
模型评估和检验
score = model.evaluate(X_test, Y_test, verbose=0) print(score) k = np.array(X_train[4545]) #seven print(k.shape) y= k.reshape(1,1,28,28) print(y.shape) prediction = model.predict(y) print(prediction) class_pred = model.predict_classes(y) print(class_pred) plt.subplot(111) plt.imshow(X_train[4545][0], cmap=plt.get_cmap('gray')) plt.show()
如果没问题,就会看到分类是正确的。
版权声明本博客的文章除特别说明外均为原创,本人版权所有。欢迎转载,转载请注明作者及来源链接,谢谢。本文地址: https://blog.ailemon.net/2019/11/04/tensorflow-2-0-implements-mnist-handwritten-digit-recognition/ All articles are under Attribution-NonCommercial-ShareAlike 4.0 |