Forked from nithishdivakar/tensorflow_checkpoint_to_graph.py
Created
March 4, 2018 01:13
-
-
Save qijiexu/5cb54e6246f5299c115138be5409996b to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import tensorflow as tf | |
| import argparse | |
| # command line arguments | |
| parser = argparse.ArgumentParser( | |
| description='Convert a checkpoint to frozen graph') | |
| parser.add_argument( | |
| '--checkpoint', | |
| type=str, | |
| default="model.ckpt", | |
| help='The checkpoint file to be converted') | |
| parser.add_argument( | |
| '--graph', | |
| type=str, | |
| default="graph.pb", | |
| help='Output graph name.') | |
| args = parser.parse_args() | |
| # add pb extension if not present | |
| if not args.graph.endswith(".pb"): | |
| args.graph = args.graph + ".pb" | |
| # initialise the saver | |
| saver = tf.train.Saver() | |
| with tf.Session() as sess: | |
| # restore all variables from checkpoint | |
| saver.restore(sess, args.checkpoint) | |
| # node that are required output nodes | |
| output_node_names = ["list","of","all","output","node","names"] | |
| # We use a built-in TF helper to export variables to constants | |
| output_graph_def = tf.graph_util.convert_variables_to_constants( | |
| sess, | |
| tf.get_default_graph().as_graph_def(), | |
| # The graph_def is used to retrieve the nodes | |
| output_node_names # The output node names are used to select the usefull nodes | |
| ) | |
| # convert variables to constants | |
| output_graph_def = tf.graph_util.remove_training_nodes(output_graph_def) | |
| # Finally we serialize and dump the output graph to the filesystem | |
| output_graph = args.graph | |
| with tf.gfile.GFile(output_graph, "wb") as f: | |
| f.write(output_graph_def.SerializeToString()) | |
| print("Frozen graph file {} created successfully".format(args.graph)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment