Anda di halaman 1dari 11

Chris Rawles

Machine learning, Google Cloud Professional Services
Feb 26 · 5 min read

How to use Batch Normalization with
TensorFlow and tf.keras to train deep
neural networks faster

Training deep neural networks can be time consuming. In particular,


training can be signi cantly impeded by vanishing gradients, which
occurs when a network stops updating because the gradients,
particularly in earlier layers, have approached zero values.
Incorporating Xavier weight-initialization and ReLu activation
functions helps counter the vanishing gradient problem. These
techniques also help with the opposite, yet closely related issue of
exploding gradients, where the gradients become extremely large
preventing the model from updating.

Perhaps the most powerful tool for combatting the vanishing and
exploding gradients issue is Batch Normalization. Batch Normalization
works like this: for each unit in a given layer, rst compute the z score,
and then apply a linear transformation using two trained variables
and . Batch Normalization is typically done prior to the non-linear
activation function (see below gure), however applying it after the
activation function can also be bene cial. Check out this lecture for
more detail of how the technique works.

During backpropagation gradients tend to get smaller at lower layers, slowing down weight updates
and thus training. Batch Normalization helps combat the so-called vanishing gradients.

Batch Normalization can be implemented in three ways in TensorFlow.


Using:

1. tf.keras.layers.BatchNormalization

2. tf.layers.batch_normalization

3. tf.nn.batch_normalization

The tf.keras module became part of the core TensorFlow API in


