From db723bbd4832c9901363586f08543e43ef7ffa4e Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Thu, 26 Jan 2017 22:33:34 +0300 Subject: refactoring Context managers everywhere! --- bin/mutual_friends.py | 58 +++++++++++++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 27 deletions(-) (limited to 'bin/mutual_friends.py') diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py index d8591cf..76bf00f 100644 --- a/bin/mutual_friends.py +++ b/bin/mutual_friends.py @@ -5,6 +5,7 @@ import argparse from collections import OrderedDict +from contextlib import contextmanager import csv from enum import Enum import json @@ -28,12 +29,6 @@ class OutputWriterCSV: def __init__(self, fd=sys.stdout): self._writer = csv.writer(fd, lineterminator='\n') - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def write_mutual_friends(self, friend_list): for user in friend_list: user = _filter_user_fields(user) @@ -42,18 +37,13 @@ class OutputWriterCSV: class OutputWriterJSON: def __init__(self, fd=sys.stdout): self._fd = fd - self._arr = [] - - def __enter__(self): - return self - - def __exit__(self, *args): - self._fd.write(json.dumps(self._arr, indent=3, ensure_ascii=False)) - self._fd.write('\n') def write_mutual_friends(self, friend_list): + arr = [] for user in friend_list: - self._arr.append(_filter_user_fields(user)) + arr.append(_filter_user_fields(user)) + self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False)) + self._fd.write('\n') class OutputFormat(Enum): CSV = 'csv' @@ -62,13 +52,29 @@ class OutputFormat(Enum): def __str__(self): return self.value - def create_writer(self, fd=sys.stdout): - if self is OutputFormat.CSV: - return OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - return OutputWriterJSON(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_file(path) as fd: + if self is OutputFormat.CSV: + yield OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + yield OutputWriterJSON(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) + + @staticmethod + @contextmanager + def _open_file(path=None): + fd = sys.stdout + if path is None: + pass else: - raise NotImplementedError('unsupported output format: ' + str(self)) + fd = open(path, 'w', encoding='utf-8') + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() def _parse_output_format(s): try: @@ -85,26 +91,24 @@ def _parse_args(args=None): parser.add_argument('uids', metavar='UID', nargs='+', help='user IDs or "screen names"') - parser.add_argument('-f', '--format', dest='fmt', + parser.add_argument('-f', '--format', dest='out_fmt', type=_parse_output_format, default=OutputFormat.CSV, choices=OutputFormat, help='specify output format') - parser.add_argument('-o', '--output', metavar='PATH', dest='fd', - type=argparse.FileType('w', encoding='utf-8'), - default=sys.stdout, + parser.add_argument('-o', '--output', metavar='PATH', dest='out_path', help='set output file path (standard output by default)') return parser.parse_args(args) -def write_mutual_friends(uids, fmt=OutputFormat.CSV, fd=sys.stdout): +def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): api = API(Language.EN) users = api.users_get(uids) friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) mutual_friends = frozenset.intersection(*friend_lists) - with fmt.create_writer(fd) as writer: + with out_fmt.create_writer(out_path) as writer: writer.write_mutual_friends(mutual_friends) def main(args=None): -- cgit v1.2.3