diff options
author | Egor Tensin <Egor.Tensin@gmail.com> | 2017-01-28 02:01:44 +0300 |
---|---|---|
committer | Egor Tensin <Egor.Tensin@gmail.com> | 2017-01-28 02:01:44 +0300 |
commit | f989577d524f27c079af619b9d0b6b4a9d70c386 (patch) | |
tree | bad199636fc9cd7152e64e8ffb6075c13755412a | |
parent | refactoring (diff) | |
download | vk-scripts-f989577d524f27c079af619b9d0b6b4a9d70c386.tar.gz vk-scripts-f989577d524f27c079af619b9d0b6b4a9d70c386.zip |
bin: move file i/o to a separate module
-rw-r--r-- | bin/mutual_friends.py | 64 | ||||
-rw-r--r-- | bin/online_sessions.py | 101 | ||||
-rw-r--r-- | bin/utils/output.py | 43 |
3 files changed, 113 insertions, 95 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py index 76bf00f..3b0576f 100644 --- a/bin/mutual_friends.py +++ b/bin/mutual_friends.py @@ -3,17 +3,17 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. +import abc import argparse from collections import OrderedDict -from contextlib import contextmanager -import csv from enum import Enum -import json import sys from vk.api import API, Language from vk.user import UserField +from .utils import output + _OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME def _query_friend_list(api, user): @@ -25,25 +25,25 @@ def _filter_user_fields(user): new_user[str(field)] = user[field] if field in user else None return new_user -class OutputWriterCSV: +class OutputSinkMutualFriends(metaclass=abc.ABCMeta): + @abc.abstractmethod + def write_mutual_friends(self, friend_list): + pass + +class OutputSinkCSV(OutputSinkMutualFriends): def __init__(self, fd=sys.stdout): - self._writer = csv.writer(fd, lineterminator='\n') + self._writer = output.OutputWriterCSV(fd) def write_mutual_friends(self, friend_list): for user in friend_list: - user = _filter_user_fields(user) - self._writer.writerow(user.values()) + self._writer.write_row(user.values()) -class OutputWriterJSON: +class OutputSinkJSON(OutputSinkMutualFriends): def __init__(self, fd=sys.stdout): - self._fd = fd + self._writer = output.OutputWriterJSON(fd) def write_mutual_friends(self, friend_list): - arr = [] - for user in friend_list: - arr.append(_filter_user_fields(user)) - self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False)) - self._fd.write('\n') + self._writer.write(friend_list) class OutputFormat(Enum): CSV = 'csv' @@ -52,29 +52,17 @@ class OutputFormat(Enum): def __str__(self): return self.value - @contextmanager - def create_writer(self, path=None): - with self._open_file(path) as fd: - if self is OutputFormat.CSV: - yield OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - yield OutputWriterJSON(fd) - else: - raise NotImplementedError('unsupported output format: ' + str(self)) - @staticmethod - @contextmanager - def _open_file(path=None): - fd = sys.stdout - if path is None: - pass + def open_file(path=None): + return output.open_text_file(path) + + def create_sink(self, fd=sys.stdout): + if self is OutputFormat.CSV: + return OutputSinkCSV(fd) + elif self is OutputFormat.JSON: + return OutputSinkJSON(fd) else: - fd = open(path, 'w', encoding='utf-8') - try: - yield fd - finally: - if fd is not sys.stdout: - fd.close() + raise NotImplementedError('unsupported output format: ' + str(self)) def _parse_output_format(s): try: @@ -107,9 +95,11 @@ def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) mutual_friends = frozenset.intersection(*friend_lists) + mutual_friends = [_filter_user_fields(user) for user in mutual_friends] - with out_fmt.create_writer(out_path) as writer: - writer.write_mutual_friends(mutual_friends) + with out_fmt.open_file(out_path) as out_fd: + sink = out_fmt.create_sink(out_fd) + sink.write_mutual_friends(mutual_friends) def main(args=None): write_mutual_friends(**vars(_parse_args(args))) diff --git a/bin/online_sessions.py b/bin/online_sessions.py index 5a9bdc8..cb3f4bf 100644 --- a/bin/online_sessions.py +++ b/bin/online_sessions.py @@ -3,13 +3,11 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. +import abc import argparse -import csv from collections import OrderedDict -from contextlib import contextmanager from datetime import datetime, timedelta, timezone from enum import Enum -import json import sys from vk.tracking import OnlineSessionEnumerator @@ -17,6 +15,7 @@ from vk.tracking.db import Format as DatabaseFormat from vk.user import UserField from .utils.bar_chart import BarChartBuilder +from .utils import output class GroupBy(Enum): USER = 'user' @@ -24,6 +23,9 @@ class GroupBy(Enum): WEEKDAY = 'weekday' HOUR = 'hour' + def __str__(self): + return self.value + def group(self, db_reader, time_from=None, time_to=None): online_streaks = OnlineSessionEnumerator(time_from, time_to) if self is GroupBy.USER: @@ -37,9 +39,6 @@ class GroupBy(Enum): else: raise NotImplementedError('unsupported grouping: ' + str(self)) - def __str__(self): - return self.value - _OUTPUT_USER_FIELDS = ( UserField.UID, UserField.FIRST_NAME, @@ -47,6 +46,11 @@ _OUTPUT_USER_FIELDS = ( UserField.DOMAIN, ) +class OutputSinkOnlineSessions(metaclass=abc.ABCMeta): + @abc.abstractmethod + def process_database(self, group_by, db_reader, time_from=None, time_to=None): + pass + class OutputConverterCSV: @staticmethod def convert_user(user): @@ -64,9 +68,9 @@ class OutputConverterCSV: def convert_hour(hour): return [str(timedelta(hours=hour))] -class OutputWriterCSV: +class OutputSinkCSV(OutputSinkOnlineSessions): def __init__(self, fd=sys.stdout): - self._writer = csv.writer(fd, lineterminator='\n') + self._writer = output.OutputWriterCSV(fd) _CONVERT_KEY = { GroupBy.USER: OutputConverterCSV.convert_user, @@ -77,18 +81,15 @@ class OutputWriterCSV: @staticmethod def _key_to_row(group_by, key): - if group_by not in OutputWriterCSV._CONVERT_KEY: + if group_by not in OutputSinkCSV._CONVERT_KEY: raise NotImplementedError('unsupported grouping: ' + str(group_by)) - return OutputWriterCSV._CONVERT_KEY[group_by](key) + return OutputSinkCSV._CONVERT_KEY[group_by](key) def process_database(self, group_by, db_reader, time_from=None, time_to=None): for key, duration in group_by.group(db_reader, time_from, time_to).items(): row = self._key_to_row(group_by, key) row.append(str(duration)) - self._write_row(row) - - def _write_row(self, row): - self._writer.writerow(row) + self._writer.write_row(row) class OutputConverterJSON: _DATE_FIELD = 'date' @@ -124,9 +125,9 @@ class OutputConverterJSON: obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour)) return obj -class OutputWriterJSON: +class OutputSinkJSON(OutputSinkOnlineSessions): def __init__(self, fd=sys.stdout): - self._fd = fd + self._writer = output.OutputWriterJSON(fd) _DURATION_FIELD = 'duration' @@ -141,13 +142,9 @@ class OutputWriterJSON: @staticmethod def _key_to_object(group_by, key): - if not group_by in OutputWriterJSON._CONVERT_KEY: + if not group_by in OutputSinkJSON._CONVERT_KEY: raise NotImplementedError('unsupported grouping: ' + str(group_by)) - return OutputWriterJSON._CONVERT_KEY[group_by](key) - - def _write(self, entries): - self._fd.write(json.dumps(entries, indent=3, ensure_ascii=False)) - self._fd.write('\n') + return OutputSinkJSON._CONVERT_KEY[group_by](key) def process_database(self, group_by, db_reader, time_from=None, time_to=None): entries = [] @@ -155,7 +152,7 @@ class OutputWriterJSON: entry = self._key_to_object(group_by, key) entry[self._DURATION_FIELD] = str(duration) entries.append(entry) - self._write(entries) + self._writer.write(entries) class OutputConverterPlot: @staticmethod @@ -174,7 +171,7 @@ class OutputConverterPlot: def convert_hour(hour): return '{}:00'.format(hour) -class OutputWriterPlot: +class OutputSinkPlot(OutputSinkOnlineSessions): def __init__(self, fd=sys.stdout): self._fd = fd @@ -189,9 +186,9 @@ class OutputWriterPlot: @staticmethod def _format_key(group_by, key): - if group_by not in OutputWriterPlot._FORMAT_KEY: + if group_by not in OutputSinkPlot._FORMAT_KEY: raise NotImplementedError('unsupported grouping: ' + str(group_by)) - return OutputWriterPlot._FORMAT_KEY[group_by](key) + return OutputSinkPlot._FORMAT_KEY[group_by](key) @staticmethod def _format_duration(seconds, _): @@ -203,11 +200,11 @@ class OutputWriterPlot: @staticmethod def _extract_labels(group_by, durations): - return tuple(map(lambda key: OutputWriterPlot._format_key(group_by, key), durations.keys())) + return tuple(map(lambda key: OutputSinkPlot._format_key(group_by, key), durations.keys())) @staticmethod def _extract_values(durations): - return tuple(map(OutputWriterPlot._duration_to_seconds, durations.values())) + return tuple(map(OutputSinkPlot._duration_to_seconds, durations.values())) def process_database( self, group_by, db_reader, time_from=None, time_to=None): @@ -215,7 +212,7 @@ class OutputWriterPlot: durations = group_by.group(db_reader, time_from, time_to) bar_chart = BarChartBuilder() - bar_chart.set_title(OutputWriterPlot.TITLE) + bar_chart.set_title(OutputSinkPlot.TITLE) bar_chart.enable_grid_for_values() bar_chart.only_integer_values() bar_chart.set_property(bar_chart.get_values_labels(), @@ -245,37 +242,24 @@ class OutputFormat(Enum): JSON = 'json' PLOT = 'plot' - @contextmanager - def create_writer(self, path=None): - with self._open_file(path) as fd: - if self is OutputFormat.CSV: - yield OutputWriterCSV(fd) - elif self is OutputFormat.JSON: - yield OutputWriterJSON(fd) - elif self is OutputFormat.PLOT: - yield OutputWriterPlot(fd) - else: - raise NotImplementedError('unsupported output format: ' + str(self)) - - @contextmanager - def _open_file(self, path=None): - fd = sys.stdout - if path is None: - pass - elif self is OutputFormat.CSV or self is OutputFormat.JSON: - fd = open(path, 'w', encoding='utf-8') + def __str__(self): + return self.value + + def create_sink(self, fd=sys.stdout): + if self is OutputFormat.CSV: + return OutputSinkCSV(fd) + elif self is OutputFormat.JSON: + return OutputSinkJSON(fd) elif self is OutputFormat.PLOT: - fd = open(path, 'wb') + return OutputSinkPlot(fd) else: raise NotImplementedError('unsupported output format: ' + str(self)) - try: - yield fd - finally: - if fd is not sys.stdout: - fd.close() - def __str__(self): - return self.value + def open_file(self, path=None): + if self is OutputFormat.PLOT: + return output.open_binary_file(path) + else: + return output.open_text_file(path) def _parse_group_by(s): try: @@ -352,8 +336,9 @@ def process_online_sessions( time_from, time_to = time_to, time_from with db_fmt.create_reader(db_path) as db_reader: - with out_fmt.create_writer(out_path) as out_writer: - out_writer.process_database( + with out_fmt.open_file(out_path) as out_fd: + out_sink = out_fmt.create_sink(out_fd) + out_sink.process_database( group_by, db_reader, time_from=time_from, time_to=time_to) diff --git a/bin/utils/output.py b/bin/utils/output.py new file mode 100644 index 0000000..8954c8b --- /dev/null +++ b/bin/utils/output.py @@ -0,0 +1,43 @@ +# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com> +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +from contextlib import contextmanager +import csv +import json +import sys + +class OutputWriterJSON: + def __init__(self, fd=sys.stdout): + self._fd = fd + + def write(self, something): + self._fd.write(json.dumps(something, indent=3, ensure_ascii=False)) + self._fd.write('\n') + +class OutputWriterCSV: + def __init__(self, fd=sys.stdout): + self._writer = csv.writer(fd, lineterminator='\n') + + def write_row(self, row): + self._writer.writerow(row) + +@contextmanager +def _open_file(path=None, **kwargs): + fd = sys.stdout + if path is None: + pass + else: + fd = open(path, **kwargs) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() + +def open_text_file(path=None): + return _open_file(path, mode='w', encoding='utf-8') + +def open_binary_file(path=None): + return _open_file(path, mode='wb') |