version 1.4. and provides a high level API for building TensorFlow
models; so I will show you how to do it in Keras. The
tf.layers.batch_normalization function has similar functionality, but
Keras often proves to be an easier way to write model functions in
TensorFlow.
1 in_training_mode = tf.placeholder(tf.bool)
2 hidden = tf.keras.layers.Dense(n_units,
3                                activation=None)(X) # no act
4 batch_normed = tf.keras.layers.BatchNormalization()(hidden,
5 output = tf.keras.activations\
6            .relu(batch_normed) # ReLu is typically done aft
7  
8 # optimizer code here …
9  
Note the training variable in the Batch Normalization function. This is required because Batch
Normalization operates di erently during training vs. the application stage– during training the z
score is computed using the batch mean and variance, while in inference, it’s computed using a mean
and variance estimated from the entire training set.

In TensorFlow, Batch Normalization can be implemented as an additional layer using tf.keras.layers.

The second code block with tf.GraphKeys.UPDATE_OPS is important.


Using tf.keras.layers.BatchNormalization , for each unit in the
network, TensorFlow continually estimates the mean and variance of
the weights over the training dataset. These are then stored in the
tf.GraphKeys.UPDATE_OPS variable. After training, these stored values
are used to apply Batch Normalization at prediction time. The training
set mean and variance from each unit can be observed by printing
extra_ops , which contains a list for each layer in the network:

print(extra_ops)

[<tf.Tensor ‘batch_normalization/AssignMovingAvg:0’ shape=


(500,) dtype=float32_ref>, # layer 1 mean values
<tf.Tensor ‘batch_normalization/AssignMovingAvg_1:0’ shape=
(500,) dtype=float32_ref>, # layer 1 variances ...]

While Batch Normalization is also available in the tf.nn module, it


requires extra bookkeeping, as the mean and variance are required
arguments for the function. Thus the user has to manually compute
mean and variance at both the batch level and training set level. It is,
thus, a lower abstraction level than tf.keras.layers or tf.layers ;
avoid the tf.nn implementation.

Batch Normalization on MNIST
Below, I apply Batch Normalization to the prominent MNIST dataset
using TensorFlow. Check out the code here. MNIST is an easy dataset to
analyze and doesn’t require many layers to achieve low classi cation
error. However, we can still build a deep network and observe how
Batch Normalization a ects convergence.

Let’s build a custom estimator using the tf.estimator API. First we


build the model:
1 def dnn_custom_estimator(features, labels, mode, params):
2     in_training = mode == tf.estimator.ModeKeys.TRAIN
3     use_batch_norm = params['batch_norm']
4     
5     net = tf.feature_column.input_layer(features, params['f
6     for i, n_units in enumerate(params['hidden_units']):
7         net = build_fully_connected(net, n_units=n_units, t
8                                     batch_normalization=use
9                                     activation=params['acti
10                                     name='hidden_layer'+str
11     
12     logits = output_layer(net, 10, batch_normalization=use_
13                           training=in_training)
14     
15     predicted_classes = tf.argmax(logits, 1)
16     loss = tf.losses.softmax_cross_entropy(onehot_labels=la
17     accuracy = tf.metrics.accuracy(labels=tf.argmax(labels,
18                                    predictions=predicted_cl
19                                    name='acc_op')
20     tf.summary.scalar('accuracy', accuracy[1])  # for visua
21  
22     if mode == tf.estimator.ModeKeys.EVAL:
23 t tf ti t E ti t S ( d l l

After we de ne our model function, let’s build the custom estimator


and train and evaluate our model:

1 def train_and_evaluate(output_dir):
2     features = [tf.feature_column.numeric_column(key='image
3     classifier = tf.estimator.Estimator(model_fn=dnn_custom
4                                         model_dir=output_di
5                                         params={'features':
6                                                 'batch_norm
7                                                 'activation
8                                                 'hidden_uni
9                                                 'learning_r
10  
11 train spec = tf estimator TrainSpec(input fn=train inpu

Let’s test how Batch Normalization impacts models of varying depths.


After we wrap our code into a Python package, we can re o multiple
experiments in parallel using Cloud ML Engine:

1 # def ml‑engine function
2 submitMLEngineJob() {
3     gcloud ml‑engine jobs submit training $JOBNAME \
4         ‑‑package‑path=$(pwd)/mnist_classifier/trainer \
5         ‑‑module‑name trainer.task \
6         ‑‑region $REGION \
7         ‑‑staging‑bucket=gs://$BUCKET \
8         ‑‑scale‑tier=BASIC \
9         ‑‑runtime‑version=1.4 \
10         ‑‑ \
11         ‑‑outdir $OUTDIR \
12         ‑‑hidden_units $net \
13         ‑‑num_steps 1000 \
14         $batchNorm
15 }
16  
17 # launch jobs in parallel
18 export PYTHONPATH=${PYTHONPATH}:${PWD}/mnist_classifier
19 for batchNorm in '' '‑‑use_batch_normalization'
20 do
21     net=''
22     for layer in 500 400 300 200 100 50 25;

The below plot show the number of training iterations (1 iteration


contains a batch size of 500) required to reach 90% testing accuracy —
an easy target — as a function of network depth. It’s evident that Batch
Normalization signi cantly speeds up training for the deeper networks.
Without Batch Normalization, the number of training steps increases
with each subsequent layer, but with it, the number of training steps is
near constant. And in practice, on more di cult datasets, more layers is
a prerequisite for success.
Without Batch Normalization the number of training iterations required to hit 90% accuracy
increases with the number of layers, likely due to the vanishing gradient e ect.

Similarly, as shown below, for a fully connected network with 7 hidden


layers, the convergence time without Batch Normalization is
signi cantly slower.

The above experiments utilize the commonly used ReLu activation


function. Though obviously not immune to the vanishing gradient
e ect as shown above, the ReLu activation fares much better than the
sigmoid or tanh activation functions. The vulnerability of the sigmoid
activation function to vanishing gradients is rather intuitive to
understand. At larger magnitude (very positive or negative) values, the
sigmoid function “saturates” — i.e. the derivative of the sigmoid
function approaches zero. And when many nodes saturate, the number
of updates decreases, and network stops training.
The same 7-layer network trains signi cantly slower using sigmoid
activation functions without using Batch Normalization. With Batch
Normalization, the network converges in a similar number of iterations
when using ReLu.

On the other hand, other activation functions, such as the exponential


ReLu or leaky ReLu functions, can help combat the vanishing gradient
issue as they have non-zero derivatives for both positive and negative
large numbers.

Finally, it is important to note that Batch Normalization incurs an extra


time cost to training. Though Batch Normalization typically decreases
the number of training steps to reach convergence, it brings an extra
time cost because it introduces an additional operation and also
introduces two new trained parameters per unit.
For the MNIST classi cation problem (using a 1080 GTX GPU), Batch Normalization converges in (top)
fewer iterations, however the time per iteration is slower. Ultimately, the Batch Normalization version
still converges faster (bottom), but the improvement is less pronounced when incorporating total
training time.

Incorporating XLA and fused Batch Normalization (fused argument in


tf.layers.batch_normalization ) could help speed up the Batch
Normalization operation by combining several individual operations
into a single kernel.

Regardless, Batch Normalization can be a very valuable tool for


speeding the training of deep neural networks. As always with training
deep neural networks, the best way to gure out if an approach will
help for your problem is to try it!

Anda mungkin juga menyukai