In this notebook, we’re going to take the dataset we prepared and build a model to classify the images.

Table of Contents

Introduction

What model should we use? A simple model, such as logistic regression, is able to do fairly well on a simple image classification task like MNIST or even cat vs non-cat, but won’t really cut it for this task. Because kangaroos and wallabies are so similar, even a fully connected neural network doesn’t score particularly well on this dataset.

So we’re just going to skip those and go straight to the most popular model in modern computer vision: convolutional neural networks. We’re going to take a huge model known as Xception and apply it to our dataset. But we’re not just going to take the model architecture, we’re going to take the actual ImageNet weights that have been developed by weeks of training it on the ImageNet dataset. Then we will freeze everything except the last layers and train them with our dataset. This is known as transfer learning and it allows us to benefit from the highly tuned convolutional layers that are so good at extracting features while allowing us to tweak the last layers specifically for our problem.

import os
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras import applications, optimizers
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.models import Model
tf.random.set_seed(0)

Let’s make sure we have a GPU available because this would take way too long on a CPU.

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices("GPU")))
os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
Num GPUs Available:  1

Prepare the Data

Now we’ll prepare our data using TensorFlow’s tf.data.Datasets.

train_data_dir = Path(os.getenv("DATA")) / "WallabiesAndRoosFullSize/train"
val_data_dir = Path(os.getenv("DATA")) / "WallabiesAndRoosFullSize/val"
test_data_dir = Path(os.getenv("DATA")) / "WallabiesAndRoosFullSize/test"
train_files = tf.data.Dataset.list_files(str(train_data_dir / "*/*"), shuffle=True)
val_files = tf.data.Dataset.list_files(str(val_data_dir / "*/*"), shuffle=False)
test_files = tf.data.Dataset.list_files(str(test_data_dir / "*/*"), shuffle=False)
class_names = np.array(sorted([folder.name for folder in train_data_dir.glob('*')]))
num_classes = len(class_names)
print(class_names)
['kangaroo' 'wallaby']

We’ll set the image size we want to work with and the number of epochs.

img_height, img_width = 256, 256
BATCH_SIZE = 4
EPOCHS = 5
def parse_label(filename):
    parts = tf.strings.split(filename, sep=os.path.sep)
    one_hot_label = parts[-2] == class_names
    label = tf.argmax(one_hot_label)
    return label


def parse_image(filename):
    img = tf.io.read_file(filename)
    image_decoded = tf.io.decode_jpeg(img, channels=3)
    image = tf.image.convert_image_dtype(image_decoded, tf.float32)
    image = tf.image.resize(image, (img_height, img_width))
    return image


def parse_file_path(filename):
    image = parse_image(filename)
    label = parse_label(filename)
    return image, label

Now let’s look at some of our images

data_iter = iter(train_files)
file_path = next(data_iter)
image, label = parse_file_path(file_path)


def display_image(image, label):
    plt.figure()
    plt.imshow(image)
    plt.title(class_names[label])
    plt.axis("off")


display_image(image, label)

png

file_path = next(data_iter)
image, label = parse_file_path(file_path)
display_image(image, label)

png

Now we create the image data pipeline. This is the process of steps to take us from a tf.data.Dataset object into parsed and batched image and label pairs ready for training. We’ll also use TensorFlow’s prefetch ability to speed up the data loading.

def one_hot(image, label):
    """
    Convert to one-hot encoding
    """
    label = tf.one_hot(tf.cast(label, tf.int32), num_classes)
    # Recasts it to Float32
    label = tf.cast(label, tf.float32)
    return image, label
train_dataset = train_files.map(parse_file_path)
train_dataset = train_dataset.batch(BATCH_SIZE)
train_dataset = train_dataset.map(one_hot)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
val_dataset = val_files.map(parse_file_path)
val_dataset = val_dataset.batch(BATCH_SIZE)
val_dataset = val_dataset.map(one_hot)
val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test_files.map(parse_file_path)
test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.map(one_hot)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)
data_iter = iter(train_dataset)
examp_im, examp_label = next(data_iter)

