Tensorflow keras MNIST image Model of Evaluating & Prediction

모델의 평가와 예측

라이브러리 Import

import tensorflow as tf
from tensorflow.keras import layers

from tensorflow.keras import datasets 
import matplotlib.pyplot as plt

import numpy as np

%matplotlib inline

학습 과정 돌아보기

Build Model

input_shape = (28, 28, 1)
num_classes = 10

learning_rate = 0.001

inputs = layers.Input(input_shape)
net = layers.Conv2D(32, (3, 3), padding='SAME')(inputs)
net = layers.Activation('relu')(net)
net = layers.Conv2D(32, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.5)(net)

net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.Conv2D(64, (3, 3), padding='SAME')(net)
net = layers.Activation('relu')(net)
net = layers.MaxPooling2D(pool_size=(2, 2))(net)
net = layers.Dropout(0.5)(net)

net = layers.Flatten()(net)
net = layers.Dense(512)(net)
net = layers.Activation('relu')(net)
net = layers.Dropout(0.5)(net)
net = layers.Dense(num_classes)(net)
net = layers.Activation('softmax')(net)

model = tf.keras.Model(inputs=inputs, outputs=net, name='Basic_CNN')

# Model is the full model w/o custom layers
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate),
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

Preprocess

데이터셋 불러오기

(train_x, train_y), (test_x, test_y) = datasets.mnist.load_data()

train_x = train_x[..., tf.newaxis]
test_x = test_x[..., tf.newaxis]

train_x = train_x / 255.
test_x = test_x / 255.

Training

모델 학습 하기

num_epochs = 1
batch_size = 64

hist = model.fit(train_x, train_y, 
                 batch_size=batch_size, 
                 shuffle=True)

hist.history

Evaluating

모델 학습 확인하기

model.evaluate(test_x,test_y, batch_size=batch_size)

결과 확인

test_image = test_x[0, :, : ,0]
test_image.shape
plt.imshow(test_image,)
plt.title(test_y[0])
plt.show()
pred = model.predict(test_image.reshape(1, 28, 28, 1))
pred.shape
# (1, 10)

np.argmax(pred) # 제일 큰 값의 index 
# 7

Test Batch

test_batch = test_x[:5000] # 5000개 이미지 
test_batch.shape

답글 남기기

이메일 주소는 공개되지 않습니다.