Skip to content

Instantly share code, notes, and snippets.

@nithishdivakar
Created January 19, 2018 11:32
Show Gist options
  • Select an option

  • Save nithishdivakar/c50696c5304555253b6a1a6aeff28d55 to your computer and use it in GitHub Desktop.

Select an option

Save nithishdivakar/c50696c5304555253b6a1a6aeff28d55 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import argparse
# command line arguments
parser = argparse.ArgumentParser(
description='Convert a checkpoint to frozen graph')
parser.add_argument(
'--checkpoint',
type=str,
default="model.ckpt",
help='The checkpoint file to be converted')
parser.add_argument(
'--graph',
type=str,
default="graph.pb",
help='Output graph name.')
args = parser.parse_args()
# add pb extension if not present
if not args.graph.endswith(".pb"):
args.graph = args.graph + ".pb"
# initialise the saver
saver = tf.train.Saver()
with tf.Session() as sess:
# restore all variables from checkpoint
saver.restore(sess, args.checkpoint)
# node that are required output nodes
output_node_names = ["list","of","all","output","node","names"]
# We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
# The graph_def is used to retrieve the nodes
output_node_names # The output node names are used to select the usefull nodes
)
# convert variables to constants
output_graph_def = tf.graph_util.remove_training_nodes(output_graph_def)
# Finally we serialize and dump the output graph to the filesystem
output_graph = args.graph
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("Frozen graph file {} created successfully".format(args.graph))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment