From db723bbd4832c9901363586f08543e43ef7ffa4e Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Thu, 26 Jan 2017 22:33:34 +0300 Subject: refactoring Context managers everywhere! --- bin/mutual_friends.py | 58 ++++++++++++++++++++------------------ bin/online_sessions.py | 57 ++++++++++++++++++++++++------------- bin/track_status.py | 18 +++++------- vk/tracking/db/backend/csv.py | 12 -------- vk/tracking/db/backend/log.py | 6 ---- vk/tracking/db/backend/null.py | 16 ++--------- vk/tracking/db/format.py | 64 ++++++++++++++++++++++++++++++++++-------- 7 files changed, 131 insertions(+), 100 deletions(-) diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py index d8591cf..76bf00f 100644 --- a/bin/mutual_friends.py +++ b/bin/mutual_friends.py @@ -5,6 +5,7 @@ import argparse from collections import OrderedDict +from contextlib import contextmanager import csv from enum import Enum import json @@ -28,12 +29,6 @@ class OutputWriterCSV: def __init__(self, fd=sys.stdout): self._writer = csv.writer(fd, lineterminator='\n') - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def write_mutual_friends(self, friend_list): for user in friend_list: user = _filter_user_fields(user) @@ -42,18 +37,13 @@ class OutputWriterCSV: class OutputWriterJSON: def __init__(self, fd=sys.stdout): self._fd = fd - self._arr = [] - - def __enter__(self): - return self - - def __exit__(self, *args): - self._fd.write(json.dumps(self._arr, indent=3, ensure_ascii=False)) - self._fd.write('\n') def write_mutual_friends(self, friend_list): + arr = [] for user in friend_list: - self._arr.append(_filter_user_fields(user)) + arr.append(_filter_user_fields(user)) + self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False)) + self._fd.write('\n') class OutputFormat(Enum): CSV = 'csv' @@ -62,13 +52,29 @@ class OutputFormat(Enum): def __str__(self): return self.value - def create_writer(self, fd=sys.stdout): - if self is OutputFormat.CSV: - return OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - return OutputWriterJSON(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_file(path) as fd: + if self is OutputFormat.CSV: + yield OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + yield OutputWriterJSON(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) + + @staticmethod + @contextmanager + def _open_file(path=None): + fd = sys.stdout + if path is None: + pass else: - raise NotImplementedError('unsupported output format: ' + str(self)) + fd = open(path, 'w', encoding='utf-8') + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() def _parse_output_format(s): try: @@ -85,26 +91,24 @@ def _parse_args(args=None): parser.add_argument('uids', metavar='UID', nargs='+', help='user IDs or "screen names"') - parser.add_argument('-f', '--format', dest='fmt', + parser.add_argument('-f', '--format', dest='out_fmt', type=_parse_output_format, default=OutputFormat.CSV, choices=OutputFormat, help='specify output format') - parser.add_argument('-o', '--output', metavar='PATH', dest='fd', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('-o', '--output', metavar='PATH', dest='out_path', help='set output file path (standard output by default)') return parser.parse_args(args) -def write_mutual_friends(uids, fmt=OutputFormat.CSV, fd=sys.stdout): +def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): api = API(Language.EN) users = api.users_get(uids) friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) mutual_friends = frozenset.intersection(*friend_lists) - with fmt.create_writer(fd) as writer: + with out_fmt.create_writer(out_path) as writer: writer.write_mutual_friends(mutual_friends) def main(args=None): diff --git a/bin/online_sessions.py b/bin/online_sessions.py index ec26689..7a763e5 100644 --- a/bin/online_sessions.py +++ b/bin/online_sessions.py @@ -6,6 +6,7 @@ import argparse import csv from collections import OrderedDict +from contextlib import contextmanager from datetime import datetime, timedelta, timezone from enum import Enum import json @@ -244,15 +245,34 @@ class OutputFormat(Enum): JSON = 'json' PLOT = 'plot' - def create_writer(self, fd): - if self is OutputFormat.CSV: - return OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - return OutputWriterJSON(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_file(path) as fd: + if self is OutputFormat.CSV: + yield OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + yield OutputWriterJSON(fd) + elif self is OutputFormat.PLOT: + yield OutputWriterPlot(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) + + @contextmanager + def _open_file(self, path=None): + fd = sys.stdout + if path is None: + pass + elif self is OutputFormat.CSV or self is OutputFormat.JSON: + fd = open(path, 'w', encoding='utf-8') elif self is OutputFormat.PLOT: - return OutputWriterPlot(fd) + fd = open(path, 'wb') else: raise NotImplementedError('unsupported output format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() def __str__(self): return self.value @@ -293,12 +313,9 @@ def _parse_args(args=None): parser = argparse.ArgumentParser( description='View/visualize the amount of time people spend online.') - parser.add_argument('db_fd', metavar='input', - type=argparse.FileType('r', encoding='utf-8'), - help='database file path') - parser.add_argument('fd', metavar='output', nargs='?', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('db_path', metavar='input', nargs='?', + help='database file path (standard input by default)') + parser.add_argument('out_path', metavar='output', nargs='?', help='output file path (standard output by default)') parser.add_argument('-g', '--group-by', type=_parse_group_by, @@ -310,7 +327,7 @@ def _parse_args(args=None): default=DatabaseFormat.CSV, choices=DatabaseFormat, help='specify database format') - parser.add_argument('-o', '--output-format', dest='fmt', + parser.add_argument('-o', '--output-format', dest='out_fmt', type=_parse_output_format, choices=OutputFormat, default=OutputFormat.CSV, @@ -325,8 +342,8 @@ def _parse_args(args=None): return parser.parse_args(args) def process_online_sessions( - db_fd, db_fmt=DatabaseFormat.CSV, - fd=sys.stdout, fmt=OutputFormat.CSV, + db_path=None, db_fmt=DatabaseFormat.CSV, + out_path=None, out_fmt=OutputFormat.CSV, group_by=GroupBy.USER, time_from=None, time_to=None): @@ -334,10 +351,12 @@ def process_online_sessions( if time_from > time_to: time_from, time_to = time_to, time_from - with db_fmt.create_reader(db_fd) as db_reader: - output_writer = fmt.create_writer(fd) - output_writer.process_database( - group_by, db_reader, time_from=time_from, time_to=time_to) + with db_fmt.create_reader(db_path) as db_reader: + with out_fmt.create_writer(out_path) as out_writer: + out_writer.process_database( + group_by, db_reader, + time_from=time_from, + time_to=time_to) def main(args=None): process_online_sessions(**vars(_parse_args(args))) diff --git a/bin/track_status.py b/bin/track_status.py index d8d8908..404c1ec 100644 --- a/bin/track_status.py +++ b/bin/track_status.py @@ -40,36 +40,32 @@ def _parse_args(args=None): type=_parse_positive_integer, default=DEFAULT_TIMEOUT, help='set refresh interval') - parser.add_argument('-l', '--log', metavar='PATH', dest='log_fd', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('-l', '--log', metavar='PATH', dest='log_path', help='set log file path (standard output by default)') parser.add_argument('-f', '--format', dest='db_fmt', type=_parse_database_format, choices=DatabaseFormat, default=DEFAULT_DB_FORMAT, help='specify database format') - parser.add_argument('-o', '--output', metavar='PATH', dest='db_fd', - type=argparse.FileType('w', encoding='utf-8'), - default=None, + parser.add_argument('-o', '--output', metavar='PATH', dest='db_path', help='set database file path') return parser.parse_args(args) def track_status( uids, timeout=DEFAULT_TIMEOUT, - log_fd=sys.stdout, - db_fd=None, db_fmt=DEFAULT_DB_FORMAT): + log_path=None, + db_path=None, db_fmt=DEFAULT_DB_FORMAT): api = API(Language.EN, deactivated_users=False) tracker = StatusTracker(api, timeout) - if db_fmt is DatabaseFormat.LOG or db_fd is None: + if db_fmt is DatabaseFormat.LOG or db_path is None: db_fmt = DatabaseFormat.NULL - with DatabaseFormat.LOG.create_writer(log_fd) as log_writer: + with DatabaseFormat.LOG.create_writer(log_path) as log_writer: tracker.add_database_writer(log_writer) - with db_fmt.create_writer(db_fd) as db_writer: + with db_fmt.create_writer(db_path) as db_writer: tracker.add_database_writer(db_writer) tracker.loop(uids) diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py index f20c617..e607774 100644 --- a/vk/tracking/db/backend/csv.py +++ b/vk/tracking/db/backend/csv.py @@ -14,12 +14,6 @@ class Writer: self._fd = fd self._writer = csv.writer(fd, lineterminator='\n') - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def on_initial_status(self, user): self._write_record(user) self._fd.flush() @@ -47,12 +41,6 @@ class Reader(Iterable): def __init__(self, fd): self._reader = csv.reader(fd) - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def __iter__(self): return map(Reader._record_from_row, self._reader) diff --git a/vk/tracking/db/backend/log.py b/vk/tracking/db/backend/log.py index 6eebc35..f248d65 100644 --- a/vk/tracking/db/backend/log.py +++ b/vk/tracking/db/backend/log.py @@ -15,12 +15,6 @@ class Writer: datefmt='%Y-%m-%d %H:%M:%S')) self._logger.addHandler(handler) - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def info(self, msg): self._logger.info(msg) diff --git a/vk/tracking/db/backend/null.py b/vk/tracking/db/backend/null.py index 6454f84..3cdb41e 100644 --- a/vk/tracking/db/backend/null.py +++ b/vk/tracking/db/backend/null.py @@ -6,13 +6,7 @@ from collections.abc import Iterable class Writer: - def __init__(self, fd): - pass - - def __enter__(self): - return self - - def __exit__(self, *args): + def __init__(self): pass def on_initial_status(self, user): @@ -25,13 +19,7 @@ class Writer: pass class Reader(Iterable): - def __init__(self, fd): - pass - - def __enter__(self): - return self - - def __exit__(self, *args): + def __init__(self): pass def __iter__(self): diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py index 9d5c6e4..e1b34a1 100644 --- a/vk/tracking/db/format.py +++ b/vk/tracking/db/format.py @@ -3,7 +3,9 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. +from contextlib import contextmanager from enum import Enum +import sys from . import backend @@ -12,25 +14,65 @@ class Format(Enum): LOG = 'log' NULL = 'null' - def create_writer(self, fd): - if self is Format.CSV: - return backend.csv.Writer(fd) - elif self is Format.LOG: - return backend.log.Writer(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_output_file(path) as fd: + if self is Format.CSV: + yield backend.csv.Writer(fd) + elif self is Format.LOG: + yield backend.log.Writer(fd) + elif self is Format.NULL: + yield backend.null.Writer() + else: + raise NotImplementedError('unsupported database format: ' + str(self)) + + @contextmanager + def _open_output_file(self, path=None): + fd = sys.stdout + if path is None: + pass + elif self is Format.CSV or self is Format.LOG: + fd = open(path, 'w', encoding='utf-8') elif self is Format.NULL: - return backend.null.Writer(fd) + pass else: raise NotImplementedError('unsupported database format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() + + @contextmanager + def create_reader(self, path=None): + with self._open_input_file(path) as fd: + if self is Format.CSV: + yield backend.csv.Reader(fd) + elif self is Format.LOG: + raise NotImplementedError('cannot read from a log file') + elif self is Format.NULL: + yield backend.null.Reader() + else: + raise NotImplementedError('unsupported database format: ' + str(self)) - def create_reader(self, fd): - if self is Format.CSV: - return backend.csv.Reader(fd) + @contextmanager + def _open_input_file(self, path=None): + fd = sys.stdin + if path is None: + pass + elif self is Format.CSV: + fd = open(path, encoding='utf-8') elif self is Format.LOG: - raise NotImplementedError() + raise NotImplementedError('cannot read from a log file') elif self is Format.NULL: - return backend.null.Reader(fd) + pass else: raise NotImplementedError('unsupported database format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdin: + fd.close() def __str__(self): return self.value -- cgit v1.2.3