diff options
Diffstat (limited to '')
-rw-r--r-- | bin/mutual_friends.py | 64 |
1 files changed, 27 insertions, 37 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py index 76bf00f..3b0576f 100644 --- a/bin/mutual_friends.py +++ b/bin/mutual_friends.py @@ -3,17 +3,17 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. +import abc import argparse from collections import OrderedDict -from contextlib import contextmanager -import csv from enum import Enum -import json import sys from vk.api import API, Language from vk.user import UserField +from .utils import output + _OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME def _query_friend_list(api, user): @@ -25,25 +25,25 @@ def _filter_user_fields(user): new_user[str(field)] = user[field] if field in user else None return new_user -class OutputWriterCSV: +class OutputSinkMutualFriends(metaclass=abc.ABCMeta): + @abc.abstractmethod + def write_mutual_friends(self, friend_list): + pass + +class OutputSinkCSV(OutputSinkMutualFriends): def __init__(self, fd=sys.stdout): - self._writer = csv.writer(fd, lineterminator='\n') + self._writer = output.OutputWriterCSV(fd) def write_mutual_friends(self, friend_list): for user in friend_list: - user = _filter_user_fields(user) - self._writer.writerow(user.values()) + self._writer.write_row(user.values()) -class OutputWriterJSON: +class OutputSinkJSON(OutputSinkMutualFriends): def __init__(self, fd=sys.stdout): - self._fd = fd + self._writer = output.OutputWriterJSON(fd) def write_mutual_friends(self, friend_list): - arr = [] - for user in friend_list: - arr.append(_filter_user_fields(user)) - self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False)) - self._fd.write('\n') + self._writer.write(friend_list) class OutputFormat(Enum): CSV = 'csv' @@ -52,29 +52,17 @@ class OutputFormat(Enum): def __str__(self): return self.value - @contextmanager - def create_writer(self, path=None): - with self._open_file(path) as fd: - if self is OutputFormat.CSV: - yield OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - yield OutputWriterJSON(fd) - else: - raise NotImplementedError('unsupported output format: ' + str(self)) - @staticmethod - @contextmanager - def _open_file(path=None): - fd = sys.stdout - if path is None: - pass + def open_file(path=None): + return output.open_text_file(path) + + def create_sink(self, fd=sys.stdout): + if self is OutputFormat.CSV: + return OutputSinkCSV(fd) + elif self is OutputFormat.JSON: + return OutputSinkJSON(fd) else: - fd = open(path, 'w', encoding='utf-8') - try: - yield fd - finally: - if fd is not sys.stdout: - fd.close() + raise NotImplementedError('unsupported output format: ' + str(self)) def _parse_output_format(s): try: @@ -107,9 +95,11 @@ def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) mutual_friends = frozenset.intersection(*friend_lists) + mutual_friends = [_filter_user_fields(user) for user in mutual_friends] - with out_fmt.create_writer(out_path) as writer: - writer.write_mutual_friends(mutual_friends) + with out_fmt.open_file(out_path) as out_fd: + sink = out_fmt.create_sink(out_fd) + sink.write_mutual_friends(mutual_friends) def main(args=None): write_mutual_friends(**vars(_parse_args(args))) |