Created
December 17, 2020 13:02
-
-
Save HedgehogCode/1037d9756692b618c086a22de5ad8ab0 to your computer and use it in GitHub Desktop.
Filter values from a Tensorboard log directory.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| """Filter values by tags from a tensorboard tf record. | |
| """ | |
| from __future__ import print_function | |
| import os | |
| import re | |
| import sys | |
| import argparse | |
| from tensorflow.python.lib.io import tf_record | |
| from tensorboard.compat.proto import event_pb2 | |
| from tensorboard.compat.proto import summary_pb2 | |
| from tensorboard.summary.writer.event_file_writer import EventFileWriter | |
| def events_iterator(path): | |
| for r in tf_record.tf_record_iterator(path): | |
| yield event_pb2.Event.FromString(r) | |
| def main(args): | |
| exclude_regex = re.compile(args.exclude) | |
| writer = EventFileWriter(args.output_log_dir) | |
| # Loop over event files | |
| for filename in os.listdir(args.log_dir): | |
| path = os.path.join(args.log_dir, filename) | |
| # Loop over events | |
| for event in events_iterator(path): | |
| # Not a summary: Always include | |
| if not event.HasField('summary'): | |
| writer.add_event(event) | |
| # A summary: Filter values with excluded tags | |
| included_values = [] | |
| for value in event.summary.value: | |
| if not exclude_regex.match(value.tag): | |
| included_values.append(value) | |
| if included_values: | |
| summary = summary_pb2.Summary(value=included_values) | |
| new_event = event_pb2.Event( | |
| summary=summary, wall_time=event.wall_time, step=event.step) | |
| writer.add_event(new_event) | |
| writer.flush() | |
| def dir_path(string): | |
| if os.path.isdir(string): | |
| return string | |
| else: | |
| raise NotADirectoryError(string) | |
| def parse_args(arguments): | |
| """Parse the command line arguments.""" | |
| parser = argparse.ArgumentParser( | |
| description=__doc__, | |
| formatter_class=argparse.RawDescriptionHelpFormatter) | |
| parser.add_argument('log_dir', help="Log directory", type=dir_path) | |
| parser.add_argument('output_log_dir', help="Output log directory", type=str) | |
| parser.add_argument('-e', '--exclude', | |
| help="Regex for tags which should be excluded.", type=str) | |
| return parser.parse_args(arguments) | |
| if __name__ == '__main__': | |
| sys.exit(main(parse_args(sys.argv[1:]))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment