Skip to content

Instantly share code, notes, and snippets.

@qfgaohao
Created January 31, 2018 21:41
Show Gist options
  • Select an option

  • Save qfgaohao/214a2366ea6e0043903badfa8cd9c830 to your computer and use it in GitHub Desktop.

Select an option

Save qfgaohao/214a2366ea6e0043903badfa8cd9c830 to your computer and use it in GitHub Desktop.

Revisions

  1. qfgaohao created this gist Jan 31, 2018.
    49 changes: 49 additions & 0 deletions draw_tensorflow_graph_in_notebook.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,49 @@
    def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
    n = strip_def.node.add()
    n.MergeFrom(n0)
    if n.op == 'Const':
    tensor = n.attr['value'].tensor
    size = len(tensor.tensor_content)
    if size > max_const_size:
    tensor.tensor_content = tf.compat.as_bytes("<stripped %d bytes>"%size)
    return strip_def

    def rename_nodes(graph_def, rename_func):
    res_def = tf.GraphDef()
    for n0 in graph_def.node:
    n = res_def.node.add()
    n.MergeFrom(n0)
    n.name = rename_func(n.name)
    for i, s in enumerate(n.input):
    n.input[i] = rename_func(s) if s[0]!='^' else '^'+rename_func(s[1:])
    return res_def

    def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
    graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
    <script>
    function load() {{
    document.getElementById("{id}").pbtxt = {data};
    }}
    </script>
    <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
    <div style="height:600px">
    <tf-graph-basic id="{id}"></tf-graph-basic>
    </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
    <iframe seamless style="width:800px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

    # Visualizing the network graph. Be sure expand the "mixed" nodes to see their
    # internal structure. We are going to visualize "Conv2D" nodes.
    tmp_def = rename_nodes(graph.as_graph_def(), lambda s:"/".join(s.split('_',1)))
    show_graph(tmp_def)