From db723bbd4832c9901363586f08543e43ef7ffa4e Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Thu, 26 Jan 2017 22:33:34 +0300 Subject: refactoring Context managers everywhere! --- bin/mutual_friends.py | 58 +++++++++++++++++++++++++++----------------------- bin/online_sessions.py | 57 ++++++++++++++++++++++++++++++++----------------- bin/track_status.py | 18 ++++++---------- 3 files changed, 76 insertions(+), 57 deletions(-) (limited to 'bin') diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py index d8591cf..76bf00f 100644 --- a/bin/mutual_friends.py +++ b/bin/mutual_friends.py @@ -5,6 +5,7 @@ import argparse from collections import OrderedDict +from contextlib import contextmanager import csv from enum import Enum import json @@ -28,12 +29,6 @@ class OutputWriterCSV: def __init__(self, fd=sys.stdout): self._writer = csv.writer(fd, lineterminator='\n') - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def write_mutual_friends(self, friend_list): for user in friend_list: user = _filter_user_fields(user) @@ -42,18 +37,13 @@ class OutputWriterCSV: class OutputWriterJSON: def __init__(self, fd=sys.stdout): self._fd = fd - self._arr = [] - - def __enter__(self): - return self - - def __exit__(self, *args): - self._fd.write(json.dumps(self._arr, indent=3, ensure_ascii=False)) - self._fd.write('\n') def write_mutual_friends(self, friend_list): + arr = [] for user in friend_list: - self._arr.append(_filter_user_fields(user)) + arr.append(_filter_user_fields(user)) + self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False)) + self._fd.write('\n') class OutputFormat(Enum): CSV = 'csv' @@ -62,13 +52,29 @@ class OutputFormat(Enum): def __str__(self): return self.value - def create_writer(self, fd=sys.stdout): - if self is OutputFormat.CSV: - return OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - return OutputWriterJSON(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_file(path) as fd: + if self is OutputFormat.CSV: + yield OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + yield OutputWriterJSON(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) + + @staticmethod + @contextmanager + def _open_file(path=None): + fd = sys.stdout + if path is None: + pass else: - raise NotImplementedError('unsupported output format: ' + str(self)) + fd = open(path, 'w', encoding='utf-8') + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() def _parse_output_format(s): try: @@ -85,26 +91,24 @@ def _parse_args(args=None): parser.add_argument('uids', metavar='UID', nargs='+', help='user IDs or "screen names"') - parser.add_argument('-f', '--format', dest='fmt', + parser.add_argument('-f', '--format', dest='out_fmt', type=_parse_output_format, default=OutputFormat.CSV, choices=OutputFormat, help='specify output format') - parser.add_argument('-o', '--output', metavar='PATH', dest='fd', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('-o', '--output', metavar='PATH', dest='out_path', help='set output file path (standard output by default)') return parser.parse_args(args) -def write_mutual_friends(uids, fmt=OutputFormat.CSV, fd=sys.stdout): +def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): api = API(Language.EN) users = api.users_get(uids) friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) mutual_friends = frozenset.intersection(*friend_lists) - with fmt.create_writer(fd) as writer: + with out_fmt.create_writer(out_path) as writer: writer.write_mutual_friends(mutual_friends) def main(args=None): diff --git a/bin/online_sessions.py b/bin/online_sessions.py index ec26689..7a763e5 100644 --- a/bin/online_sessions.py +++ b/bin/online_sessions.py @@ -6,6 +6,7 @@ import argparse import csv from collections import OrderedDict +from contextlib import contextmanager from datetime import datetime, timedelta, timezone from enum import Enum import json @@ -244,15 +245,34 @@ class OutputFormat(Enum): JSON = 'json' PLOT = 'plot' - def create_writer(self, fd): - if self is OutputFormat.CSV: - return OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - return OutputWriterJSON(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_file(path) as fd: + if self is OutputFormat.CSV: + yield OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + yield OutputWriterJSON(fd) + elif self is OutputFormat.PLOT: + yield OutputWriterPlot(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) + + @contextmanager + def _open_file(self, path=None): + fd = sys.stdout + if path is None: + pass + elif self is OutputFormat.CSV or self is OutputFormat.JSON: + fd = open(path, 'w', encoding='utf-8') elif self is OutputFormat.PLOT: - return OutputWriterPlot(fd) + fd = open(path, 'wb') else: raise NotImplementedError('unsupported output format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() def __str__(self): return self.value @@ -293,12 +313,9 @@ def _parse_args(args=None): parser = argparse.ArgumentParser( description='View/visualize the amount of time people spend online.') - parser.add_argument('db_fd', metavar='input', - type=argparse.FileType('r', encoding='utf-8'), - help='database file path') - parser.add_argument('fd', metavar='output', nargs='?', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('db_path', metavar='input', nargs='?', + help='database file path (standard input by default)') + parser.add_argument('out_path', metavar='output', nargs='?', help='output file path (standard output by default)') parser.add_argument('-g', '--group-by', type=_parse_group_by, @@ -310,7 +327,7 @@ def _parse_args(args=None): default=DatabaseFormat.CSV, choices=DatabaseFormat, help='specify database format') - parser.add_argument('-o', '--output-format', dest='fmt', + parser.add_argument('-o', '--output-format', dest='out_fmt', type=_parse_output_format, choices=OutputFormat, default=OutputFormat.CSV, @@ -325,8 +342,8 @@ def _parse_args(args=None): return parser.parse_args(args) def process_online_sessions( - db_fd, db_fmt=DatabaseFormat.CSV, - fd=sys.stdout, fmt=OutputFormat.CSV, + db_path=None, db_fmt=DatabaseFormat.CSV, + out_path=None, out_fmt=OutputFormat.CSV, group_by=GroupBy.USER, time_from=None, time_to=None): @@ -334,10 +351,12 @@ def process_online_sessions( if time_from > time_to: time_from, time_to = time_to, time_from - with db_fmt.create_reader(db_fd) as db_reader: - output_writer = fmt.create_writer(fd) - output_writer.process_database( - group_by, db_reader, time_from=time_from, time_to=time_to) + with db_fmt.create_reader(db_path) as db_reader: + with out_fmt.create_writer(out_path) as out_writer: + out_writer.process_database( + group_by, db_reader, + time_from=time_from, + time_to=time_to) def main(args=None): process_online_sessions(**vars(_parse_args(args))) diff --git a/bin/track_status.py b/bin/track_status.py index d8d8908..404c1ec 100644 --- a/bin/track_status.py +++ b/bin/track_status.py @@ -40,36 +40,32 @@ def _parse_args(args=None): type=_parse_positive_integer, default=DEFAULT_TIMEOUT, help='set refresh interval') - parser.add_argument('-l', '--log', metavar='PATH', dest='log_fd', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('-l', '--log', metavar='PATH', dest='log_path', help='set log file path (standard output by default)') parser.add_argument('-f', '--format', dest='db_fmt', type=_parse_database_format, choices=DatabaseFormat, default=DEFAULT_DB_FORMAT, help='specify database format') - parser.add_argument('-o', '--output', metavar='PATH', dest='db_fd', - type=argparse.FileType('w', encoding='utf-8'), - default=None, + parser.add_argument('-o', '--output', metavar='PATH', dest='db_path', help='set database file path') return parser.parse_args(args) def track_status( uids, timeout=DEFAULT_TIMEOUT, - log_fd=sys.stdout, - db_fd=None, db_fmt=DEFAULT_DB_FORMAT): + log_path=None, + db_path=None, db_fmt=DEFAULT_DB_FORMAT): api = API(Language.EN, deactivated_users=False) tracker = StatusTracker(api, timeout) - if db_fmt is DatabaseFormat.LOG or db_fd is None: + if db_fmt is DatabaseFormat.LOG or db_path is None: db_fmt = DatabaseFormat.NULL - with DatabaseFormat.LOG.create_writer(log_fd) as log_writer: + with DatabaseFormat.LOG.create_writer(log_path) as log_writer: tracker.add_database_writer(log_writer) - with db_fmt.create_writer(db_fd) as db_writer: + with db_fmt.create_writer(db_path) as db_writer: tracker.add_database_writer(db_writer) tracker.loop(uids) -- cgit v1.2.3