Petals to the Metal
Getting Started with TPUs on Kaggle
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, re
import warnings, math, sys, json
import subprocess, pprint, pdb
import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score,precision_score, recall_score, confusion_matrix
warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
def seed_everything(seed=0):
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
os.environ['TF_DETERMINISTIC_OPS'] = '1'
GOOGLE = 'google.colab' in str(get_ipython())
KAGGLE = not GOOGLE
print("Running on {}!".format(
"Google Colab" if GOOGLE else "Kaggle Kernel"
))
BASE_MODEL = 'efficientnet_b3' #@param ["'efficientnet_b3'", "'efficientnet_b4'", "'efficientnet_b2'"] {type:"raw", allow-input: true}
HEIGHT = 300#@param {type:"number"}
WIDTH = 300#@param {type:"number"}
CHANNELS = 3#@param {type:"number"}
IMG_SIZE = (HEIGHT, WIDTH, CHANNELS)
EPOCHS = 50#@param {type:"number"}
BATCH_SIZE = 32 * strategy.num_replicas_in_sync #@param {type:"raw"}
print("Use {} with input size {}".format(BASE_MODEL, IMG_SIZE))
print("Train on batch size of {} for {} epochs".format(BATCH_SIZE, EPOCHS))
%%run_if {GOOGLE}
#@title {run: "auto", display-mode: "form"}
GCS_PATH = 'gs://kds-c6b9829baa483a13a169c7cbe266341fb8c9b5ba36843af37a093a4c' #@param {type: "string"}
GCS_PATH += '/tfrecords-jpeg-512x512' #@param {type: "string"}
print(f"Sourcing images from {GCS_PATH}")
CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'wild geranium',
'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle',
'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris',
'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower','giant white arum lily',
'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth',
'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william',
'carnation', 'garden phlox', 'love in the mist', 'cosmos', 'alpine sea holly',
'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose',
'barberton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue',
'wallflower', 'marigold', 'buttercup', 'daisy', 'common dandelion',
'petunia', 'wild pansy', 'primula', 'sunflower', 'lilac hibiscus',
'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia',
'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy',
'osteospermum', 'spring crocus', 'iris', 'windflower', 'tree poppy',
'gazania', 'azalea', 'water lily', 'rose', 'thorn apple',
'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium',
'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose',
'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily',
'hippeastrum ', 'bee balm', 'pink quill', 'foxglove', 'bougainvillea',
'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower',
'trumpet creeper', 'blackberry lily', 'common tulip', 'wild rose']
NCLASSES = len(CLASSES)
print(f"Number of labels: {NCLASSES}")
def decode_image(image_data):
image = tf.image.decode_jpeg(image_data, channels=CHANNELS)
image = (tf.cast(image, tf.float32) if GOOGLE
else tf.cast(image, tf.float32) / 255.0)
image = tf.image.random_crop(image, IMG_SIZE)
return image
def collate_labeled_tfrecord(example):
LABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
image = decode_image(example['image'])
label = tf.cast(example['class'], tf.int32)
return image, label
def process_unlabeled_tfrecord(example):
UNLABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"id": tf.io.FixedLenFeature([], tf.string), # shape [] means single element
}
example = tf.io.parse_single_example(example, UNLABELED_TFREC_FORMAT)
image = decode_image(example['image'])
idnum = example['id']
return image, idnum
def count_data_items(filenames):
n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
for filename in filenames]
return np.sum(n)
train_filenames = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec')
test_filenames = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec')
def transform_shear(image, height, shear):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly sheared
DIM = height
XDIM = DIM%2 #fix for size 331
shear = shear * tf.random.uniform([1],dtype='float32')
shear = math.pi * shear / 180.
# SHEAR MATRIX
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
c2 = tf.math.cos(shear)
s2 = tf.math.sin(shear)
shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_shift(image, height, h_shift, w_shift):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly shifted
DIM = height
XDIM = DIM%2 #fix for size 331
height_shift = h_shift * tf.random.uniform([1],dtype='float32')
width_shift = w_shift * tf.random.uniform([1],dtype='float32')
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
# SHIFT MATRIX
shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def transform_rotation(image, height, rotation):
# input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
# output - image randomly rotated
DIM = height
XDIM = DIM%2 #fix for size 331
rotation = rotation * tf.random.uniform([1],dtype='float32')
# CONVERT DEGREES TO RADIANS
rotation = math.pi * rotation / 180.
# ROTATION MATRIX
c1 = tf.math.cos(rotation)
s1 = tf.math.sin(rotation)
one = tf.constant([1],dtype='float32')
zero = tf.constant([0],dtype='float32')
rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])
# LIST DESTINATION PIXEL INDICES
x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
z = tf.ones([DIM*DIM],dtype='int32')
idx = tf.stack( [x,y,z] )
# ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
idx2 = K.cast(idx2,dtype='int32')
idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
# FIND ORIGIN PIXEL VALUES
idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
d = tf.gather_nd(image, tf.transpose(idx3))
return tf.reshape(d,[DIM,DIM,3])
def data_augment(image, label):
p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_shift = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
# Flips
if p_spatial >= .2:
image = tf.image.random_flip_left_right(image)
image = tf.image.random_flip_up_down(image)
# Rotates
if p_rotate > .75:
image = tf.image.rot90(image, k=3) # rotate 270º
elif p_rotate > .5:
image = tf.image.rot90(image, k=2) # rotate 180º
elif p_rotate > .25:
image = tf.image.rot90(image, k=1) # rotate 90º
if p_rotation >= .3: # Rotation
image = transform_rotation(image, height=HEIGHT, rotation=45.)
if p_shift >= .3: # Shift
image = transform_shift(image, height=HEIGHT, h_shift=15., w_shift=15.)
if p_shear >= .3: # Shear
image = transform_shear(image, height=HEIGHT, shear=20.)
# Crops
if p_crop > .4:
crop_size = tf.random.uniform([], int(HEIGHT*.7), HEIGHT, dtype=tf.int32)
image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
elif p_crop > .7:
if p_crop > .9:
image = tf.image.central_crop(image, central_fraction=.7)
elif p_crop > .8:
image = tf.image.central_crop(image, central_fraction=.8)
else:
image = tf.image.central_crop(image, central_fraction=.9)
image = tf.image.resize(image, size=[HEIGHT, WIDTH])
# Pixel-level transforms
if p_pixel >= .2:
if p_pixel >= .8:
image = tf.image.random_saturation(image, lower=0, upper=2)
elif p_pixel >= .6:
image = tf.image.random_contrast(image, lower=.8, upper=2)
elif p_pixel >= .4:
image = tf.image.random_brightness(image, max_delta=.2)
else:
image = tf.image.adjust_gamma(image, gamma=.6)
return image, label
experimental_deterministic
is set to decide whether the outputs need to be produced in deterministic order. Default: True
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False
train_ds = tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
train_ds = (train_ds
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.map(data_augment, num_parallel_calls=AUTOTUNE)
.repeat()
.shuffle(2048)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
valid_ds = (valid_ds
.with_options(option_no_order)
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.cache()
.prefetch(AUTOTUNE))
test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
.with_options(option_no_order)
.map(process_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
Augmentation can be applied in two ways.
- Using the Keras Preprocessing Layers
- Using the
tf.image
#batch_augment = tf.keras.Sequential(
# [
# tf.keras.layers.experimental.preprocessing.RandomCrop(*IMG_SIZE),
# tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
# tf.keras.layers.experimental.preprocessing.RandomRotation(0.25),
# tf.keras.layers.experimental.preprocessing.RandomZoom((-0.2, 0)),
# tf.keras.layers.experimental.preprocessing.RandomContrast((0.2,0.2))
# ]
#)
#func = lambda x,y: (batch_augment(x), y)
#x = (train_ds
# .take(1)
# .map(func, num_parallel_calls=AUTOTUNE))
Now we're ready to create a neural network for classifying images! We'll use what's known as transfer learning. With transfer learning, you reuse the body part of a pretrained model and replace its' head or tail with custom layers depending on the problem that we are solving.
For this tutorial, we'll use EfficientNetb3
which is pretrained on ImageNet. Later, I might want to experiment with other models. (Xception wouldn't be a bad choice.)
straategy.scope
. This context manager tells TensorFlow how to divide the work of training among the eight TPU cores. When using TensorFlow with a TPU, it’s important to define your model in strategy.sceop()
context.
%%run_if {GOOGLE}
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.applications import VGG16
def build_model(base_model, num_class):
inputs = tf.keras.layers.Input(shape=IMG_SIZE)
x = base_model(inputs)
x = tf.keras.layers.Dropout(0.4)(x)
outputs = tf.keras.layers.Dense(num_class, activation="softmax", name="pred")(x)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
return model
with strategy.scope():
efficientnet = EfficientNetB3(
weights = 'imagenet' if TRAIN else None,
include_top = False,
input_shape = IMG_SIZE,
pooling='avg')
efficientnet.trainable = True
model = build_model(base_model=efficientnet, num_class=len(CLASSES))
CosineDecayRestarts
function implemented in tf.keras
as it seemed promising and I struggled to find the right settings (if there were any) for the ReduceLROnPlateau
STEPS = math.ceil(count_data_items(train_filenames) / BATCH_SIZE) * EPOCHS
LR_START = 1e-4 #@param {type: "number"}
LR_START *= strategy.num_replicas_in_sync
LR_MIN = 1e-5 #@param {type: "number"}
N_RESTARTS = 5#@param {type: "number"}
T_MUL = 2.0 #@param {type: "number"}
M_MUL = 1#@param {type: "number"}
STEPS_START = math.ceil((T_MUL-1)/(T_MUL**(N_RESTARTS+1)-1) * STEPS)
schedule = tf.keras.experimental.CosineDecayRestarts(
first_decay_steps=STEPS_START,
initial_learning_rate=LR_START,
alpha=LR_MIN,
m_mul=M_MUL,
t_mul=T_MUL)
x = [i for i in range(STEPS)]
y = [schedule(s) for s in range(STEPS)]
_,ax = plt.subplots(1,1,figsize=(8,5),facecolor='#F0F0F0')
ax.plot(x, y)
ax.set_facecolor('#F8F8F8')
ax.set_xlabel('iteration')
ax.set_ylabel('learning rate')
print('{:d} total epochs and {:d} steps per epoch'
.format(EPOCHS, STEPS // EPOCHS))
print(schedule.get_config())
callbacks = [
tf.keras.callbacks.ModelCheckpoint(
filepath='001_best_model.h5',
monitor='val_loss',
save_best_only=True),
tf.keras.callbacks.EarlyStopping(
monitor='val_loss',
mode='min',
patience=10,
restore_best_weights=True,
verbose=1)
]
model.compile(
optimizer=tf.keras.optimizers.Adam(schedule),
loss = 'sparse_categorical_crossentropy',
metrics=['sparse_categorical_accuracy']
)
model.summary()
%%run_if {GOOGLE}
def generate_norm_image(example):
LABELED_TFREC_FORMAT = {
"image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
"class": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element
}
example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
image = decode_image(example['image'])
image = image / 255.0
return image
%%run_if {GOOGLE}
if os.path.exists("000_normalization.h5"):
model.load_weights("000_normalization.h5")
else:
adapt_ds = (tf.data.TFRecordDataset(train_filenames, num_parallel_reads=AUTOTUNE)
.map(generate_norm_image, num_parallel_calls=AUTOTUNE)
.shuffle(2048)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
model.get_layer('efficientnetb3').get_layer('normalization').adapt(adapt_ds)
model.save_weights("000_normalization.h5")
history = model.fit(
x=train_ds,
validation_data=valid_ds,
epochs=EPOCHS,
steps_per_epoch=STEPS//BATCH_SIZE,
callbacks=callbacks,
verbose=2
)
def show_history(history):
topics = ['loss', 'accuracy']
groups = [{k:v for (k,v) in history.items() if topic in k} for topic in topics]
_,axs = plt.subplots(1,2,figsize=(15,6),facecolor='#F0F0F0')
for topic,group,ax in zip(topics,groups,axs.flatten()):
for (_,v) in group.items(): ax.plot(v)
ax.set_facecolor('#F8F8F8')
ax.set_title(f'{topic} over epochs')
ax.set_xlabel('epoch')
ax.set_ylabel(topic)
ax.legend(['train', 'valid'], loc='best')
show_history(history.history)
def show_confusion_matrix(cmat, score, precision, recall):
_,ax = plt.subplots(1,1,figsize=(12,12),facecolor='#F0F0F0')
ax.matshow(cmat, cmap='Blues')
if len(CLASSES) <= 10:
ax.set_xticks(range(len(CLASSES)),)
ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
ax.set_yticks(range(len(CLASSES)))
ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
else: ax.axis('off')
textstr = ""
if precision: textstr += 'precision = {:.3f} '.format(precision)
if recall: textstr += '\nrecall = {:.3f} '.format(recall)
if score: textstr += '\nf1 = {:.3f} '.format(score)
if len(textstr) > 0:
props = dict(boxstyle='round', facecolor='wheat', alpha=0.2)
ax.text(0.75, 0.95, textstr, transform=ax.transAxes, fontsize=14,
verticalalignment='top', bbox=props)
plt.show()
ordered_valid_ds = tf.data.TFRecordDataset(valid_filenames, num_parallel_reads=AUTOTUNE)
ordered_valid_ds = (ordered_valid_ds
.map(collate_labeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.cache()
.prefetch(AUTOTUNE))
x_valid_ds = ordered_valid_ds.map(lambda x,y : x, num_parallel_calls=AUTOTUNE)
y_valid_ds = ordered_valid_ds.map(lambda x,y : y, num_parallel_calls=AUTOTUNE)
y_true = (y_valid_ds
.unbatch()
.batch(count_data_items(valid_filenames))
.as_numpy_iterator()
.next())
y_probs = model.predict(x_valid_ds)
y_preds = np.argmax(y_probs, axis=-1)
label_ids = range(len(CLASSES))
cmatrix = confusion_matrix(y_true, y_preds, labels=label_ids)
cmatrix = (cmatrix.T / cmatrix.sum(axis=1)).T # normalize
You might be familiar with metrics like F1-score or precision and recall. This cell will compute these metrics and display them with a plot of the confusion matrix. (These metrics are defined in the Scikit-learn module sklearn.metrics
; we've imported them in the helper script for you.)
precision = precision_score(y_true, y_preds, labels=label_ids, average='macro')
recall = recall_score(y_true, y_preds, labels=label_ids, average='macro')
score = f1_score(y_true, y_preds, labels=label_ids,average='macro')
show_confusion_matrix(cmatrix, score, precision, recall)
test_ds = tf.data.TFRecordDataset(test_filenames, num_parallel_reads=AUTOTUNE)
test_ds = (test_ds
.map(process_unlabeled_tfrecord, num_parallel_calls=AUTOTUNE)
.batch(BATCH_SIZE)
.prefetch(AUTOTUNE))
x_test_ds = test_ds.map(lambda image,idnum: image)
y_probs = model.predict(x_test_ds)
y_preds = np.argmax(y_probs, axis=-1)
Let's generate a file submission.csv
. This file is what you'll submit to get your score on the leaderboard.
id_test_ds = test_ds.map(lambda image,idnum: idnum)
id_test_ds = (id_test_ds.unbatch()
.batch(count_data_items(test_filenames))
.as_numpy_iterator()
.next()
.astype('U'))
np.savetxt('submission.csv',
np.rec.fromarrays([id_test_ds, y_preds]),
fmt=['%s', '%d'],
delimiter=',',
header='id,label',
comments='')
!head submission.csv