MNIST with Tensorflow Experiments and Estimators

An MNIST classifier is the go-to introduction for machine learning. Tensorflow is no different, and evolves to the Deep MNIST for Experts to include convolution, max pooling, dense layers and dropout: a good overview of ML layers for image problems. The downside of this is it doesn’t make use of Tensorflow’s new tf.estimator high level APIs. These provide all sorts of benefits for free than the usual TensorFlow tutorials you see online. The tf.estimator Quickstart gives a good reason to use it:

TensorFlow’s high-level machine learning API (tf.estimator) makes it easy to configure, train, and evaluate a variety of machine learning models

A more extensive list is the Advantages of Estimators that list a few points of particular interest:

  • Run on CPU, GPU or TPU without reordering your model
  • Safe distributed training loop to build graph, initialize variables, start queues, create checkpoints, save summaries to TensorBoard
  • Export for serving

These are just to list a few. With that, lets get into producing the same network as Deep MNIST for Experts but using the layout for TensorFlow Estimators.

Estimator minimum layout

At a high level, you’ll need:

  • An input function: returns features and labels for training
  • A model function: that is fed the features, labels and a few other parameters where your model code processes them and produces losses, metrics etc
  • An experiment object: to run the estimator in a configured state

Data input function

I wrote another article to Convert and use the MNIST dataset as TFRecords. We will be using the TFRecord files we produced there as input here. The data_input_fn we define below allows passing a list of filenames to construct and return an input_fn that is compatible with the estimator API.

def data_input_fn(filenames, batch_size=1000, shuffle=False):

    def _parser(record):
            'label': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string)
        parsed_record = tf.parse_single_example(record, features)
        image = tf.decode_raw(parsed_record['image_raw'], tf.float32)

        label = tf.cast(parsed_record['label'], tf.int32)

        return image, label

    def _input_fn():
        dataset =
        dataset =
        if shuffle:
            dataset = dataset.shuffle(buffer_size=10_000)

        dataset = dataset.repeat(None) # Infinite iterations: let experiment determine num_epochs
        dataset = dataset.batch(batch_size)

        iterator = dataset.make_one_shot_iterator()
        features, labels = iterator.get_next()

        return features, labels
    return _input_fn

The _parser function is used to decode a single record from the file into a image, label tuple. The local _input_fn is the value returned that our estimator will call to retrieve data. It sets up the dataset, and defines how the data should be returned by defining batch sizes, shuffling and how many times it should repeat (commonly referred to as number of epochs).

CNN Model Function

The model function should have the signature:

  • features: items returned from input_fn
  • labels: second item returned from input_fn
  • mode: the mode the estimator is running in (basically training, validation or prediction)
  • params: an optional dictionary of hyperparameters

Looking at the function below, you should see that the code up to the 10 unit dense layer for predictions is exactly the same as the Deep MNIST for Experts tutorial. We will go over the remainder of the function below to keep those closer to their definition.