Let’s visualize the data coming out of our tf.data.Dataset. I like to do this to ensure it’s what I’m expecting. The image should be a bunch of values between 0 and 1 and the labels should all be either 0 or 1.

plt.hist(examp_im.numpy().flatten());

png

np.unique(examp_label)
array([0., 1.], dtype=float32)

Looks good!

Establish a Baseline

Before we build a complex model, we should develop a baseline so we can gauge our model’s performance. As we saw in the first post, there are more images of kangaroos than of wallabies, so we need to take that into consideration. For example, if our model accurately predicted 75% of images, that might sound good. But what if simply guessing “kangaroo” each time got the same result, then maybe the model isn’t learning much at all. So what’s the baseline? How many would we get from pure guessing? To answer that, we’ll use a dummy classifier. Our dummy classifier will find the most common label and predict that for every image.

If we had many classes, we could use something like sklearn’s DummyClassifier with the most_frequent strategy, but we only have two classes so we can simplify the process. Let’s figure out which class has more examples and then guess that for every example in our test dataset.

We know that all of our classes are either 0 or 1, so we can see which we have more of like so.

kang_dir = test_data_dir / "kangaroo"
wall_dir = test_data_dir / "wallaby"
num_kang_images = len(list(kang_dir.glob("*")))
num_wall_images = len(list(wall_dir.glob("*")))
print(f"Number of kangaroo images: {num_kang_images}")
print(f"Number of wallaby images: {num_wall_images}")
print(
    f"If we just guessed the most common class each time, we would be right {num_kang_images / (num_kang_images + num_wall_images):.2%} of the time."
)
Number of kangaroo images: 306
Number of wallaby images: 195
If we just guessed the most common class each time, we would be right 61.08% of the time.

Prepare the Model

We have to provide some basic characteristics of the network and how we want to train it. As we said before, we’ll use ImageNet weights.

model_encoder = applications.Xception(weights="imagenet", include_top=False, input_shape=(img_width, img_height, 3))
# Add our own layers at the end
x = model_encoder.output
x = Flatten()(x)
x = tf.keras.layers.BatchNormalization()(x)
x = Dense(256, activation="relu")(x)
predictions = Dense(num_classes, activation="softmax")(x)

Create the model using the inputs from the pretrained model.

model = Model(inputs=model_encoder.input, outputs=predictions)

We will freeze all layers except the last five. Then we’ll add our own layers at the end. For the final activation function, I’m going to use softmax. I could just use a sigmoid function, but using softmax makes it easier to scale the model to multiple classes, even though they’re mathematically equivalent in the case of two classes.

One problem we’re sure to run into with such a large and complex neural network is overfitting. This is when the model finds a small number of features that work well in the dataset but might not generalize to all images. For example, say all the images in the training set show the ears of the kangaroos and wallabies really well and the model learns how to distinguish them based on that. Well, maybe the ears are behind a branch in some other images, then how is the model going to decide? We want the model to look at many aspects of the image and use them all to classify it.

There are several different types of regularization that we could use, but we’ll discuss just two of them: L1 and L2. They rely on the same concept - penalizing the network for weights that are too large. This forces each individual parameter to be low and therefore prevents the model from relying too much on a single weight or feature. L1 regularization penalizes based on the magnitude of the weights, and L2 penalizes based on the square of the magnitude of the weights. There are good reasons to use one over the other which we won’t get into, but in this case, we’ll use L2 because it’s more common and generally seems to give better performance in most cases.

