Speeding up TensorFlow development and debug in terminal

For most of my development, I use Jupyter notebooks which are fantastic for iterative development, but running in managed environments such as Google’s ML Engine require python scripts. You can obviously run these locally with python in the terminal or your IDE, but the loop from debug, terminate, change and re-run is rather slow (from what I can tell due to import speed of tensorflow and other imported packages). I wanted to be able to keep these imports in memory (like Jupyter) and just re-run a single function during development. This is my workflow/setup for such a process.

Reloading code/modules

Running a python import statement a second time is basically a no-op: this is great for our import tensorflow as tf but not so good when we want to reload our own code/modules. If we had run import trainer.train as t to import all of our code, a change in these functions would require a reload.

from importlib import reload
reload(t)

Command Line Arguments

Next stage is to move main program logic from __main__ into it’s own function. Leaving the argparse logic there and allowing our execution to be called via the function. For an example to run training that accepts a data directory parameter, I would use:

from argparse import ArgumentParser

def run_train(data_dir, **kwargs):
    print('Loading data from ', data_dir)

if __name__ == '__main__':
    parser = ArgumentParser()
    parser.add_argument('--data-dir')
    args = parser.parse_args()

    run_train(**vars(args))

With this setup, we can call our script from the command line as will as from an interactive terminal. The **vars(args)) line will give you a dictionary of arguments via vars() and then give keyword arguments from that via the **. Note if you use parser.parse_known_args() will return a tuple so will instead be run_train(**vars(args[0])).

Debugging Worfklow

With these tidbits, I can now edit and re-run a single function quickly. Such a series might be

python3

from importlib import reload
import trainer.train as t

reload(t); t.run_train(data_dir='/data')