aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/bin/mutual_friends.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bin/mutual_friends.py58
1 files changed, 31 insertions, 27 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py
index d8591cf..76bf00f 100644
--- a/bin/mutual_friends.py
+++ b/bin/mutual_friends.py
@@ -5,6 +5,7 @@
import argparse
from collections import OrderedDict
+from contextlib import contextmanager
import csv
from enum import Enum
import json
@@ -28,12 +29,6 @@ class OutputWriterCSV:
def __init__(self, fd=sys.stdout):
self._writer = csv.writer(fd, lineterminator='\n')
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def write_mutual_friends(self, friend_list):
for user in friend_list:
user = _filter_user_fields(user)
@@ -42,18 +37,13 @@ class OutputWriterCSV:
class OutputWriterJSON:
def __init__(self, fd=sys.stdout):
self._fd = fd
- self._arr = []
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- self._fd.write(json.dumps(self._arr, indent=3, ensure_ascii=False))
- self._fd.write('\n')
def write_mutual_friends(self, friend_list):
+ arr = []
for user in friend_list:
- self._arr.append(_filter_user_fields(user))
+ arr.append(_filter_user_fields(user))
+ self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False))
+ self._fd.write('\n')
class OutputFormat(Enum):
CSV = 'csv'
@@ -62,13 +52,29 @@ class OutputFormat(Enum):
def __str__(self):
return self.value
- def create_writer(self, fd=sys.stdout):
- if self is OutputFormat.CSV:
- return OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- return OutputWriterJSON(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_file(path) as fd:
+ if self is OutputFormat.CSV:
+ yield OutputWriterCSV(fd)
+ elif self is OutputFormat.JSON:
+ yield OutputWriterJSON(fd)
+ else:
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
+ @staticmethod
+ @contextmanager
+ def _open_file(path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
else:
- raise NotImplementedError('unsupported output format: ' + str(self))
+ fd = open(path, 'w', encoding='utf-8')
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
def _parse_output_format(s):
try:
@@ -85,26 +91,24 @@ def _parse_args(args=None):
parser.add_argument('uids', metavar='UID', nargs='+',
help='user IDs or "screen names"')
- parser.add_argument('-f', '--format', dest='fmt',
+ parser.add_argument('-f', '--format', dest='out_fmt',
type=_parse_output_format,
default=OutputFormat.CSV,
choices=OutputFormat,
help='specify output format')
- parser.add_argument('-o', '--output', metavar='PATH', dest='fd',
- type=argparse.FileType('w', encoding='utf-8'),
- default=sys.stdout,
+ parser.add_argument('-o', '--output', metavar='PATH', dest='out_path',
help='set output file path (standard output by default)')
return parser.parse_args(args)
-def write_mutual_friends(uids, fmt=OutputFormat.CSV, fd=sys.stdout):
+def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV):
api = API(Language.EN)
users = api.users_get(uids)
friend_lists = (frozenset(_query_friend_list(api, user)) for user in users)
mutual_friends = frozenset.intersection(*friend_lists)
- with fmt.create_writer(fd) as writer:
+ with out_fmt.create_writer(out_path) as writer:
writer.write_mutual_friends(mutual_friends)
def main(args=None):