for layer in model.layers[:-2]:
    layer.trainable = False
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
block1_conv1 (Conv2D)           (None, 127, 127, 32) 864         input_1[0][0]                    
__________________________________________________________________________________________________
block1_conv1_bn (BatchNormaliza (None, 127, 127, 32) 128         block1_conv1[0][0]               
__________________________________________________________________________________________________
block1_conv1_act (Activation)   (None, 127, 127, 32) 0           block1_conv1_bn[0][0]            
__________________________________________________________________________________________________
block1_conv2 (Conv2D)           (None, 125, 125, 64) 18432       block1_conv1_act[0][0]           
__________________________________________________________________________________________________
block1_conv2_bn (BatchNormaliza (None, 125, 125, 64) 256         block1_conv2[0][0]               
__________________________________________________________________________________________________
block1_conv2_act (Activation)   (None, 125, 125, 64) 0           block1_conv2_bn[0][0]            
__________________________________________________________________________________________________
block2_sepconv1 (SeparableConv2 (None, 125, 125, 128 8768        block1_conv2_act[0][0]           
__________________________________________________________________________________________________
block2_sepconv1_bn (BatchNormal (None, 125, 125, 128 512         block2_sepconv1[0][0]            
__________________________________________________________________________________________________
block2_sepconv2_act (Activation (None, 125, 125, 128 0           block2_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block2_sepconv2 (SeparableConv2 (None, 125, 125, 128 17536       block2_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block2_sepconv2_bn (BatchNormal (None, 125, 125, 128 512         block2_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 63, 63, 128)  8192        block1_conv2_act[0][0]           
__________________________________________________________________________________________________
block2_pool (MaxPooling2D)      (None, 63, 63, 128)  0           block2_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 63, 63, 128)  512         conv2d[0][0]                     
__________________________________________________________________________________________________
add (Add)                       (None, 63, 63, 128)  0           block2_pool[0][0]                
                                                                 batch_normalization[0][0]        
__________________________________________________________________________________________________
block3_sepconv1_act (Activation (None, 63, 63, 128)  0           add[0][0]                        
__________________________________________________________________________________________________
block3_sepconv1 (SeparableConv2 (None, 63, 63, 256)  33920       block3_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block3_sepconv1_bn (BatchNormal (None, 63, 63, 256)  1024        block3_sepconv1[0][0]            
__________________________________________________________________________________________________
block3_sepconv2_act (Activation (None, 63, 63, 256)  0           block3_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block3_sepconv2 (SeparableConv2 (None, 63, 63, 256)  67840       block3_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block3_sepconv2_bn (BatchNormal (None, 63, 63, 256)  1024        block3_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 256)  32768       add[0][0]                        
__________________________________________________________________________________________________
block3_pool (MaxPooling2D)      (None, 32, 32, 256)  0           block3_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 32, 32, 256)  1024        conv2d_1[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 32, 32, 256)  0           block3_pool[0][0]                
                                                                 batch_normalization_1[0][0]      
__________________________________________________________________________________________________
block4_sepconv1_act (Activation (None, 32, 32, 256)  0           add_1[0][0]                      
__________________________________________________________________________________________________
block4_sepconv1 (SeparableConv2 (None, 32, 32, 728)  188672      block4_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block4_sepconv1_bn (BatchNormal (None, 32, 32, 728)  2912        block4_sepconv1[0][0]            
__________________________________________________________________________________________________
block4_sepconv2_act (Activation (None, 32, 32, 728)  0           block4_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block4_sepconv2 (SeparableConv2 (None, 32, 32, 728)  536536      block4_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block4_sepconv2_bn (BatchNormal (None, 32, 32, 728)  2912        block4_sepconv2[0][0]            
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 16, 16, 728)  186368      add_1[0][0]                      
__________________________________________________________________________________________________
block4_pool (MaxPooling2D)      (None, 16, 16, 728)  0           block4_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 16, 16, 728)  2912        conv2d_2[0][0]                   
__________________________________________________________________________________________________
add_2 (Add)                     (None, 16, 16, 728)  0           block4_pool[0][0]                
                                                                 batch_normalization_2[0][0]      
__________________________________________________________________________________________________
block5_sepconv1_act (Activation (None, 16, 16, 728)  0           add_2[0][0]                      
__________________________________________________________________________________________________
block5_sepconv1 (SeparableConv2 (None, 16, 16, 728)  536536      block5_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv1_bn (BatchNormal (None, 16, 16, 728)  2912        block5_sepconv1[0][0]            
__________________________________________________________________________________________________
block5_sepconv2_act (Activation (None, 16, 16, 728)  0           block5_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block5_sepconv2 (SeparableConv2 (None, 16, 16, 728)  536536      block5_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv2_bn (BatchNormal (None, 16, 16, 728)  2912        block5_sepconv2[0][0]            
__________________________________________________________________________________________________
block5_sepconv3_act (Activation (None, 16, 16, 728)  0           block5_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block5_sepconv3 (SeparableConv2 (None, 16, 16, 728)  536536      block5_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block5_sepconv3_bn (BatchNormal (None, 16, 16, 728)  2912        block5_sepconv3[0][0]            
__________________________________________________________________________________________________
add_3 (Add)                     (None, 16, 16, 728)  0           block5_sepconv3_bn[0][0]         
                                                                 add_2[0][0]                      
__________________________________________________________________________________________________
block6_sepconv1_act (Activation (None, 16, 16, 728)  0           add_3[0][0]                      
__________________________________________________________________________________________________
block6_sepconv1 (SeparableConv2 (None, 16, 16, 728)  536536      block6_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv1_bn (BatchNormal (None, 16, 16, 728)  2912        block6_sepconv1[0][0]            
__________________________________________________________________________________________________
block6_sepconv2_act (Activation (None, 16, 16, 728)  0           block6_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block6_sepconv2 (SeparableConv2 (None, 16, 16, 728)  536536      block6_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv2_bn (BatchNormal (None, 16, 16, 728)  2912        block6_sepconv2[0][0]            
__________________________________________________________________________________________________
block6_sepconv3_act (Activation (None, 16, 16, 728)  0           block6_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block6_sepconv3 (SeparableConv2 (None, 16, 16, 728)  536536      block6_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block6_sepconv3_bn (BatchNormal (None, 16, 16, 728)  2912        block6_sepconv3[0][0]            
__________________________________________________________________________________________________
add_4 (Add)                     (None, 16, 16, 728)  0           block6_sepconv3_bn[0][0]         
                                                                 add_3[0][0]                      
__________________________________________________________________________________________________
block7_sepconv1_act (Activation (None, 16, 16, 728)  0           add_4[0][0]                      
__________________________________________________________________________________________________
block7_sepconv1 (SeparableConv2 (None, 16, 16, 728)  536536      block7_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv1_bn (BatchNormal (None, 16, 16, 728)  2912        block7_sepconv1[0][0]            
__________________________________________________________________________________________________
block7_sepconv2_act (Activation (None, 16, 16, 728)  0           block7_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block7_sepconv2 (SeparableConv2 (None, 16, 16, 728)  536536      block7_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv2_bn (BatchNormal (None, 16, 16, 728)  2912        block7_sepconv2[0][0]            
__________________________________________________________________________________________________
block7_sepconv3_act (Activation (None, 16, 16, 728)  0           block7_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block7_sepconv3 (SeparableConv2 (None, 16, 16, 728)  536536      block7_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block7_sepconv3_bn (BatchNormal (None, 16, 16, 728)  2912        block7_sepconv3[0][0]            
__________________________________________________________________________________________________
add_5 (Add)                     (None, 16, 16, 728)  0           block7_sepconv3_bn[0][0]         
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
block8_sepconv1_act (Activation (None, 16, 16, 728)  0           add_5[0][0]                      
__________________________________________________________________________________________________
block8_sepconv1 (SeparableConv2 (None, 16, 16, 728)  536536      block8_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv1_bn (BatchNormal (None, 16, 16, 728)  2912        block8_sepconv1[0][0]            
__________________________________________________________________________________________________
block8_sepconv2_act (Activation (None, 16, 16, 728)  0           block8_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block8_sepconv2 (SeparableConv2 (None, 16, 16, 728)  536536      block8_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv2_bn (BatchNormal (None, 16, 16, 728)  2912        block8_sepconv2[0][0]            
__________________________________________________________________________________________________
block8_sepconv3_act (Activation (None, 16, 16, 728)  0           block8_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block8_sepconv3 (SeparableConv2 (None, 16, 16, 728)  536536      block8_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block8_sepconv3_bn (BatchNormal (None, 16, 16, 728)  2912        block8_sepconv3[0][0]            
__________________________________________________________________________________________________
add_6 (Add)                     (None, 16, 16, 728)  0           block8_sepconv3_bn[0][0]         
                                                                 add_5[0][0]                      
__________________________________________________________________________________________________
block9_sepconv1_act (Activation (None, 16, 16, 728)  0           add_6[0][0]                      
__________________________________________________________________________________________________
block9_sepconv1 (SeparableConv2 (None, 16, 16, 728)  536536      block9_sepconv1_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv1_bn (BatchNormal (None, 16, 16, 728)  2912        block9_sepconv1[0][0]            
__________________________________________________________________________________________________
block9_sepconv2_act (Activation (None, 16, 16, 728)  0           block9_sepconv1_bn[0][0]         
__________________________________________________________________________________________________
block9_sepconv2 (SeparableConv2 (None, 16, 16, 728)  536536      block9_sepconv2_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv2_bn (BatchNormal (None, 16, 16, 728)  2912        block9_sepconv2[0][0]            
__________________________________________________________________________________________________
block9_sepconv3_act (Activation (None, 16, 16, 728)  0           block9_sepconv2_bn[0][0]         
__________________________________________________________________________________________________
block9_sepconv3 (SeparableConv2 (None, 16, 16, 728)  536536      block9_sepconv3_act[0][0]        
__________________________________________________________________________________________________
block9_sepconv3_bn (BatchNormal (None, 16, 16, 728)  2912        block9_sepconv3[0][0]            
__________________________________________________________________________________________________
add_7 (Add)                     (None, 16, 16, 728)  0           block9_sepconv3_bn[0][0]         
                                                                 add_6[0][0]                      
__________________________________________________________________________________________________
block10_sepconv1_act (Activatio (None, 16, 16, 728)  0           add_7[0][0]                      
__________________________________________________________________________________________________
block10_sepconv1 (SeparableConv (None, 16, 16, 728)  536536      block10_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv1_bn (BatchNorma (None, 16, 16, 728)  2912        block10_sepconv1[0][0]           
__________________________________________________________________________________________________
block10_sepconv2_act (Activatio (None, 16, 16, 728)  0           block10_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block10_sepconv2 (SeparableConv (None, 16, 16, 728)  536536      block10_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv2_bn (BatchNorma (None, 16, 16, 728)  2912        block10_sepconv2[0][0]           
__________________________________________________________________________________________________
block10_sepconv3_act (Activatio (None, 16, 16, 728)  0           block10_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block10_sepconv3 (SeparableConv (None, 16, 16, 728)  536536      block10_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block10_sepconv3_bn (BatchNorma (None, 16, 16, 728)  2912        block10_sepconv3[0][0]           
__________________________________________________________________________________________________
add_8 (Add)                     (None, 16, 16, 728)  0           block10_sepconv3_bn[0][0]        
                                                                 add_7[0][0]                      
__________________________________________________________________________________________________
block11_sepconv1_act (Activatio (None, 16, 16, 728)  0           add_8[0][0]                      
__________________________________________________________________________________________________
block11_sepconv1 (SeparableConv (None, 16, 16, 728)  536536      block11_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv1_bn (BatchNorma (None, 16, 16, 728)  2912        block11_sepconv1[0][0]           
__________________________________________________________________________________________________
block11_sepconv2_act (Activatio (None, 16, 16, 728)  0           block11_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block11_sepconv2 (SeparableConv (None, 16, 16, 728)  536536      block11_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv2_bn (BatchNorma (None, 16, 16, 728)  2912        block11_sepconv2[0][0]           
__________________________________________________________________________________________________
block11_sepconv3_act (Activatio (None, 16, 16, 728)  0           block11_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block11_sepconv3 (SeparableConv (None, 16, 16, 728)  536536      block11_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block11_sepconv3_bn (BatchNorma (None, 16, 16, 728)  2912        block11_sepconv3[0][0]           
__________________________________________________________________________________________________
add_9 (Add)                     (None, 16, 16, 728)  0           block11_sepconv3_bn[0][0]        
                                                                 add_8[0][0]                      
__________________________________________________________________________________________________
block12_sepconv1_act (Activatio (None, 16, 16, 728)  0           add_9[0][0]                      
__________________________________________________________________________________________________
block12_sepconv1 (SeparableConv (None, 16, 16, 728)  536536      block12_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv1_bn (BatchNorma (None, 16, 16, 728)  2912        block12_sepconv1[0][0]           
__________________________________________________________________________________________________
block12_sepconv2_act (Activatio (None, 16, 16, 728)  0           block12_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block12_sepconv2 (SeparableConv (None, 16, 16, 728)  536536      block12_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv2_bn (BatchNorma (None, 16, 16, 728)  2912        block12_sepconv2[0][0]           
__________________________________________________________________________________________________
block12_sepconv3_act (Activatio (None, 16, 16, 728)  0           block12_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
block12_sepconv3 (SeparableConv (None, 16, 16, 728)  536536      block12_sepconv3_act[0][0]       
__________________________________________________________________________________________________
block12_sepconv3_bn (BatchNorma (None, 16, 16, 728)  2912        block12_sepconv3[0][0]           
__________________________________________________________________________________________________
add_10 (Add)                    (None, 16, 16, 728)  0           block12_sepconv3_bn[0][0]        
                                                                 add_9[0][0]                      
__________________________________________________________________________________________________
block13_sepconv1_act (Activatio (None, 16, 16, 728)  0           add_10[0][0]                     
__________________________________________________________________________________________________
block13_sepconv1 (SeparableConv (None, 16, 16, 728)  536536      block13_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block13_sepconv1_bn (BatchNorma (None, 16, 16, 728)  2912        block13_sepconv1[0][0]           
__________________________________________________________________________________________________
block13_sepconv2_act (Activatio (None, 16, 16, 728)  0           block13_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block13_sepconv2 (SeparableConv (None, 16, 16, 1024) 752024      block13_sepconv2_act[0][0]       
__________________________________________________________________________________________________
block13_sepconv2_bn (BatchNorma (None, 16, 16, 1024) 4096        block13_sepconv2[0][0]           
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 8, 8, 1024)   745472      add_10[0][0]                     
__________________________________________________________________________________________________
block13_pool (MaxPooling2D)     (None, 8, 8, 1024)   0           block13_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 8, 8, 1024)   4096        conv2d_3[0][0]                   
__________________________________________________________________________________________________
add_11 (Add)                    (None, 8, 8, 1024)   0           block13_pool[0][0]               
                                                                 batch_normalization_3[0][0]      
__________________________________________________________________________________________________
block14_sepconv1 (SeparableConv (None, 8, 8, 1536)   1582080     add_11[0][0]                     
__________________________________________________________________________________________________
block14_sepconv1_bn (BatchNorma (None, 8, 8, 1536)   6144        block14_sepconv1[0][0]           
__________________________________________________________________________________________________
block14_sepconv1_act (Activatio (None, 8, 8, 1536)   0           block14_sepconv1_bn[0][0]        
__________________________________________________________________________________________________
block14_sepconv2 (SeparableConv (None, 8, 8, 2048)   3159552     block14_sepconv1_act[0][0]       
__________________________________________________________________________________________________
block14_sepconv2_bn (BatchNorma (None, 8, 8, 2048)   8192        block14_sepconv2[0][0]           
__________________________________________________________________________________________________
block14_sepconv2_act (Activatio (None, 8, 8, 2048)   0           block14_sepconv2_bn[0][0]        
__________________________________________________________________________________________________
flatten (Flatten)               (None, 131072)       0           block14_sepconv2_act[0][0]       
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 131072)       524288      flatten[0][0]                    
__________________________________________________________________________________________________
dense (Dense)                   (None, 256)          33554688    batch_normalization_4[0][0]      
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 2)            514         dense[0][0]                      
==================================================================================================
Total params: 54,940,970
Trainable params: 33,555,202
Non-trainable params: 21,385,768
__________________________________________________________________________________________________

Compile the Model

Now we compile the model. The compiler automatically determines how to split up the data between the CPU and GPU. We have to specify the loss function, optimizer, and whatever metrics we’re interested in.

model.compile(loss="binary_crossentropy", optimizer=optimizers.SGD(lr=0.0001, momentum=0.9), metrics=["accuracy"])

I always like to save the model so I can reload it later if needed.

results_path = Path(r"C:\Users\Julius\Documents\GitHub\cv\results\KangWall_" + datetime.now().strftime("%Y%m%d-%H%M%S"))

os.makedirs(results_path, exist_ok=True)
# Save the model according to the conditions

checkpoint_path = results_path / "model-{epoch:02d}-{val_accuracy:.2f}.hdf5"
early = EarlyStopping(monitor='val_accuracy', min_delta=0,
                      patience=4, verbose=1, mode='auto')
checkpoint = ModelCheckpoint(str(checkpoint_path), monitor='val_accuracy', verbose=1, save_best_only=False, mode='auto', save_freq='epoch')
tensorboard_callback = tf.keras.callbacks.TensorBoard(
    log_dir=results_path, histogram_freq=0, write_graph=True, write_images=True, update_freq="epoch"
)

OK, now let’s train the model.

history = model.fit(
    train_dataset, epochs=EPOCHS, callbacks=[tensorboard_callback, early, checkpoint], validation_data=val_dataset
)
Epoch 1/5
914/914 [==============================] - 678s 735ms/step - loss: 0.2790 - accuracy: 0.8730 - val_loss: 0.6247 - val_accuracy: 0.8289

Epoch 00001: saving model to C:\Users\Julius\Documents\GitHub\cv\results\KangWall_20221017-211925\model-01-0.83.hdf5
Epoch 2/5
914/914 [==============================] - 558s 610ms/step - loss: 0.0592 - accuracy: 0.9836 - val_loss: 0.4672 - val_accuracy: 0.8501

Epoch 00002: saving model to C:\Users\Julius\Documents\GitHub\cv\results\KangWall_20221017-211925\model-02-0.85.hdf5
Epoch 3/5
914/914 [==============================] - 559s 612ms/step - loss: 0.0285 - accuracy: 0.9943 - val_loss: 0.6124 - val_accuracy: 0.8413

Epoch 00003: saving model to C:\Users\Julius\Documents\GitHub\cv\results\KangWall_20221017-211925\model-03-0.84.hdf5
Epoch 4/5
914/914 [==============================] - 566s 620ms/step - loss: 0.0168 - accuracy: 0.9975 - val_loss: 0.5485 - val_accuracy: 0.8413

Epoch 00004: saving model to C:\Users\Julius\Documents\GitHub\cv\results\KangWall_20221017-211925\model-04-0.84.hdf5
Epoch 5/5
914/914 [==============================] - 556s 608ms/step - loss: 0.0126 - accuracy: 1.0000 - val_loss: 0.7400 - val_accuracy: 0.8131

Epoch 00005: saving model to C:\Users\Julius\Documents\GitHub\cv\results\KangWall_20221017-211925\model-05-0.81.hdf5
model.save(str(results_path) + "final_model")
INFO:tensorflow:Assets written to: C:\Users\Julius\Documents\GitHub\cv\results\KangWall_20221017-211925final_model\assets

Note how much higher the training accuracy is than the validation accuracy. That means we’re overfitting the training data. We’ll go over how to correct for that in a future notebook.

plt.plot(history.history["accuracy"], label="accuracy")
plt.plot(history.history["val_accuracy"], label="val_accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.ylim([0.5, 1])
plt.legend(loc="lower right");

png

Evaluating the Model

Note that we’re only saving the model when it improves the validation set. Since we keep checking the accuracy on our validation set, we could actually be overfitting the validation set as well. That’s why we reserved a test set that, so far, we haven’t even looked at. We’ll use that to compute the model’s accuracy. Accuracy isn’t the most comprehensive way to measure model quality, especially in the case of multiple classes, but it’s quick and simple, so we’ll use it.

model_eval = tf.keras.models.load_model(str(results_path) + "final_model")
model_eval.evaluate(test_dataset)
126/126 [==============================] - 82s 641ms/step - loss: 0.4245 - accuracy: 0.8463





[0.42446160316467285, 0.8463073968887329]

85% - that’s pretty good! But we’re not done yet - for one thing you can see clear overfitting in the plot. In the next notebook, we’ll talk about data augmentation and how we can use that to improve the model’s performance on the test set.