Convert and using the MNIST dataset as TFRecords

TFRecords are TensorFlow’s native binary data format and is the recommended way to store your data for streaming data. Using the TFRecordReader is also a very convenient way to subsequently get these records into your model.

The data

We will use the well known MNIST dataset for handwritten digit recognition as a sample. This is easily retrieved from tensorflow via:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(

We then have mnist.validation, mnist.train and mnist.test data sets.

Creating TFRecords

TFRecords contain Example instances for each data point of which each Example containers some Features. To write these out to disk we use a TFRecordWriter. Let’s use a single MNIST data sample to show an example:

image = mnist.train.images[0]
image_label = mnist.train.labels[0]
_, rows, cols, depth = mnist.train.images.shape

with tf.python_io.TFRecordWriter(filename) as writer:
    image_raw = image.tostring()
    example = tf.train.Example(features=tf.train.Features(feature={
        'height': _int64_feature(rows),
        'width': _int64_feature(cols),
        'depth': _int64_feature(depth),
        'label': _int64_feature(int(image_label)),
        'image_raw': _bytes_feature(image_raw)

Feature entries are protobuf instances and the TF documentation doesn’t have too many details, but linking to the feature.proto definitions, the Feature has one of BytesList, FloatList or Int64List. Here we are only using bytes and int features so creating two helper functions to simplify the record creation is handy as so:

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

The single example higlights the specific code, but here’s the entire sample to demonstrate sharding the data as well

def convert_to(data_set, name:str, data_directory:str, num_shards:int=1):
    """Convert the dataset into TFRecords on disk
        data_set:       The MNIST data set to convert
        name:           The name of the data set
        data_directory: The directory where records will be stored
        num_shards:     The number of files on disk to separate records into

    num_examples, rows, cols, depth = data_set.images.shape

    data_set = list(zip(data_set.images, data_set.labels))

    def _process_examples(example_dataset, filename:str):
        print(f'Processing {filename} data')
        dataset_length = len(example_dataset)
        with tf.python_io.TFRecordWriter(filename) as writer:
            for index, (image, label) in enumerate(example_dataset):
                sys.stdout.write(f"\rProcessing sample {index+1} of {dataset_length}")

                image_raw = image.tostring()
                example = tf.train.Example(features=tf.train.Features(feature={
                    'height': _int64_feature(rows),
                    'width': _int64_feature(cols),
                    'depth': _int64_feature(depth),
                    'label': _int64_feature(int(label)),
                    'image_raw': _bytes_feature(image_raw)
    if num_shards == 1:
        _process_examples(data_set, _data_path(data_directory, name))
        sharded_dataset = np.array_split(data_set, num_shards)
        for shard, dataset in enumerate(sharded_dataset):
            _process_examples(dataset, _data_path(data_directory, f'{name}-{shard+1}'))
convert_to(mnist.validation, 'validation', data_directory)
convert_to(mnist.train, 'train', data_directory, num_shards=10)
convert_to(mnist.test, 'test', data_directory)

Reading the data

If you are using the recommended Dataset API, we can use the TFRecordDataset to read in one or more TFRecord files shown in the example below. The main difference from any other use of the Dataset API is how we parse out the sample. We tell the tf.parse_single_example what features and types we want retrieved and then get them into an appropriate format for our model.

def data_input_fn(filenames, batch_size=1000, shuffle=False):
    def _parse(record):
            'label': tf.FixedLenFeature([], tf.int64),
            'image_raw': tf.FixedLenFeature([], tf.string)
        parsed_record = tf.parse_single_example(record, features)
        image = tf.decode_raw(parsed_record['image_raw'], tf.float32)

        label = tf.cast(parsed_record['label'], tf.int32)

        return image, label
    def _input_fn():
        dataset = (
        if shuffle:
            dataset = dataset.shuffle(buffer_size=10_000)

        dataset = dataset.repeat(None) # Infinite iterations: let experiment determine num_epochs
        dataset = dataset.batch(batch_size)
        iterator = dataset.make_one_shot_iterator()
        features, labels = iterator.get_next()
        return features, labels
    return _input_fn

And then calling this:

train_input_fn = data_input_fn(glob.glob('/path/to/data/train-*.tfrecords'), shuffle=True)
validation_input_fn = data_input_fn('/path/to/data/validation.tfrecords')

Full code available at