import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, re
import warnings, math, sys, json
import subprocess, pprint, pdb

import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score,precision_score, recall_score, confusion_matrix

warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
Using TensorFlow v2.4.1

Tip: Adding seed helps reproduce results. Setting debug parameter wil run the model on smaller number of epochs to validate the architecture.

def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'

GOOGLE = 'google.colab' in str(get_ipython())
KAGGLE = not GOOGLE

print("Running on {}!".format(
   "Google Colab" if GOOGLE else "Kaggle Kernel"
))
Running on Google Colab!

Hyperparameters

BASE_MODEL = 'efficientnet_b3' #@param ["'efficientnet_b3'", "'efficientnet_b4'", "'efficientnet_b2'"] {type:"raw", allow-input: true}
HEIGHT = 300#@param {type:"number"}
WIDTH = 300#@param {type:"number"}
CHANNELS = 3#@param {type:"number"}
IMG_SIZE = (HEIGHT, WIDTH, CHANNELS)
EPOCHS =  50#@param {type:"number"}
BATCH_SIZE = 32 * strategy.num_replicas_in_sync #@param {type:"raw"}

print("Use {} with input size {}".format(BASE_MODEL, IMG_SIZE))
print("Train on batch size of {} for {} epochs".format(BATCH_SIZE, EPOCHS))
Use efficientnet_b3 with input size (300, 300, 3)
Train on batch size of 256 for 50 epochs

Data

Loading data

%%run_if {GOOGLE}
#@title {run: "auto", display-mode: "form"}
GCS_PATH = 'gs://kds-c6b9829baa483a13a169c7cbe266341fb8c9b5ba36843af37a093a4c' #@param {type: "string"}
GCS_PATH += '/tfrecords-jpeg-512x512' #@param {type: "string"}
print(f"Sourcing images from {GCS_PATH}")
Sourcing images from gs://kds-c6b9829baa483a13a169c7cbe266341fb8c9b5ba36843af37a093a4c/tfrecords-jpeg-512x512

CLASSES = ['pink primrose',        'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',
           'tiger lily',           'moon orchid',               'bird of paradise', 'monkshood',     'globe thistle',        
           'snapdragon',           "colt's foot",               'king protea',      'spear thistle', 'yellow iris',
           'globe-flower',         'purple coneflower',         'peruvian lily',    'balloon flower','giant white arum lily',
           'fire lily',            'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',
           'corn poppy',           'prince of wales feathers',  'stemless gentian', 'artichoke',     'sweet william',        
           'carnation',            'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',
           'ruby-lipped cattleya', 'cape flower',               'great masterwort', 'siam tulip',    'lenten rose',          
           'barberton daisy',      'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',
           'wallflower',           'marigold',                  'buttercup',        'daisy',         'common dandelion',     
           'petunia',              'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',
           'bishop of llandaff',   'gaura',                     'geranium',         'orange dahlia', 'pink-yellow dahlia',   
           'cautleya spicata',     'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy',
           'osteospermum',         'spring crocus',             'iris',             'windflower',    'tree poppy',           
           'gazania',              'azalea',                    'water lily',       'rose',          'thorn apple',
           'morning glory',        'passion flower',            'lotus',            'toad lily',     'anthurium',
           'frangipani',           'clematis',                  'hibiscus',         'columbine',     'desert-rose',
           'tree mallow',          'magnolia',                  'cyclamen ',        'watercress',    'canna lily',           
           'hippeastrum ',         'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',
           'camellia',             'mallow',                    'mexican petunia',  'bromelia',      'blanket flower',       
           'trumpet creeper',      'blackberry lily',           'common tulip',     'wild rose']

NCLASSES = len(CLASSES)
print(f"Number of labels: {NCLASSES}")
Number of labels: 104

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) if GOOGLE
                else tf.cast(image, tf.float32) / 255.0)
    image = tf.image.random_crop(image, IMG_SIZE)
    return image
    
def collate_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32)
    return image, label

def process_unlabeled_tfrecord(example):
    UNLABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "id": tf.io.FixedLenFeature([], tf.string),    # shape [] means single element
    }
    example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    idnum = example['id']
    return image, idnum

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
            for filename in filenames]
    return np.sum(n)
train_filenames = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
test_filenames  = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') 
Number of train set: 12753
Number of valid set: 3712
Number of test set:  7382

Data augmentation

