aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bin/mutual_friends.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py
index b72b317..bb59ef4 100644
--- a/bin/mutual_friends.py
+++ b/bin/mutual_friends.py
@@ -2,6 +2,7 @@
# This file is licensed under the terms of the MIT License.
# See LICENSE.txt for details.
+import argparse
from collections import OrderedDict
import csv
from enum import Enum
@@ -33,7 +34,7 @@ class OutputWriterCSV:
pass
def write_mutual_friends(self, friend_list):
- for user in mutual_friends:
+ for user in friend_list:
user = extract_output_fields(user)
self._writer.writerow(user.values())
@@ -47,6 +48,7 @@ class OutputWriterJSON:
def __exit__(self, *args):
self._fd.write(json.dumps(self._array, indent=3, ensure_ascii=False))
+ self._fd.write('\n')
def write_mutual_friends(self, friend_list):
for user in friend_list:
@@ -67,21 +69,19 @@ class OutputFormat(Enum):
else:
raise NotImplementedError('unsupported output format: ' + str(self))
-if __name__ == '__main__':
- import argparse
-
- def output_format(s):
- try:
- return OutputFormat(s)
- except ValueError:
- raise argparse.ArgumentError()
+def parse_output_format(s):
+ try:
+ return OutputFormat(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('invalid output format: ' + str(s))
+def parse_args(args=sys.argv):
parser = argparse.ArgumentParser(
description='Learn who your ex and her new boyfriend are both friends with.')
parser.add_argument(metavar='UID', dest='uids', nargs='+',
help='user IDs or "screen names"')
- parser.add_argument('--output-format', type=output_format,
+ parser.add_argument('--output-format', type=parse_output_format,
choices=tuple(fmt for fmt in OutputFormat),
default=OutputFormat.CSV,
help='specify output format')
@@ -89,7 +89,10 @@ if __name__ == '__main__':
default=sys.stdout,
help='set output file path (standard output by default)')
- args = parser.parse_args()
+ return parser.parse_args(args)
+
+def main(args=sys.argv):
+ args = parse_args(args)
api = API(Language.EN)
users = api.users_get(args.uids)
@@ -99,3 +102,6 @@ if __name__ == '__main__':
with args.output_format.create_writer(args.output) as writer:
writer.write_mutual_friends(mutual_friends)
+
+if __name__ == '__main__':
+ main()