Custom Everything with TPU
Getting start to customize everything with TPU
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, re, datetime
import warnings, math, sys, json
import subprocess, pprint, pdb
from functools import partial
import tensorflow as tf
from tensorflow.keras import backend as K
import tensorflow_hub as hub
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix
warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
def seed_everything(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['TF_DETERMINISTIC_OPS'] = '1'
GOOGLE = 'google.colab' in str(get_ipython())
KAGGLE = not GOOGLE
print("Running on {}!".format(
   "Google Colab" if GOOGLE else "Kaggle Kernel"
))
HEIGHT = 224#@param {type:"number"}
WIDTH = 224#@param {type:"number"}
CHANNELS = 3#@param {type:"number"}
IMG_SIZE = (HEIGHT, WIDTH, CHANNELS)
EPOCHS =  8#@param {type:"number"}
BATCH_SIZE = 8#@param {type:"raw"}
GLOBAL_BATCH_SIZE = BATCH_SIZE * strategy.num_replicas_in_sync
print("Input size: {}".format(IMG_SIZE))
print("Train on batch size of {} with {} replicas for {} epochs".format(
    BATCH_SIZE, strategy.num_replicas_in_sync, EPOCHS))
%%run_if {GOOGLE}
#@title {run: "auto", display-mode: "form"}
# reference: https://www.kaggle.com/austinyhc/custom-training-with-tpu?scriptVersionId=51687595
GCS_DS_PATH = 'gs://kds-f7aaa241d2ceea308646ba649de83022e9089736f446906c81c4f8a0' #@param {type: "string"}
GCS_PATH_SELECT = {
    192: GCS_DS_PATH + '/tfrecords-jpeg-192x192',
    224: GCS_DS_PATH + '/tfrecords-jpeg-224x224',
    331: GCS_DS_PATH + '/tfrecords-jpeg-331x331',
    512: GCS_DS_PATH + '/tfrecords-jpeg-512x512'
}
print(f"Sourcing images from")
for v in GCS_PATH_SELECT.values(): print(f"\t{v}")
CLASSES = ['pink primrose',        'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea',     'wild geranium',
           'tiger lily',           'moon orchid',               'bird of paradise', 'monkshood',     'globe thistle',        
           'snapdragon',           "colt's foot",               'king protea',      'spear thistle', 'yellow iris',
           'globe-flower',         'purple coneflower',         'peruvian lily',    'balloon flower','giant white arum lily',
           'fire lily',            'pincushion flower',         'fritillary',       'red ginger',    'grape hyacinth',
           'corn poppy',           'prince of wales feathers',  'stemless gentian', 'artichoke',     'sweet william',        
           'carnation',            'garden phlox',              'love in the mist', 'cosmos',        'alpine sea holly',
           'ruby-lipped cattleya', 'cape flower',               'great masterwort', 'siam tulip',    'lenten rose',          
           'barberton daisy',      'daffodil',                  'sword lily',       'poinsettia',    'bolero deep blue',
           'wallflower',           'marigold',                  'buttercup',        'daisy',         'common dandelion',     
           'petunia',              'wild pansy',                'primula',          'sunflower',     'lilac hibiscus',
           'bishop of llandaff',   'gaura',                     'geranium',         'orange dahlia', 'pink-yellow dahlia',   
           'cautleya spicata',     'japanese anemone',          'black-eyed susan', 'silverbush',    'californian poppy',
           'osteospermum',         'spring crocus',             'iris',             'windflower',    'tree poppy',           
           'gazania',              'azalea',                    'water lily',       'rose',          'thorn apple',
           'morning glory',        'passion flower',            'lotus',            'toad lily',     'anthurium',
           'frangipani',           'clematis',                  'hibiscus',         'columbine',     'desert-rose',
           'tree mallow',          'magnolia',                  'cyclamen ',        'watercress',    'canna lily',           
           'hippeastrum ',         'bee balm',                  'pink quill',       'foxglove',      'bougainvillea',
           'camellia',             'mallow',                    'mexican petunia',  'bromelia',      'blanket flower',       
           'trumpet creeper',      'blackberry lily',           'common tulip',     'wild rose']
with strategy.scope():
    NCLASSES = len(CLASSES)
print(f"Number of labels: {NCLASSES}")
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1))
            for filename in filenames]
    return np.sum(n)
def inspect_tfrecord(tfrec):
    raw_dataset = tf.data.TFRecordDataset(tfrec)
    for raw_record in raw_dataset.take(1):
        example = tf.train.Example()
        example.ParseFromString(raw_record.numpy())
    print(example)
train_filenames = tf.io.gfile.glob(GCS_PATH_SELECT[HEIGHT] + '/train/*.tfrec')
valid_filenames = tf.io.gfile.glob(GCS_PATH_SELECT[HEIGHT] + '/val/*.tfrec')
test_filenames  = tf.io.gfile.glob(GCS_PATH_SELECT[HEIGHT] + '/test/*.tfrec') 
inspect_tfrecord(train_filenames)
def decode_image(image_string):
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, IMG_SIZE)
    return image
def read_tfrecord(example, labeled=True):
    TFREC_FORMAT = {
        "image" : tf.io.FixedLenFeature([],tf.string),
        "class" : tf.io.FixedLenFeature([],tf.int64),
        "id"    : tf.io.FixedLenFeature([],tf.string),
    } if labeled else {
        "image" : tf.io.FixedLenFeature([],tf.string),
        "id"    : tf.io.FixedLenFeature([],tf.string),
    }
    example = tf.io.parse_single_example(example, TFREC_FORMAT)
    image = decode_image(example['image'])
    label = tf.cast(example['class'], tf.int32) if labeled else example['id']
    return image, label
def transform_shear(image, height, shear):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly sheared
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    shear = shear * tf.random.uniform([1],dtype='float32')
    shear = math.pi * shear / 180.
        
    # SHEAR MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    c2 = tf.math.cos(shear)
    s2 = tf.math.sin(shear)
    shear_matrix = tf.reshape(tf.concat([one,s2,zero, zero,c2,zero, zero,zero,one],axis=0),[3,3])    
    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shear_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])
def transform_shift(image, height, h_shift, w_shift):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly shifted
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    height_shift = h_shift * tf.random.uniform([1],dtype='float32') 
    width_shift = w_shift * tf.random.uniform([1],dtype='float32') 
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
        
    # SHIFT MATRIX
    shift_matrix = tf.reshape(tf.concat([one,zero,height_shift, zero,one,width_shift, zero,zero,one],axis=0),[3,3])
    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(shift_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])
