diff options
author | Egor Tensin <Egor.Tensin@gmail.com> | 2016-06-20 04:35:51 +0300 |
---|---|---|
committer | Egor Tensin <Egor.Tensin@gmail.com> | 2016-06-20 04:35:51 +0300 |
commit | 77360caa776be3151c9fe0acd88a4f06a7344c50 (patch) | |
tree | 3658707cea0a8cd33f53bba38c870e47246f3efd /bin/online_duration.py | |
parent | mutual_friends.py: refactoring (diff) | |
download | vk-scripts-77360caa776be3151c9fe0acd88a4f06a7344c50.tar.gz vk-scripts-77360caa776be3151c9fe0acd88a4f06a7344c50.zip |
online_duration.py: refactoring + Pylint fixes
Diffstat (limited to 'bin/online_duration.py')
-rw-r--r-- | bin/online_duration.py | 264 |
1 files changed, 155 insertions, 109 deletions
diff --git a/bin/online_duration.py b/bin/online_duration.py index ca6fd66..b462b1a 100644 --- a/bin/online_duration.py +++ b/bin/online_duration.py @@ -2,6 +2,7 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +import argparse import csv from collections import OrderedDict from datetime import datetime, timedelta, timezone @@ -12,7 +13,7 @@ import sys import matplotlib.pyplot as plt import numpy as np -from vk.tracking import OnlineStreakEnumerator, Weekday +from vk.tracking import OnlineStreakEnumerator from vk.tracking.db import Format as DatabaseFormat from vk.user import UserField @@ -37,41 +38,46 @@ class Grouping(Enum): def __str__(self): return self.value -_USER_FIELDS = ( +_OUTPUT_USER_FIELDS = ( UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME, UserField.SCREEN_NAME, ) -class OutputWriterCSV: - def __init__(self, fd=sys.stdout): - self._writer = csv.writer(fd, lineterminator='\n') - - def _user_to_row(user): - return [user[field] for field in _USER_FIELDS] +class OutputConverterCSV: + @staticmethod + def convert_user(user): + return [user[field] for field in _OUTPUT_USER_FIELDS] - def _date_to_row(date): + @staticmethod + def convert_date(date): return [str(date)] - def _weekday_to_row(weekday): + @staticmethod + def convert_weekday(weekday): return [str(weekday)] - def _hour_to_row(hour): + @staticmethod + def convert_hour(hour): return [str(timedelta(hours=hour))] - _CONVERT_KEY_TO_ROW = { - Grouping.USER: _user_to_row, - Grouping.DATE: _date_to_row, - Grouping.WEEKDAY: _weekday_to_row, - Grouping.HOUR: _hour_to_row, +class OutputWriterCSV: + def __init__(self, fd=sys.stdout): + self._writer = csv.writer(fd, lineterminator='\n') + + _CONVERT_KEY = { + Grouping.USER: OutputConverterCSV.convert_user, + Grouping.DATE: OutputConverterCSV.convert_date, + Grouping.WEEKDAY: OutputConverterCSV.convert_weekday, + Grouping.HOUR: OutputConverterCSV.convert_hour, } @staticmethod def _key_to_row(grouping, key): - if grouping not in OutputWriterCSV._CONVERT_KEY_TO_ROW: + if grouping not in OutputWriterCSV._CONVERT_KEY: raise NotImplementedError('unsupported grouping: ' + str(grouping)) - return OutputWriterCSV._CONVERT_KEY_TO_ROW[grouping](key) + return OutputWriterCSV._CONVERT_KEY[grouping](key) def process_database(self, grouping, db_reader, date_from=None, date_to=None): for key, duration in grouping.enum_durations(db_reader, date_from, date_to).items(): @@ -82,49 +88,58 @@ class OutputWriterCSV: def _write_row(self, row): self._writer.writerow(row) -_DATE_FIELD = 'date' -_WEEKDAY_FIELD = 'weekday' -_HOUR_FIELD = 'hour' +class OutputConverterJSON: + _DATE_FIELD = 'date' + _WEEKDAY_FIELD = 'weekday' + _HOUR_FIELD = 'hour' -class OutputWriterJSON: - def __init__(self, fd=sys.stdout): - self._fd = fd - - def _user_to_object(user): + @staticmethod + def convert_user(user): obj = OrderedDict() - for field in _USER_FIELDS: + for field in _OUTPUT_USER_FIELDS: obj[str(field)] = user[field] return obj - def _date_to_object(date): + @staticmethod + def convert_date(date): obj = OrderedDict() - obj[_DATE_FIELD] = str(date) + obj[OutputConverterJSON._DATE_FIELD] = str(date) return obj - def _weekday_to_object(weekday): + @staticmethod + def convert_weekday(weekday): obj = OrderedDict() - obj[_WEEKDAY_FIELD] = str(weekday) + obj[OutputConverterJSON._WEEKDAY_FIELD] = str(weekday) return obj - def _hour_to_object(hour): + @staticmethod + def convert_hour(hour): obj = OrderedDict() - obj[_HOUR_FIELD] = str(timedelta(hours=hour)) + obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour)) return obj +class OutputWriterJSON: + def __init__(self, fd=sys.stdout): + self._fd = fd + _DURATION_FIELD = 'duration' - _CONVERT_KEY_TO_OBJECT = { - Grouping.USER: _user_to_object, - Grouping.DATE: _date_to_object, - Grouping.WEEKDAY: _weekday_to_object, - Grouping.HOUR: _hour_to_object, + _CONVERT_KEY = { + Grouping.USER: OutputConverterJSON.convert_user, + Grouping.DATE: OutputConverterJSON.convert_date, + Grouping.WEEKDAY: OutputConverterJSON.convert_weekday, + Grouping.HOUR: OutputConverterJSON.convert_hour, } @staticmethod def _key_to_object(grouping, key): - if not grouping in OutputWriterJSON._CONVERT_KEY_TO_OBJECT: + if not grouping in OutputWriterJSON._CONVERT_KEY: raise NotImplementedError('unsupported grouping: ' + str(grouping)) - return OutputWriterJSON._CONVERT_KEY_TO_OBJECT[grouping](key) + return OutputWriterJSON._CONVERT_KEY[grouping](key) + + def _write(self, x): + self._fd.write(json.dumps(x, indent=3, ensure_ascii=False)) + self._fd.write('\n') def process_database(self, grouping, db_reader, date_from=None, date_to=None): arr = [] @@ -132,7 +147,7 @@ class OutputWriterJSON: obj = self._key_to_object(grouping, key) obj[self._DURATION_FIELD] = str(duration) arr.append(obj) - self._fd.write(json.dumps(arr, indent=3)) + self._write(arr) class BarChartBuilder: _BAR_HEIGHT = 1. @@ -172,7 +187,8 @@ class BarChartBuilder: from matplotlib.ticker import MaxNLocator self._get_value_axis().set_major_locator(MaxNLocator(integer=True)) - def set_property(self, *args, **kwargs): + @staticmethod + def set_property(*args, **kwargs): plt.setp(*args, **kwargs) def _set_size(self, inches, dim=0): @@ -210,40 +226,48 @@ class BarChartBuilder: return self._ax.barh(bar_offsets, values, align='center', height=self._BAR_HEIGHT) - def show(self): + @staticmethod + def show(): plt.show() def save(self, path): self._fig.savefig(path, bbox_inches='tight') -class PlotBuilder: - def __init__(self, fd=sys.stdout): - self._fd = fd - - def _format_user(user): +class OutputConverterPlot: + @staticmethod + def convert_user(user): return '{}\n{}'.format(user.get_first_name(), user.get_last_name()) - def _format_date(date): + @staticmethod + def convert_date(date): return str(date) - def _format_weekday(weekday): + @staticmethod + def convert_weekday(weekday): return str(weekday) - def _format_hour(hour): + @staticmethod + def convert_hour(hour): return '{}:00'.format(hour) +class OutputWriterPlot: + def __init__(self, fd=sys.stdout): + self._fd = fd + + TITLE = 'How much time people spend online' + _FORMAT_KEY = { - Grouping.USER: _format_user, - Grouping.DATE: _format_date, - Grouping.WEEKDAY: _format_weekday, - Grouping.HOUR: _format_hour, + Grouping.USER: OutputConverterPlot.convert_user, + Grouping.DATE: OutputConverterPlot.convert_date, + Grouping.WEEKDAY: OutputConverterPlot.convert_weekday, + Grouping.HOUR: OutputConverterPlot.convert_hour, } @staticmethod def _format_key(grouping, key): - if grouping not in PlotBuilder._FORMAT_KEY: + if grouping not in OutputWriterPlot._FORMAT_KEY: raise NotImplementedError('unsupported grouping: ' + str(grouping)) - return PlotBuilder._FORMAT_KEY[grouping](key) + return OutputWriterPlot._FORMAT_KEY[grouping](key) @staticmethod def _format_duration(seconds, _): @@ -255,18 +279,18 @@ class PlotBuilder: @staticmethod def _extract_labels(grouping, durations): - return tuple(map(lambda key: PlotBuilder._format_key(grouping, key), durations.keys())) + return tuple(map(lambda key: OutputWriterPlot._format_key(grouping, key), durations.keys())) @staticmethod def _extract_values(durations): - return tuple(map(PlotBuilder._duration_to_seconds, durations.values())) + return tuple(map(OutputWriterPlot._duration_to_seconds, durations.values())) def process_database(self, grouping, db_reader, date_from=None, date_to=None): durations = grouping.enum_durations(db_reader, date_from, date_to) bar_chart = BarChartBuilder() - bar_chart.set_title('How much time people spend online') + bar_chart.set_title(OutputWriterPlot.TITLE) bar_chart.set_value_grid() bar_chart.set_integer_values_only() @@ -291,80 +315,102 @@ class PlotBuilder: class OutputFormat(Enum): CSV = 'csv' JSON = 'json' - IMG = 'img' + PLOT = 'plot' def create_writer(self, fd): if self is OutputFormat.CSV: return OutputWriterCSV(fd) elif self is OutputFormat.JSON: return OutputWriterJSON(fd) - elif self is OutputFormat.IMG: - return PlotBuilder(fd) + elif self is OutputFormat.PLOT: + return OutputWriterPlot(fd) else: raise NotImplementedError('unsupported output format: ' + str(self)) def __str__(self): return self.value -if __name__ == '__main__': - import argparse - +def _parse_grouping(s): + try: + return Grouping(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid grouping: ' + str(s)) + +def _parse_database_format(s): + try: + return DatabaseFormat(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid database format: ' + str(s)) + +def _parse_output_format(s): + try: + return OutputFormat(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid output format: ' + str(s)) + +_DATE_RANGE_LIMIT_FORMAT = '%Y-%m-%dT%H:%M:%SZ' + +def _parse_date_range_limit(s): + try: + dt = datetime.strptime(s, _DATE_RANGE_LIMIT_FORMAT) + return dt.replace(tzinfo=timezone.utc) + except ValueError: + msg = 'invalid date range limit (must be in the \'{}\' format): {}' + raise argparse.ArgumentTypeError( + msg.format(_DATE_RANGE_LIMIT_FORMAT, s)) + +def _parse_args(args=sys.argv): parser = argparse.ArgumentParser( description='View/visualize the amount of time people spend online.') - def grouping(s): - try: - return Grouping(s) - except ValueError: - raise argparse.ArgumentError() - def database_format(s): - try: - return DatabaseFormat(s) - except ValueError: - raise argparse.ArgumentError() - def output_format(s): - try: - return OutputFormat(s) - except ValueError: - raise argparse.ArgumentError() - def date_range_limit(s): - try: - return datetime.strptime(s, '%Y-%m-%dT%H:%M:%SZ').replace(tzinfo=timezone.utc) - except ValueError: - raise argparse.ArgumentError() - - parser.add_argument('input', type=argparse.FileType('r'), + parser.add_argument('db_fd', metavar='input', + type=argparse.FileType('r', encoding='utf-8'), help='database path') - parser.add_argument('output', type=argparse.FileType('w'), - nargs='?', default=sys.stdout, + parser.add_argument('fd', metavar='output', nargs='?', + type=argparse.FileType('w', encoding='utf-8'), + default=sys.stdout, help='output path (standard output by default)') - parser.add_argument('--grouping', type=grouping, + parser.add_argument('--grouping', + type=_parse_grouping, default=Grouping.USER, choices=tuple(grouping for grouping in Grouping), - default=Grouping.USER, help='set grouping') - parser.add_argument('--input-format', type=database_format, - choices=tuple(fmt for fmt in DatabaseFormat), + parser.add_argument('--input-format', dest='db_fmt', + type=_parse_database_format, default=DatabaseFormat.CSV, + choices=tuple(fmt for fmt in DatabaseFormat), help='specify database format') - parser.add_argument('--output-format', type=output_format, + parser.add_argument('--output-format', dest='fmt', + type=_parse_output_format, default=OutputFormat.CSV, choices=tuple(fmt for fmt in OutputFormat), - default=OutputFormat.CSV, help='specify output format') - parser.add_argument('--from', type=date_range_limit, default=None, - dest='date_from', + parser.add_argument('--from', dest='date_from', + type=_parse_date_range_limit, default=None, help='set the date to process database records from') - parser.add_argument('--to', type=date_range_limit, default=None, - dest='date_to', + parser.add_argument('--to', dest='date_to', + type=_parse_date_range_limit, default=None, help='set the date to process database record to') - args = parser.parse_args() + return parser.parse_args(args[1:]) + +def write_online_duration(db_fd, fd=sys.stdout, + db_fmt=DatabaseFormat.CSV, + fmt=OutputFormat.CSV, + grouping=Grouping.USER, + date_from=None, date_to=None): + + if date_from is not None and date_to is not None: + if date_from > date_to: + date_from, date_to = date_to, date_from - if args.date_from is not None and args.date_to is not None: - if args.date_from > args.date_to: - args.date_from, args.date_to = args.date_to, args.date_from + with db_fmt.create_reader(db_fd) as db_reader: + output_writer = fmt.create_writer(fd) + output_writer.process_database(grouping, db_reader, + date_from=date_from, + date_to=date_to) - with args.input_format.create_reader(args.input) as db_reader: - output_writer = args.output_format.create_writer(args.output) - output_writer.process_database( - args.grouping, db_reader, date_from=args.date_from, - date_to=args.date_to) +def main(args=sys.argv): + args = _parse_args(args) + write_online_duration(**vars(args)) + +if __name__ == '__main__': + main() |