Note: The following data augmentation functions are referenced from Data Augmentation using GPU/TPU for Maximum Speed! by @cdeotte

def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_shift(image, height, h_shift, w_shift):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly shifted
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    height_shift = h_shift * tf.random.uniform([1],dtype='float32') 
    width_shift = w_shift * tf.random.uniform([1],dtype='float32') 
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
        
    # SHIFT MATRIX
    shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])

def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)    
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shift = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Flips
    if p_spatial >= .2:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
    
    if p_rotation >= .3: # Rotation
        image = transform_rotation(image, height=HEIGHT, rotation=45.)
    if p_shift >= .3: # Shift
        image = transform_shift(image, height=HEIGHT, h_shift=15., w_shift=15.)
    if p_shear >= .3: # Shear
        image = transform_shear(image, height=HEIGHT, shear=20.)
        
    # Crops
    if p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.7), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
    elif p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.7)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.8)
        else:
            image = tf.image.central_crop(image, central_fraction=.9)
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])
        
    # Pixel-level transforms
    if p_pixel >= .2:
        if p_pixel >= .8:
            image = tf.image.random_saturation(image, lower=0, upper=2)
        elif p_pixel >= .6:
            image = tf.image.random_contrast(image, lower=.8, upper=2)
        elif p_pixel >= .4:
            image = tf.image.random_brightness(image, max_delta=.2)
        else:
            image = tf.image.adjust_gamma(image, gamma=.6)

    return image, label

Tip: experimental_deterministic is set to decide whether the outputs need to be produced in deterministic order. Default: True
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
train_ds = tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
train_ds = (train_ds
            .map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
            .map(data_augment, num_parallel_calls=AUTOTUNE)
            .repeat()
            .shuffle(2048)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE))
valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
valid_ds = (valid_ds
            .with_options(option_no_order)
            .map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
            .batch(BATCH_SIZE)
            .cache()
            .prefetch(AUTOTUNE))
test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
            .with_options(option_no_order)
            .map(process_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE))

Model

Batch augmentation

Augmentation can be applied in two ways.

Important: The Keras Preprocessing Layers are currently experimental so it seems it does not have supporting TPU OpKernel yet.

#batch_augment = tf.keras.Sequential(
#    [
#     tf.keras.layers.experimental.preprocessing.RandomCrop(*IMG_SIZE),
#     tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
#     tf.keras.layers.experimental.preprocessing.RandomRotation(0.25),
#     tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0)),
#     tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
#    ]
#)

#func = lambda x,y: (batch_augment(x), y)
#x = (train_ds
#     .take(1)
#     .map(func, num_parallel_calls=AUTOTUNE))

Building a model

Now we're ready to create a neural network for classifying images! We'll use what's known as transfer learning. With transfer learning, you reuse the body part of a pretrained model and replace its' head or tail with custom layers depending on the problem that we are solving.

For this tutorial, we'll use EfficientNetb3 which is pretrained on ImageNet. Later, I might want to experiment with other models. (Xception wouldn't be a bad choice.)

Important: The distribution strategy we created earilier contains a context manager, straategy.scope. This context manager tells TensorFlow how to divide the work of training among the eight TPU cores. When using TensorFlow with a TPU, it’s important to define your model in strategy.sceop() context.
%%run_if {GOOGLE}
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications import VGG16
def build_model(base_model, num_class):
    inputs = tf.keras.layers.Input(shape=IMG_SIZE)
    x = base_model(inputs)
    x = tf.keras.layers.Dropout(0.4)(x)
    outputs = tf.keras.layers.Dense(num_class, activation="softmax", name="pred")(x)
    model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
    return model
with strategy.scope():
    efficientnet = EfficientNetB3(
        weights = 'imagenet' if TRAIN else None, 
        include_top = False, 
        input_shape = IMG_SIZE, 
        pooling='avg')
    efficientnet.trainable = True
    model = build_model(base_model=efficientnet, num_class=len(CLASSES))

Optimizer

Important: I always wanted to try the new CosineDecayRestarts function implemented in tf.keras as it seemed promising and I struggled to find the right settings (if there were any) for the ReduceLROnPlateau
STEPS = math.ceil(count_data_items(train_filenames) / BATCH_SIZE) * EPOCHS
LR_START = 1e-4 #@param {type: "number"}
LR_START *= strategy.num_replicas_in_sync
LR_MIN = 1e-5 #@param {type: "number"}
N_RESTARTS =  5#@param {type: "number"}
T_MUL = 2.0 #@param {type: "number"}
M_MUL =  1#@param {type: "number"}
STEPS_START = math.ceil((T_MUL-1)/(T_MUL**(N_RESTARTS+1)-1) * STEPS)

schedule = tf.keras.experimental.CosineDecayRestarts(
    first_decay_steps=STEPS_START,
    initial_learning_rate=LR_START,
    alpha=LR_MIN,
    m_mul=M_MUL,
    t_mul=T_MUL)

x = [i for i in range(STEPS)]
y = [schedule(s) for s in range(STEPS)]

_,ax = plt.subplots(1,1,figsize=(8,5),facecolor='#F0F0F0')
ax.plot(x, y)
ax.set_facecolor('#F8F8F8')
ax.set_xlabel('iteration')
ax.set_ylabel('learning rate')

print('{:d} total epochs and {:d} steps per epoch'
        .format(EPOCHS, STEPS // EPOCHS))
print(schedule.get_config())
50 total epochs and 50 steps per epoch
{'initial_learning_rate': 0.0008, 'first_decay_steps': 40, 't_mul': 2.0, 'm_mul': 1, 'alpha': 1e-05, 'name': None}

Callbacks

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='001_best_model.h5',
        monitor='val_loss',
        save_best_only=True),
    tf.keras.callbacks.EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=10,
        restore_best_weights=True,
        verbose=1)
    ]
    
model.compile(
    optimizer=tf.keras.optimizers.Adam(schedule),
    loss = 'sparse_categorical_crossentropy',
    metrics=['sparse_categorical_accuracy']
)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 300, 300, 3)]     0         
_________________________________________________________________
efficientnetb3 (Functional)  (None, 1536)              10783535  
_________________________________________________________________
dropout (Dropout)            (None, 1536)              0         
_________________________________________________________________
pred (Dense)                 (None, 104)               159848    
=================================================================
Total params: 10,943,383
Trainable params: 10,856,080
Non-trainable params: 87,303
_________________________________________________________________

Training

Train the normalization layer

%%run_if {GOOGLE}
def generate_norm_image(example):
    LABELED_TFREC_FORMAT = {
        "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        "class": tf.io.FixedLenFeature([], tf.int64),  # shape [] means single element
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    image = image / 255.0
    return image
%%run_if {GOOGLE}
if os.path.exists("000_normalization.h5"):
    model.load_weights("000_normalization.h5")
else:
    adapt_ds = (tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
                    .map(generate_norm_image, num_parallel_calls=AUTOTUNE)
                    .shuffle(2048)
                    .batch(BATCH_SIZE)
                    .prefetch(AUTOTUNE))
    model.get_layer('efficientnetb3').get_layer('normalization').adapt(adapt_ds)
    model.save_weights("000_normalization.h5")

Train all

history = model.fit(
    x=train_ds,
    validation_data=valid_ds,
    epochs=EPOCHS,
    steps_per_epoch=STEPS//BATCH_SIZE,
    callbacks=callbacks,
    verbose=2
)

Epoch 1/50
9/9 - 156s - loss: 1.9010 - sparse_categorical_accuracy: 0.5043 - val_loss: 13.0768 - val_sparse_categorical_accuracy: 0.2594
Epoch 2/50
9/9 - 6s - loss: 1.6419 - sparse_categorical_accuracy: 0.5647 - val_loss: 6.8306 - val_sparse_categorical_accuracy: 0.3257
Epoch 3/50
9/9 - 6s - loss: 1.5901 - sparse_categorical_accuracy: 0.5864 - val_loss: 3.9058 - val_sparse_categorical_accuracy: 0.4119
Epoch 4/50
9/9 - 6s - loss: 1.5140 - sparse_categorical_accuracy: 0.5877 - val_loss: 2.9674 - val_sparse_categorical_accuracy: 0.4723
Epoch 5/50
9/9 - 10s - loss: 1.5115 - sparse_categorical_accuracy: 0.5959 - val_loss: 2.2275 - val_sparse_categorical_accuracy: 0.5275
Epoch 6/50
9/9 - 7s - loss: 1.4713 - sparse_categorical_accuracy: 0.6050 - val_loss: 1.6515 - val_sparse_categorical_accuracy: 0.6051
Epoch 7/50
9/9 - 7s - loss: 1.4460 - sparse_categorical_accuracy: 0.6155 - val_loss: 1.3832 - val_sparse_categorical_accuracy: 0.6360
Epoch 8/50
9/9 - 7s - loss: 1.3661 - sparse_categorical_accuracy: 0.6337 - val_loss: 1.3275 - val_sparse_categorical_accuracy: 0.6509
Epoch 9/50
9/9 - 6s - loss: 1.3910 - sparse_categorical_accuracy: 0.6215 - val_loss: 1.2825 - val_sparse_categorical_accuracy: 0.6608
Epoch 10/50
9/9 - 6s - loss: 1.2811 - sparse_categorical_accuracy: 0.6562 - val_loss: 1.2094 - val_sparse_categorical_accuracy: 0.6797
Epoch 11/50
9/9 - 6s - loss: 1.2926 - sparse_categorical_accuracy: 0.6484 - val_loss: 1.1464 - val_sparse_categorical_accuracy: 0.6950
Epoch 12/50
9/9 - 6s - loss: 1.2898 - sparse_categorical_accuracy: 0.6302 - val_loss: 1.1089 - val_sparse_categorical_accuracy: 0.7050
Epoch 13/50
9/9 - 6s - loss: 1.2204 - sparse_categorical_accuracy: 0.6784 - val_loss: 1.0893 - val_sparse_categorical_accuracy: 0.7096
Epoch 14/50
9/9 - 6s - loss: 1.2152 - sparse_categorical_accuracy: 0.6623 - val_loss: 1.0783 - val_sparse_categorical_accuracy: 0.7142
Epoch 15/50
9/9 - 6s - loss: 1.2462 - sparse_categorical_accuracy: 0.6576 - val_loss: 1.0587 - val_sparse_categorical_accuracy: 0.7217
Epoch 16/50
9/9 - 6s - loss: 1.2206 - sparse_categorical_accuracy: 0.6688 - val_loss: 0.9961 - val_sparse_categorical_accuracy: 0.7314
Epoch 17/50
9/9 - 7s - loss: 1.2518 - sparse_categorical_accuracy: 0.6636 - val_loss: 1.0047 - val_sparse_categorical_accuracy: 0.7368
Epoch 18/50
9/9 - 7s - loss: 1.1942 - sparse_categorical_accuracy: 0.6727 - val_loss: 0.9873 - val_sparse_categorical_accuracy: 0.7384
Epoch 19/50
9/9 - 9s - loss: 1.1119 - sparse_categorical_accuracy: 0.7010 - val_loss: 0.9598 - val_sparse_categorical_accuracy: 0.7395
Epoch 20/50
9/9 - 7s - loss: 1.1494 - sparse_categorical_accuracy: 0.6944 - val_loss: 0.9371 - val_sparse_categorical_accuracy: 0.7532
Epoch 21/50
9/9 - 7s - loss: 1.1256 - sparse_categorical_accuracy: 0.6879 - val_loss: 0.9200 - val_sparse_categorical_accuracy: 0.7627
Epoch 22/50
9/9 - 6s - loss: 1.0482 - sparse_categorical_accuracy: 0.7201 - val_loss: 0.8943 - val_sparse_categorical_accuracy: 0.7694
Epoch 23/50
9/9 - 6s - loss: 1.0809 - sparse_categorical_accuracy: 0.7153 - val_loss: 0.9161 - val_sparse_categorical_accuracy: 0.7648
Epoch 24/50
9/9 - 6s - loss: 1.0814 - sparse_categorical_accuracy: 0.7114 - val_loss: 0.9213 - val_sparse_categorical_accuracy: 0.7619
Epoch 25/50
9/9 - 6s - loss: 1.0078 - sparse_categorical_accuracy: 0.7214 - val_loss: 0.8947 - val_sparse_categorical_accuracy: 0.7675
Epoch 26/50
9/9 - 7s - loss: 0.9709 - sparse_categorical_accuracy: 0.7391 - val_loss: 0.8600 - val_sparse_categorical_accuracy: 0.7753
Epoch 27/50
9/9 - 6s - loss: 1.0287 - sparse_categorical_accuracy: 0.7270 - val_loss: 0.8354 - val_sparse_categorical_accuracy: 0.7872
Epoch 28/50
9/9 - 6s - loss: 0.9749 - sparse_categorical_accuracy: 0.7396 - val_loss: 0.8162 - val_sparse_categorical_accuracy: 0.7920
Epoch 29/50
9/9 - 6s - loss: 0.9431 - sparse_categorical_accuracy: 0.7383 - val_loss: 0.8065 - val_sparse_categorical_accuracy: 0.7939
Epoch 30/50
9/9 - 7s - loss: 0.9715 - sparse_categorical_accuracy: 0.7422 - val_loss: 0.7995 - val_sparse_categorical_accuracy: 0.7936
Epoch 31/50
9/9 - 7s - loss: 0.9164 - sparse_categorical_accuracy: 0.7530 - val_loss: 0.7948 - val_sparse_categorical_accuracy: 0.7947
Epoch 32/50
9/9 - 6s - loss: 0.9203 - sparse_categorical_accuracy: 0.7487 - val_loss: 0.8642 - val_sparse_categorical_accuracy: 0.7683
Epoch 33/50
9/9 - 7s - loss: 0.9915 - sparse_categorical_accuracy: 0.7313 - val_loss: 0.8313 - val_sparse_categorical_accuracy: 0.7823
Epoch 34/50
9/9 - 6s - loss: 1.0157 - sparse_categorical_accuracy: 0.7296 - val_loss: 0.8204 - val_sparse_categorical_accuracy: 0.7842
Epoch 35/50
9/9 - 7s - loss: 1.0600 - sparse_categorical_accuracy: 0.7188 - val_loss: 0.8281 - val_sparse_categorical_accuracy: 0.7866
Epoch 36/50
9/9 - 6s - loss: 0.9717 - sparse_categorical_accuracy: 0.7361 - val_loss: 0.8413 - val_sparse_categorical_accuracy: 0.7761
Epoch 37/50
9/9 - 7s - loss: 0.9517 - sparse_categorical_accuracy: 0.7326 - val_loss: 0.8223 - val_sparse_categorical_accuracy: 0.7839
Epoch 38/50
9/9 - 7s - loss: 0.8959 - sparse_categorical_accuracy: 0.7626 - val_loss: 0.7921 - val_sparse_categorical_accuracy: 0.7874
Epoch 39/50
9/9 - 7s - loss: 1.0345 - sparse_categorical_accuracy: 0.7196 - val_loss: 0.7694 - val_sparse_categorical_accuracy: 0.8036
Epoch 40/50
9/9 - 7s - loss: 0.9000 - sparse_categorical_accuracy: 0.7465 - val_loss: 0.7594 - val_sparse_categorical_accuracy: 0.8036
Epoch 41/50
9/9 - 7s - loss: 0.8503 - sparse_categorical_accuracy: 0.7626 - val_loss: 0.7567 - val_sparse_categorical_accuracy: 0.8012
Epoch 42/50
9/9 - 6s - loss: 0.8809 - sparse_categorical_accuracy: 0.7530 - val_loss: 0.7373 - val_sparse_categorical_accuracy: 0.8082
Epoch 43/50
9/9 - 6s - loss: 0.9019 - sparse_categorical_accuracy: 0.7539 - val_loss: 0.7463 - val_sparse_categorical_accuracy: 0.8106
Epoch 44/50
9/9 - 6s - loss: 0.8453 - sparse_categorical_accuracy: 0.7674 - val_loss: 0.7092 - val_sparse_categorical_accuracy: 0.8155
Epoch 45/50
9/9 - 7s - loss: 0.7987 - sparse_categorical_accuracy: 0.7687 - val_loss: 0.7131 - val_sparse_categorical_accuracy: 0.8152
Epoch 46/50
9/9 - 6s - loss: 0.8338 - sparse_categorical_accuracy: 0.7769 - val_loss: 0.7005 - val_sparse_categorical_accuracy: 0.8171
Epoch 47/50
9/9 - 7s - loss: 0.7969 - sparse_categorical_accuracy: 0.7799 - val_loss: 0.6921 - val_sparse_categorical_accuracy: 0.8230
Epoch 48/50
9/9 - 6s - loss: 0.7685 - sparse_categorical_accuracy: 0.7891 - val_loss: 0.6948 - val_sparse_categorical_accuracy: 0.8198
Epoch 49/50
9/9 - 7s - loss: 0.7290 - sparse_categorical_accuracy: 0.7986 - val_loss: 0.6804 - val_sparse_categorical_accuracy: 0.8254
Epoch 50/50
9/9 - 7s - loss: 0.7546 - sparse_categorical_accuracy: 0.7925 - val_loss: 0.6556 - val_sparse_categorical_accuracy: 0.8349

Training curve

def show_history(history):
    topics = ['loss', 'accuracy']
    groups = [{k:v for (k,v) in history.items() if topic in k} for topic in topics]
    _,axs = plt.subplots(1,2,figsize=(15,6),facecolor='#F0F0F0')
    for topic,group,ax in zip(topics,groups,axs.flatten()):
        for (_,v) in group.items(): ax.plot(v)
        ax.set_facecolor('#F8F8F8')
        ax.set_title(f'{topic} over epochs')
        ax.set_xlabel('epoch')
        ax.set_ylabel(topic)
        ax.legend(['train', 'valid'], loc='best')
show_history(history.history)

def show_confusion_matrix(cmat, score, precision, recall):
    _,ax = plt.subplots(1,1,figsize=(12,12),facecolor='#F0F0F0')
    ax.matshow(cmat, cmap='Blues')
    if len(CLASSES) <= 10:
        ax.set_xticks(range(len(CLASSES)),)
        ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
        plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
        ax.set_yticks(range(len(CLASSES)))
        ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
        plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    else: ax.axis('off')

    textstr = ""
    if precision: textstr += 'precision = {:.3f} '.format(precision)
    if recall: textstr += '\nrecall = {:.3f} '.format(recall)
    if score: textstr += '\nf1 = {:.3f} '.format(score)
    if len(textstr) > 0:
        props = dict(boxstyle='round', facecolor='wheat', alpha=0.2)
        ax.text(0.75, 0.95, textstr, transform=ax.transAxes, fontsize=14,
                verticalalignment='top', bbox=props)
    plt.show()
ordered_valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
ordered_valid_ds = (ordered_valid_ds
            .map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
            .batch(BATCH_SIZE)
            .cache()
            .prefetch(AUTOTUNE))

x_valid_ds = ordered_valid_ds.map(lambda x,y : x, num_parallel_calls=AUTOTUNE)
y_valid_ds = ordered_valid_ds.map(lambda x,y : y, num_parallel_calls=AUTOTUNE)
y_true = (y_valid_ds
        .unbatch()
        .batch(count_data_items(valid_filenames))
        .as_numpy_iterator()
        .next())
y_probs = model.predict(x_valid_ds)
y_preds = np.argmax(y_probs, axis=-1)
label_ids = range(len(CLASSES))
cmatrix = confusion_matrix(y_true, y_preds, labels=label_ids)
cmatrix = (cmatrix.T / cmatrix.sum(axis=1)).T # normalize

You might be familiar with metrics like F1-score or precision and recall. This cell will compute these metrics and display them with a plot of the confusion matrix. (These metrics are defined in the Scikit-learn module sklearn.metrics; we've imported them in the helper script for you.)

precision = precision_score(y_true, y_preds, labels=label_ids, average='macro')
recall = recall_score(y_true, y_preds, labels=label_ids, average='macro')
score = f1_score(y_true, y_preds, labels=label_ids,average='macro')
show_confusion_matrix(cmatrix, score, precision, recall)

Prediction

Once you're satisfied with everything, you're ready to make predictions on the test set.

test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
            .map(process_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
            .batch(BATCH_SIZE)
            .prefetch(AUTOTUNE))
            
x_test_ds = test_ds.map(lambda image,idnum: image)
y_probs = model.predict(x_test_ds)
y_preds = np.argmax(y_probs, axis=-1)

Let's generate a file submission.csv. This file is what you'll submit to get your score on the leaderboard.

id_test_ds = test_ds.map(lambda image,idnum: idnum)
id_test_ds = (id_test_ds.unbatch()
           .batch(count_data_items(test_filenames))
           .as_numpy_iterator()
           .next()
           .astype('U'))
np.savetxt('submission.csv',
           np.rec.fromarrays([id_test_ds, y_preds]),
           fmt=['%s', '%d'],
           delimiter=',',
           header='id,label',
           comments='')
!head submission.csv
id,label
252d840db,67
1c4736dea,28
c37a6f3e9,83
00e4f514e,103
59d1b6146,46
8d808a07b,53
aeb67eefb,52
53cfc6586,71
aaa580243,85