def transform_rotation(image, height, rotation):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated
    DIM = height
    XDIM = DIM%2 #fix for size 331
    
    rotation = rotation * tf.random.uniform([1],dtype='float32')
    # CONVERT DEGREES TO RADIANS
    rotation = math.pi * rotation / 180.
    
    # ROTATION MATRIX
    c1 = tf.math.cos(rotation)
    s1 = tf.math.sin(rotation)
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    rotation_matrix = tf.reshape(tf.concat([c1,s1,zero, -s1,c1,zero, zero,zero,one],axis=0),[3,3])
    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM )
    y = tf.tile( tf.range(-DIM//2,DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(rotation_matrix,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES 
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    d = tf.gather_nd(image, tf.transpose(idx3))
        
    return tf.reshape(d,[DIM,DIM,3])
def data_augment(image, label):
    p_rotation = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_spatial = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_rotate = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_pixel = tf.random.uniform([], 0, 1.0, dtype=tf.float32)    
    p_shear = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_shift = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    p_crop = tf.random.uniform([], 0, 1.0, dtype=tf.float32)
    
    # Flips
    if p_spatial >= .2:
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        
    # Rotates
    if p_rotate > .75:
        image = tf.image.rot90(image, k=3) # rotate 270º
    elif p_rotate > .5:
        image = tf.image.rot90(image, k=2) # rotate 180º
    elif p_rotate > .25:
        image = tf.image.rot90(image, k=1) # rotate 90º
    
    if p_rotation >= .3: # Rotation
        image = transform_rotation(image, height=HEIGHT, rotation=45.)
    if p_shift >= .3: # Shift
        image = transform_shift(image, height=HEIGHT, h_shift=15., w_shift=15.)
    if p_shear >= .3: # Shear
        image = transform_shear(image, height=HEIGHT, shear=20.)
        
    # Crops
    if p_crop > .4:
        crop_size = tf.random.uniform([], int(HEIGHT*.7), HEIGHT, dtype=tf.int32)
        image = tf.image.random_crop(image, size=[crop_size, crop_size, CHANNELS])
    elif p_crop > .7:
        if p_crop > .9:
            image = tf.image.central_crop(image, central_fraction=.7)
        elif p_crop > .8:
            image = tf.image.central_crop(image, central_fraction=.8)
        else:
            image = tf.image.central_crop(image, central_fraction=.9)
            
    image = tf.image.resize(image, size=[HEIGHT, WIDTH])
        
    # Pixel-level transforms
    if p_pixel >= .2:
        if p_pixel >= .8:
            image = tf.image.random_saturation(image, lower=0, upper=2)
        elif p_pixel >= .6:
            image = tf.image.random_contrast(image, lower=.8, upper=2)
        elif p_pixel >= .4:
            image = tf.image.random_brightness(image, max_delta=.2)
        else:
            image = tf.image.adjust_gamma(image, gamma=.6)
    return image, label
def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered: ignore_order.experimental_deterministic = False
    ds = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTOTUNE)
    ds = ds.with_options(ignore_order)
    ds = ds.map(partial(read_tfrecord, labeled=labeled), num_parallel_calls=AUTOTUNE)
    return ds
def get_train_dataset(filenames, bs=GLOBAL_BATCH_SIZE):
    ds = load_dataset(filenames, labeled=True)
    ds = ds.map(data_augment, num_parallel_calls=AUTOTUNE)
    ds = ds.shuffle(2048)
    ds = ds.batch(bs, drop_remainder=True)
    # prefetch the next batch while training
    ds = ds.prefetch(AUTOTUNE)
    return ds
 
def get_valid_dataset(filenames, bs=GLOBAL_BATCH_SIZE, ordered=False):
    ds = load_dataset(filenames, labeled=True, ordered=ordered)
    ds = ds.batch(bs, drop_remainder=True)
    ds= ds.cache()
    # prefetch the next batch while training
    ds = ds.prefetch(AUTOTUNE)
    return ds
def get_test_dataset(filenames, bs=GLOBAL_BATCH_SIZE, ordered=False): 
    ds = load_dataset(filenames, labeled=False, ordered=ordered)
    ds = ds.batch(bs)
    # prefetch the next batch while training
    ds = ds.prefetch(AUTOTUNE)
    return ds
train_ds = get_train_dataset(train_filenames)
valid_ds = get_valid_dataset(valid_filenames)
test_ds  = get_test_dataset(test_filenames)
def show_images(ds):
    _,axs = plt.subplots(3,3,figsize=(16,16))
    for ((x, y), ax) in zip(ds.take(9), axs.flatten()):
        ax.imshow((x.numpy()*255).astype(np.uint8))
        ax.set_title(CLASSES[y])
        ax.axis('off')
from tensorflow.keras.applications import Xception
from tensorflow.keras.applications import DenseNet121, DenseNet169, DenseNet201
from tensorflow.keras.applications import ResNet50V2, ResNet101V2, ResNet152V2
from tensorflow.keras.applications import InceptionV3
from tensorflow.keras.applications import InceptionResNetV2
class Flower_Classifier(tf.keras.models.Model):
    def __init__(self):
        super().__init__()
        self.image_embedding_layers = []
        self.image_embedding_layers.append(
            Xception(weights='imagenet',
                     include_top=False,
                     input_shape=IMG_SIZE))
        self.image_embedding_layers.append(
            ResNet152V2(weights='imagenet',
                        include_top=False,
                        input_shape=IMG_SIZE))
        self.image_embedding_layers.append(
            InceptionResNetV2(weights='imagenet',
                              include_top=False,
                              input_shape=IMG_SIZE))
        self.pooling_layer = tf.keras.layers.GlobalAveragePooling2D()
        self.layer_normalization_layers = []
        self.prob_dist_layers = []
        for model_idx, image_embedded_layer in enumerate(self.image_embedding_layers):
            self.layer_normalization_layers.append(
                tf.keras.layers.LayerNormalization(epsilon=1E-6))
            self.prob_dist_layers.append(
                tf.keras.layers.Dense(NCLASSES, activation='softmax',
                                      name=f'prob_dist_{model_idx}'))
        kernel_init = tf.constant_initializer(
            np.array([0.86690587, 1.0948032, 1.1121726])) 
        bias_init = tf.constant_initializer(
            np.array([-0.13309559, 0.09480964, 0.11218266]))
        self.prob_dist_weight = tf.keras.layers.Dense(
            len(self.image_embedding_layers), activation="softmax",
            kernel_initializer=kernel_init,
            bias_initializer=bias_init,
            name='prob_dist_weight')
    def call(self, inputs, training=False):
        all_model_outputs=[]
        for i in range(len(self.image_embedding_layers)):
            embedding = self.image_embedding_layers[i](inputs, training=training)
            pooling = self.pooling_layer(embedding, training=training)
            pooling_normalized = self.layer_normalization_layers[i](pooling, training=training)
            model_output = self.prob_dist_layers[i](pooling_normalized, training=training)
            all_model_outputs.append(model_output)
        all_model_outputs = tf.stack(all_model_outputs, axis=1)
        prob_dist_weight = self.prob_dist_weight(tf.constant(1, shape=(1,1)), training=training)
        prob_dist = tf.linalg.matmul(prob_dist_weight, all_model_outputs)
        prob_dist = prob_dist[:, 0, :]
        return prob_dist
    def model(self):
        x = tf.keras.Input(shape=IMG_SIZE)
        return tf.keras.Model(inputs=[x], outputs=self.call(x))
with strategy.scope():
    model = Flower_Classifier()
summary() of a subclassing model which is introduced in this video.
#model.model().summary()
class CustomCyclicSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, n_step, lr_max, div=25.0, div_final=1e5,
                 pct_start=0.25, staircase=False, cycle=False):
        self.lr_start = lr_max / div
        self.lr_max = lr_max
        self.lr_min = lr_max / div_final
        self.rising_steps = int(n_step * pct_start)
        self.rising_a = (lr_max-self.lr_min) / (self.rising_steps-1)
        self.rising_b = self.lr_min
        self.falling_steps = int(n_step - self.rising_steps)
        self.falling_rate = self.lr_min / lr_max
        self.cycle = tf.constant(cycle, dtype=tf.bool)
        self.decay_fn = tf.keras.experimental.CosineDecay(
            initial_learning_rate = lr_max,
            decay_steps = self.falling_steps,
            alpha = self.lr_min)
    def __call__(self, step):
        """ `step` is actually the step index, starting at 0. """
        lr = tf.cond(step < self.rising_steps,
                     lambda : self.rising_a*tf.cast(step, tf.float32) + self.rising_b,
                     lambda : self.decay_fn(tf.cast(step-self.rising_steps, tf.int32)))
        return tf.cast(lr, tf.float32)
NSTEPS = math.ceil(count_data_items(train_filenames) / GLOBAL_BATCH_SIZE) * EPOCHS
LR_MAX = 1e-4 #@param {type: "number"}
LR_MAX *= strategy.num_replicas_in_sync
DIV = 25.0 #@param {type: "number"}
DIV_FINAL = 1e5 #@param {type: "number"}
PCT_START = 0.25#@param {type: "number"}
with strategy.scope():
    schedule = CustomCyclicSchedule(
        n_step=NSTEPS,
        lr_max=LR_MAX,
        div=DIV,
        div_final=DIV_FINAL,
        cycle=False)
xps = tf.range(NSTEPS)
yps = [schedule(x) for x in xps]
fig,ax = plt.subplots(1,1,figsize=(8,5),facecolor='#F0F0F0')
ax.plot(xps, yps)
ax.set_facecolor('#F8F8F8')
ax.set_xlabel('iteration')
ax.set_ylabel('learning rate')
print('{:d} total epochs and {:d} steps per epoch'
        .format(EPOCHS, NSTEPS // EPOCHS))
with strategy.scope():
    opt = tf.keras.optimizers.Adam(schedule)
reduction to ’none’, please check this tutorial. In particular, read the paragraph. If using tf.keras.losses classes (as in the example below), the loss reduction needs to be explicitly specified to be one of NONE or SUM. AUTO and SUM_OVER_BATCH_SIZE are disallowed when used with tf.distribute.Strategy.
tf.nn.compute_average_loss, please check this tutorial
BATCH_SIZE = 8 * strategy.num_replicas_in_sync, I got nan values. Since we pass probability distribution to CategoricalCrossentropy with from_logits = False, which has numerical unstability issue, we use the same trick in the source code to avoid such unstabiltiy
from tensorflow.python.ops import clip_ops
from tensorflow.python.framework import constant_op
While training with GLOBAL_BATCH_SIZE = 8 * replicas I got nan values. Since we pass probability distribution to CategoricalCrossentropy with from_logits = False, which has numerical unstability issue, we use the same trick in the source code to avoid such unstabiity.
def _constant_to_tensor(x, dtype):
    return constant_op.constant(x, dtype=dtype)
with strategy.scope():
    loss_object = tf.keras.losses.CategoricalCrossentropy(
                from_logits=False, reduction='none', label_smoothing=0.1)
    def loss_function(labels, prob_dists, sample_weights=1.0):
        epsilon_ = _constant_to_tensor(tf.keras.backend.epsilon(), prob_dists.dtype.base_dtype)
        prob_dists = clip_ops.clip_by_value(prob_dists, epsilon_, 1 - epsilon_)
        labels = tf.keras.backend.one_hot(labels, NCLASSES)
        loss = loss_object(labels, prob_dists)
        loss = tf.nn.compute_average_loss(loss, global_batch_size=GLOBAL_BATCH_SIZE)
        return loss
def get_metrics(name):
    loss = tf.keras.metrics.Mean(name=f'{name}_loss', dtype=tf.float32)
    acc = tf.keras.metrics.SparseCategoricalAccuracy(name=f'{name}_acc', dtype=tf.float32)
    return loss, acc
with strategy.scope():
    train_loss_obj, train_acc_obj = get_metrics('train')
    valid_loss_obj, valid_acc_obj = get_metrics('valid')
Distributed datasets
There are two APIs to create a tf.distribute.DistributedDataset object: tf.distribute.Strategy.experimental_distribute_dataset(dataset) and tf.distribute.Strategy.distribute_datasets_from_function(dataset_fn).
When to use which?
When you have a tf.data.Dataset instance, and the regular batch splitting (i.e. re-batch the input tf.data.Dataset instance with a new batch size that is equal to the global batch size divided by the number of replicas in sync) and autosharding (i.e. the tf.data.experimental.AutoShardPolicy options) work for you, use the former API. Otherwise, if you are not using a canonical tf.data.Dataset instance, or you would like to customize the batch splitting or sharding, you can wrap these logic in a dataset_fn and use the latter API. Both API handles prefetch to device for the user. For more details and examples, follow the links to the APIs.
train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
valid_ds = get_valid_dataset(valid_filenames, ordered=True)
valid_dist_ds = strategy.experimental_distribute_dataset(valid_ds)
test_ds = get_test_dataset(test_filenames, ordered=True).map(lambda img,lbl: img)
test_dist_ds = strategy.experimental_distribute_dataset(test_ds)
train_dist_ds = strategy.distribute_datasets_from_function(
    dataset_fn = lambda _: (get_train_dataset(train_filenames, bs=BATCH_SIZE)
                                .take(10*strategy.num_replicas_in_sync)))
valid_dist_ds = strategy.distribute_datasets_from_function(
    dataset_fn = lambda _: (get_valid_dataset(valid_filenames, ordered=True, bs=BATCH_SIZE)
                                .take(10*strategy.num_replicas_in_sync)))
train_input_signature = [
    tf.TensorSpec(shape=(None, None, None, None), dtype=tf.float32),
    tf.TensorSpec(shape=(None,), dtype=tf.int32)
]
@tf.function(input_signature=train_input_signature)
def train_step(images, labels):
    with tf.GradientTape() as tape:
        prob_dists = model(images, training=True)
        loss = loss_function(labels, prob_dists)
    grads = tape.gradient(loss, model.trainable_variables)
    grads, global_norm = tf.clip_by_global_norm(grads, clip_norm=1.0)
    opt.apply_gradients(zip(grads, model.trainable_variables))
    train_acc_obj(labels, prob_dists)
    return loss
@tf.function
def distributed_train_step(inputs):
    (images, labels) = inputs
    per_replica_losses = strategy.run(train_step, args=(images, labels))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
metrics_names = ['train_loss','train_acc'] 
pbar = tf.keras.utils.Progbar(10, width=15, stateful_metrics=metrics_names)
train_loss_obj.reset_states()
train_acc_obj.reset_states()
train_iter = iter(train_dist_ds)
for i,train_inputs in enumerate(train_dist_ds):
    if i == 10: break
    train_loss = distributed_train_step(train_inputs)
    train_loss_obj(train_loss)
    values=[('train_loss', train_loss_obj.result()),
            ('train_acc', train_acc_obj.result())]
    pbar.add(1, values=values)
valid_input_signature = train_input_signature
@tf.function(input_signature=valid_input_signature)
def valid_step(images, labels):
    prob_dists = model(images, training=False)
    loss = loss_function(labels, prob_dists, sample_weights=None)
    valid_acc_obj(labels, prob_dists)
    return loss, prob_dists
@tf.function
def distributed_valid_step(inputs):
    (images, labels) = inputs
    losses, prob_dists = strategy.run(valid_step, args=(images, labels))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, losses, axis=None), prob_dists
metrics_names = ['train_loss','train_acc','valid_loss','valid_acc'] 
pbar = tf.keras.utils.Progbar(20, width=15, stateful_metrics=metrics_names)
    
train_loss_obj.reset_states()
train_acc_obj.reset_states()
valid_loss_obj.reset_states()
valid_acc_obj.reset_states()
for i,train_inputs in enumerate(train_dist_ds):
    if i == 10: break
    train_loss = distributed_train_step(train_inputs)
    train_loss_obj(train_loss)
    values=[('train_loss', train_loss_obj.result()),
            ('train_acc', train_acc_obj.result())]
    pbar.add(1, values=values)
for i,valid_inputs in enumerate(valid_dist_ds):
    if i == 10: break
    valid_loss = distributed_valid_step(valid_inputs)
    valid_loss_obj(valid_loss)
    values=[('valid_loss', valid_loss_obj.result()),
            ('valid_acc', valid_acc_obj.result())]
    pbar.add(1, values=values)
test_input_signature = [
    tf.TensorSpec(shape=(None, None, None, None), dtype=tf.float32)]
    
@tf.function(input_signature=test_input_signature)
def test_step(images):
     prob_dists = model(images, training=False)
     return prob_dists
@tf.function
def distributed_test_step(inputs):
    images = inputs
    prob_dists = strategy.run(test_step, args=(images,))
    return prob_dist
N_TRAIN_STEPS = count_data_items(train_filenames) // GLOBAL_BATCH_SIZE
N_VALID_STEPS = count_data_items(valid_filenames) // GLOBAL_BATCH_SIZE
N_TOTAL_STEPS = N_TRAIN_STEPS + N_VALID_STEPS
history = {
    "train_loss": [], "valid_loss": [],
    "train_acc": [],  "valid_acc": []
}
valid_ds = get_valid_dataset(valid_filenames, ordered=True)
valid_dist_ds = strategy.experimental_distribute_dataset(valid_ds)
for epoch in range(EPOCHS):
    train_ds = get_valid_dataset(train_filenames)
    train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
    
    metrics_names = ['train_loss','train_acc','valid_loss','valid_acc']
    pbar = tf.keras.utils.Progbar(N_TOTAL_STEPS, width=15, stateful_metrics=metrics_names)
    train_loss_obj.reset_states()
    train_acc_obj.reset_states()
    valid_loss_obj.reset_states()
    valid_acc_obj.reset_states()
    for train_inputs in train_dist_ds:
        train_loss = distributed_train_step(train_inputs)
        train_loss_obj(train_loss)
        values=[('train_loss', train_loss_obj.result()),
                ('train_acc', train_acc_obj.result())]
        pbar.add(1, values=values)
    history['train_loss'].append(train_loss_obj.result())
    history['train_acc'].append(train_acc_obj.result())
    for valid_inputs in valid_dist_ds:
        valid_loss, _ = distributed_valid_step(valid_inputs)
        valid_loss_obj(valid_loss)
        values=[('valid_loss', valid_loss_obj.result()),
                ('valid_acc', valid_acc_obj.result())]
        pbar.add(1, values=values)
    history['valid_loss'].append(valid_loss_obj.result())
    history['valid_acc'].append(valid_acc_obj.result())
def show_history(history):
    topics = ['loss', 'acc']
    groups = [{k:v for (k,v) in history.items() if topic in k} for topic in topics]
    _,axs = plt.subplots(1,2,figsize=(15,6),facecolor='#F0F0F0')
    for topic,group,ax in zip(topics,groups,axs.flatten()):
        for (_,v) in group.items(): ax.plot(v)
        ax.set_facecolor('#F8F8F8')
        ax.set_title(f'{topic} over epochs')
        ax.set_xlabel('epoch')
        ax.set_ylabel(topic)
        ax.legend(['train', 'valid'], loc='best')
show_history(history)
all_valid_preds = []
pbar = tf.keras.utils.Progbar(N_VALID_STEPS, width=15)
for valid_inputs in valid_dist_ds:
    _, valid_preds = distributed_valid_step(valid_inputs)
    all_valid_preds.append(tf.concat(valid_preds.values, axis=0).numpy())
    pbar.add(1)
all_valid_preds = np.concatenate(all_valid_preds, axis=0, out=None)
cm_trues = (valid_ds.map(lambda im,lbl: lbl)
                    .unbatch()
                    .batch(count_data_items(valid_filenames))
                    .as_numpy_iterator()
                    .next())
cm_preds = np.argmax(all_valid_preds, axis=-1)
cmat = confusion_matrix(cm_trues, cm_preds, labels=range(len(CLASSES)))
f1 = f1_score(cm_trues, cm_preds,
              labels=range(NCLASSES), average='macro')
precision = precision_score(cm_trues, cm_preds,
              labels=range(NCLASSES), average='macro')
recall = recall_score(cm_trues, cm_preds,
              labels=range(NCLASSES), average='macro')    
def show_confusion_matrix(cmat, f1, precision, recall):
    plt.figure(figsize=(15,15))
    ax = plt.gca()
    ax.matshow(cmat, cmap='Blues')
    ax.set_xticks(range(len(CLASSES)))
    ax.set_xticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_xticklabels(), rotation=45, ha="left", rotation_mode="anchor")
    ax.set_yticks(range(len(CLASSES)))
    ax.set_yticklabels(CLASSES, fontdict={'fontsize': 7})
    plt.setp(ax.get_yticklabels(), rotation=45, ha="right", rotation_mode="anchor")
    titlestring = ""
    if precision: titlestring += 'precision = {:.3f} '.format(precision)
    if recall: titlestring += '\nrecall = {:.3f} '.format(recall)
    if f1: titlestring += '\nf1 = {:.3f} '.format(f1)
    if len(titlestring) > 0:
        ax.text(101, 1, titlestring,fontdict={'fontsize': 18, 'horizontalalignment':'right',
                                              'verticalalignment':'top', 'color':'#804040'})
    plt.show()
show_confusion_matrix(cmat, f1, precision, recall)