def cnn_model_fn(features, labels, mode, params):
    """Model function for CNN."""

    is_training = mode == tf.estimator.ModeKeys.TRAIN

    with tf.name_scope('Input'):
        # Input Layer
        input_layer = tf.reshape(features, [-1, 28, 28, 1], name='input_reshape')
        tf.summary.image('input', input_layer)

    with tf.name_scope('Conv_1'):
        # Convolutional Layer #1
        conv1 = tf.layers.conv2d(
          kernel_size=(5, 5),

        # Pooling Layer #1
        pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=(2, 2), strides=2, padding='same')

    with tf.name_scope('Conv_2'):
        # Convolutional Layer #2 and Pooling Layer #2
        conv2 = tf.layers.conv2d(
            kernel_size=(5, 5),

        pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=(2, 2), strides=2, padding='same')

    with tf.name_scope('Dense_Dropout'):
        # Dense Layer
        # pool2_flat = tf.reshape(pool2, [-1, 7 * 7 * 64])
        pool2_flat = tf.contrib.layers.flatten(pool2)
        dense = tf.layers.dense(inputs=pool2_flat, units=1024, activation=tf.nn.relu, trainable=is_training)
        dropout = tf.layers.dropout(inputs=dense, rate=params['dropout_rate'], training=is_training)

    with tf.name_scope('Predictions'):
        # Logits Layer
        logits = tf.layers.dense(inputs=dropout, units=10, trainable=is_training)

    predicted_logit = tf.argmax(input=logits, axis=1, output_type=tf.int32)
    scores = tf.nn.softmax(logits, name='softmax_tensor')

    # Generate Predictions
    predictions = {
      'classes': predicted_logit,
      'probabilities': scores

    if mode == tf.estimator.ModeKeys.PREDICT:
        export_outputs = {
            'prediction': tf.estimator.export.ClassificationOutput(
                classes=tf.cast(predicted_logit, tf.string))

        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions, export_outputs=export_outputs)

    # TRAIN and EVAL
    loss = tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits)

    accuracy = tf.metrics.accuracy(tf.argmax(labels, axis=1), predicted_logit)
    eval_metric = { 'accuracy': accuracy }

    # Configure the Training Op (for TRAIN mode)
    if mode == tf.estimator.ModeKeys.TRAIN:
        tf.summary.scalar('accuracy', accuracy[0])
        train_op = tf.contrib.layers.optimize_loss(
        train_op = None

    return tf.estimator.EstimatorSpec(

After the final dense layer, is the first time we utilise the mode parameter. We can adjust our operations depending on what we are currently running. Firstly, for all but predicting, we calculate:

  • The networks loss by tf.losses.softmax_cross_entropy
  • Whether we got the correct_prediction by comparing the neuron with max probability and our label parameter

Subsequently, if we are training, we define how the loss should be optimized by setting the train_op to optimize the loss using the Adam optimizer. The final two dictionaries are different reporters that can indicate the outcome of the network. Finally we return a tf.estimator.EstimatorSpec that will be used to build our custom estimator.

The Estimator

There are a few built in estimators that don’t require defining a custom model function and can be configured as desired. You can see the classes on the tf.estimator Overview page. For our custom estimator we will use the Estimator base class. Here we setup how and where the estimator will save appropriate checkpoints and summaries as described before, plus setup any parameters we’d like to pass to our model function params object. Final we construct our custom mnist_estimator with our model function, config and parameters.

run_config = tf.estimator.RunConfig(

hparams = {
    'learning_rate': 1e-3, 
    'dropout_rate': 0.4,
    'data_directory': os.path.expanduser(args.data_directory)

mnist_classifier = tf.estimator.Estimator(

Finally running the estimator

Up until this point, no code has actually run. We have defined everything we need, but it is the experiment that will run our estimator as desired. In the code below, we use our data_input_fn passing the paths to our tfrecord files to construct input functions for training and validation data. Following this we use our MNIST CNN estimator and tell it to train_and_evaluate. No epoch loops, no watching the steps and running accuracy checks, no summary writers etc.

train_batch_size = 1000

train_input_fn = data_input_fn(glob.glob(os.path.join(hparams['data_directory'], 'train-*.tfrecords')), batch_size=train_batch_size)
eval_input_fn = data_input_fn(os.path.join(hparams['data_directory'], 'validation.tfrecords'), batch_size=100)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=40)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, steps=100, start_delay_secs=0)

tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec)

At this point, checkpoints and tensorboard summaries will be saved into ‘/tmp/mnisttraining’ as we defined our model directory in the RunConfig above. Checkpoints can be used to re-run validation, continue training or for used to create a TensorFlow Serving instance. The summaries can be viewed in Tensorboard by running tensorboard --logdir=/tmp/mnisttraining from the terminal and you will see the metrics we defined in our model_fn showing up there. When breaking it over a blog post, the estimator API can seem more effort than it’s worth. But remember, a lot of these elements are repeatable across any experiment you create or model you choose to implement. And when you run this on your laptop and take it as is and run it on multiple GPUs you will find that everything ‘just works’. In a following post I will look at how to take this experiment and with only a few changes create a Serving endpoint we can host in the cloud.

All the code from this sample can be found at