Learning Rate Finder
Approximate to the optimal learning rate quickly without costly searches
- Why is it hard?
- Implement LRFinder
- MNIST Dataset
- Simple Model
- Train without optimal lr
- Train with optimal lr
- Conclusion
- References
- 1% Better Everyday
The learning rate is arguably the most important hyperparameter that controls how much we are adjusting the weights of our network with respect to the loss gradient. It stands for how much a model can learn from a new mini-batch of training data. The higher the learning rate, the bigger the steps we take along the trajectory to the minimum of the loss function, where the best model parameters are.
Why is it hard?
The learning rate is a tricky hyperparameter to tune for a number of reasons:
- In most cases, domain knowledge or previous studies are of little help, for a learning rate that worked well for one problem might not be even half as good for another, even a closely-related one.
- Tuning learning rates via a grid search or a random search is typically costly, both in terms of time and computing power, especially for large networks.
- The optimal learning rate is tightly coupled with other hyperparameters. Hence, each time your change the amount of regularization or the networks architecture, you should re-tune the learning rate.
import numpy as np
import pandas as pd
import seaborn as sns
import albumentations as A
import matplotlib.pyplot as plt
import os, gc, cv2, random, warnings
import math, sys, json, pprint, pdb
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import Callback
from sklearn.model_selection import train_test_split
warnings.simplefilter('ignore')
print(f"Using TensorFlow v{tf.__version__}")
class MultiplicativeLearningRate(tf.keras.callbacks.Callback):
def __init__(self, factor):
self.factor = factor
self.losses = []
self.lrs = []
def on_batch_end(self, batch, logs):
self.lrs.append(K.get_value(self.model.optimizer.lr))
self.losses.append(logs["loss"])
K.set_value(self.model.optimizer.lr, self.model.optimizer.lr*self.factor)
min_lr = 1e-6
max_lr = 1e1
num_iter = 1000
lr_factor = np.exp(np.log(max_lr / min_lr) / num_iter)
lrs = [min_lr * (lr_factor)**i for i in range(num_iter)]
fig,axs = plt.subplots(1,2,figsize=(12,4),facecolor="#F0F0F0")
axs[0].plot(lrs)
axs[0].set_yscale("log")
axs[0].set_ylabel("learning rate")
axs[0].set_xlabel("iteration")
axs[1].plot(lrs)
axs[1].set_ylabel("learning rate")
axs[1].set_xlabel("iteration")
def find_lr(model, x, y, batch_size, min_lr=1e-6, max_lr=1e1):
num_iter = len(x) // batch_size
lr_factor = np.exp(np.log(max_lr / min_lr) / num_iter)
# Train for 1 epoch, starting with minimum learning rate and increase it
K.set_value(model.optimizer.lr, min_lr)
lr_callback = MultiplicativeLearningRate(lr_factor)
model.fit(x, y, epochs=1, batch_size=batch_size, callbacks=[lr_callback])
# Plot loss vs log-scaled learning rate
plot = sns.lineplot(lr_callback.lrs, lr_callback.losses)
plot.set(xscale="log",
xlabel="Learning Rate (log-scale)",
ylabel="Training Loss",
title="Optimal learning rate is slightly below minimum",
facecolor="#F0F0F0")
class LRFinder(Callback):
"""Callback that exponentially adjusts the learning rate after each
training batch between start_lr and end_lr for a maximum number of
batches: max_step. The loss and learning rate are recorded at each
step allowing visually finding a good learning rate as per
https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
via the plot method.
"""
def __init__(self, start_lr: float = 1e-7, end_lr: float = 10,
max_steps: int = 100, smoothing=0.9):
super(LRFinder, self).__init__()
self.start_lr, self.end_lr = start_lr, end_lr
self.max_steps = max_steps
self.smoothing = smoothing
self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0
self.lrs, self.losses = [], []
def on_train_begin(self, logs=None):
self.step, self.best_loss, self.avg_loss, self.lr = 0, 0, 0, 0
self.lrs, self.losses = [], []
def on_train_batch_begin(self, batch, logs=None):
self.lr = self.exp_annealing(self.step)
K.set_value(self.model.optimizer.lr, self.lr)
def on_train_batch_end(self, batch, logs=None):
logs = logs or {}
loss = logs.get('loss')
step = self.step
if loss:
self.avg_loss = self.smoothing * self.avg_loss + (1 - self.smoothing) * loss
smooth_loss = self.avg_loss / (1 - self.smoothing ** (self.step + 1))
self.losses.append(smooth_loss)
self.lrs.append(self.lr)
if step == 0 or loss < self.best_loss:
self.best_loss = loss
if smooth_loss > 4 * self.best_loss or tf.math.is_nan(smooth_loss):
self.model.stop_training = True
if step == self.max_steps:
self.model.stop_training = True
self.step += 1
def exp_annealing(self, step):
return self.start_lr * (self.end_lr / self.start_lr) ** (step * 1. / self.max_steps)
def plot(self):
fig, ax = plt.subplots(1,1,facecolor="#F0F0F0")
ax.set_ylabel('Loss')
ax.set_xlabel('Learning Rate')
ax.set_xscale('log')
ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
ax.plot(self.lrs, self.losses)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print(x_train.shape, x_test.shape, y_train.shape, y_test.shape)
IMG_SIZE = (28, 28, 1)
NCLASSES = 10
x_train = np.expand_dims(x_train.astype('float32') / 255.0, axis=-1)
x_test = np.expand_dims(x_test.astype('float32') / 255.0, axis=-1)
print('x_train shape:', x_train.shape)
print('x_test shape:', x_test.shape)
y_train = tf.keras.utils.to_categorical(y_train, NCLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NCLASSES)
print('y_train shape:', y_train.shape)
print('y_test shape:', y_test.shape)
def build_simple_model(input_shape, lr=None):
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu',
input_shape=input_shape))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))
model.add(tf.keras.layers.Dropout(0.25))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
opt = tf.keras.optimizers.SGD(learning_rate=lr if lr else 1e-2)
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
return model
BATCH_SIZE = 64
EPOCHS = 5
STEPS = len(x_train) // BATCH_SIZE
model = build_simple_model(IMG_SIZE)
model.fit(x_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE)
score = model.evaluate(x_test, y_test, verbose=0,batch_size=BATCH_SIZE)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
Train with optimal lr
We will create an instance of the class we built above and pass it as a callback to our model. The LR finder is very cheap in terms of compute and it hardly takes an epoch or less to complete. We will keep the default values of base_lr
and max_lr
but you can change it if you want to.
model = build_simple_model(IMG_SIZE)
find_lr(model, x_train, y_train, batch_size=BATCH_SIZE)
The recommended minimum learning rate is the value where the loss decreases the fatest (minimum negative gradient), while the recommended maximum learning rate is 10 times less than the learning rate wher the loss is minimum. Why not just the very minimum of the loss? Why 10 times less? Because what we actually plot is a smoothed version of the loss, and taking the learning rate corresponding to the minimum loss is likely to be too large and make the loss diverge during training.
model3 = build_simple_model(IMG_SIZE, lr=8e-2)
model3.fit(x_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE)
score = model3.evaluate(x_test, y_test, verbose=0, batch_size=BATCH_SIZE)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
You can see that if we start with an optimal learning rate, we can coverge much faster. A good start always pays off!
References
- https://arxiv.org/abs/1708.07120
- https://arxiv.org/abs/1506.01186
- https://arxiv.org/abs/1803.09820
- https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
1% Better Everyday
Maybe we don't really need to create a class for learning rate finder. Perhaps we can achieve the same goal by using the tf.keras.optimizers.schedules.ExponentialDecay
plus tensor board.