From db723bbd4832c9901363586f08543e43ef7ffa4e Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Thu, 26 Jan 2017 22:33:34 +0300 Subject: refactoring Context managers everywhere! --- bin/online_sessions.py | 57 +++++++++++++++++++++++++++++++++----------------- 1 file changed, 38 insertions(+), 19 deletions(-) (limited to 'bin/online_sessions.py') diff --git a/bin/online_sessions.py b/bin/online_sessions.py index ec26689..7a763e5 100644 --- a/bin/online_sessions.py +++ b/bin/online_sessions.py @@ -6,6 +6,7 @@ import argparse import csv from collections import OrderedDict +from contextlib import contextmanager from datetime import datetime, timedelta, timezone from enum import Enum import json @@ -244,15 +245,34 @@ class OutputFormat(Enum): JSON = 'json' PLOT = 'plot' - def create_writer(self, fd): - if self is OutputFormat.CSV: - return OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - return OutputWriterJSON(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_file(path) as fd: + if self is OutputFormat.CSV: + yield OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + yield OutputWriterJSON(fd) + elif self is OutputFormat.PLOT: + yield OutputWriterPlot(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) + + @contextmanager + def _open_file(self, path=None): + fd = sys.stdout + if path is None: + pass + elif self is OutputFormat.CSV or self is OutputFormat.JSON: + fd = open(path, 'w', encoding='utf-8') elif self is OutputFormat.PLOT: - return OutputWriterPlot(fd) + fd = open(path, 'wb') else: raise NotImplementedError('unsupported output format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() def __str__(self): return self.value @@ -293,12 +313,9 @@ def _parse_args(args=None): parser = argparse.ArgumentParser( description='View/visualize the amount of time people spend online.') - parser.add_argument('db_fd', metavar='input', - type=argparse.FileType('r', encoding='utf-8'), - help='database file path') - parser.add_argument('fd', metavar='output', nargs='?', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('db_path', metavar='input', nargs='?', + help='database file path (standard input by default)') + parser.add_argument('out_path', metavar='output', nargs='?', help='output file path (standard output by default)') parser.add_argument('-g', '--group-by', type=_parse_group_by, @@ -310,7 +327,7 @@ def _parse_args(args=None): default=DatabaseFormat.CSV, choices=DatabaseFormat, help='specify database format') - parser.add_argument('-o', '--output-format', dest='fmt', + parser.add_argument('-o', '--output-format', dest='out_fmt', type=_parse_output_format, choices=OutputFormat, default=OutputFormat.CSV, @@ -325,8 +342,8 @@ def _parse_args(args=None): return parser.parse_args(args) def process_online_sessions( - db_fd, db_fmt=DatabaseFormat.CSV, - fd=sys.stdout, fmt=OutputFormat.CSV, + db_path=None, db_fmt=DatabaseFormat.CSV, + out_path=None, out_fmt=OutputFormat.CSV, group_by=GroupBy.USER, time_from=None, time_to=None): @@ -334,10 +351,12 @@ def process_online_sessions( if time_from > time_to: time_from, time_to = time_to, time_from - with db_fmt.create_reader(db_fd) as db_reader: - output_writer = fmt.create_writer(fd) - output_writer.process_database( - group_by, db_reader, time_from=time_from, time_to=time_to) + with db_fmt.create_reader(db_path) as db_reader: + with out_fmt.create_writer(out_path) as out_writer: + out_writer.process_database( + group_by, db_reader, + time_from=time_from, + time_to=time_to) def main(args=None): process_online_sessions(**vars(_parse_args(args))) -- cgit v1.2.3