diff options
Diffstat (limited to 'vk/utils/tracking/utils/online_streak_duration.py')
-rw-r--r-- | vk/utils/tracking/utils/online_streak_duration.py | 206 |
1 files changed, 129 insertions, 77 deletions
diff --git a/vk/utils/tracking/utils/online_streak_duration.py b/vk/utils/tracking/utils/online_streak_duration.py index aa99152..0d9313b 100644 --- a/vk/utils/tracking/utils/online_streak_duration.py +++ b/vk/utils/tracking/utils/online_streak_duration.py @@ -12,19 +12,24 @@ import sys import matplotlib.pyplot as plt import numpy as np -from .. import OnlineStreakEnumerator +from ..online_streaks import OnlineStreakEnumerator, Weekday from ..db import Format as DatabaseFormat from vk.user import UserField -def process_database(db_reader, writer): - by_user = OnlineStreakEnumerator().group_by_user(db_reader) - for user, duration in by_user.items(): - writer.add_user_duration(user, duration) - -class OutputFormat(Enum): - CSV = 'csv' - JSON = 'json' - IMG = 'img' +class Grouping(Enum): + USER = 'user' + DATE = 'date' + WEEKDAY = 'weekday' + + def enum_durations(self, db_reader): + if self is Grouping.USER: + return OnlineStreakEnumerator().group_by_user(db_reader) + elif self is Grouping.DATE: + return OnlineStreakEnumerator().group_by_date(db_reader) + elif self is Grouping.WEEKDAY: + return OnlineStreakEnumerator().group_by_weekday(db_reader) + else: + raise NotImplementedError('unsupported grouping: ' + str(self)) def __str__(self): return self.value @@ -40,49 +45,80 @@ class OutputWriterCSV: def __init__(self, fd=sys.stdout): self._writer = csv.writer(fd, lineterminator='\n') - def __enter__(self): - return self + def _user_to_row(user): + return [user[field] for field in _USER_FIELDS] + + def _date_to_row(date): + return [str(date)] + + def _weekday_to_row(weekday): + return [str(weekday)] - def __exit__(self, *args): - pass + _CONVERT_KEY_TO_ROW = { + Grouping.USER: _user_to_row, + Grouping.DATE: _date_to_row, + Grouping.WEEKDAY: _weekday_to_row, + } + + @staticmethod + def _key_to_row(grouping, key): + if grouping not in OutputWriterCSV._CONVERT_KEY_TO_ROW: + raise NotImplementedError('unsupported grouping: ' + str(grouping)) + return OutputWriterCSV._CONVERT_KEY_TO_ROW[grouping](key) - def add_user_duration(self, user, duration): - self._write_row(self._user_duration_to_row(user, duration)) + def process_database(self, grouping, db_reader): + for key, duration in grouping.enum_durations(db_reader).items(): + row = self._key_to_row(grouping, key) + row.append(str(duration)) + self._write_row(row) def _write_row(self, row): self._writer.writerow(row) - @staticmethod - def _user_duration_to_row(user, duration): - row = [] - for field in _USER_FIELDS: - row.append(user[field]) - row.append(str(duration)) - return row +_DATE_FIELD = 'date' +_WEEKDAY_FIELD = 'weekday' class OutputWriterJSON: def __init__(self, fd=sys.stdout): self._fd = fd - self._array = [] - def __enter__(self): - return self + def _user_to_object(user): + obj = OrderedDict() + for field in _USER_FIELDS: + obj[str(field)] = user[field] + return obj - def __exit__(self, *args): - self._fd.write(json.dumps(self._array, indent=3)) + def _date_to_object(date): + obj = OrderedDict() + obj[_DATE_FIELD] = str(date) + return obj - def add_user_duration(self, user, duration): - self._array.append(self._user_duration_to_object(user, duration)) + def _weekday_to_object(weekday): + obj = OrderedDict() + obj[_WEEKDAY_FIELD] = str(weekday) + return obj _DURATION_FIELD = 'duration' + _CONVERT_KEY_TO_OBJECT = { + Grouping.USER: _user_to_object, + Grouping.DATE: _date_to_object, + Grouping.WEEKDAY: _weekday_to_object, + } + @staticmethod - def _user_duration_to_object(user, duration): - record = OrderedDict() - for field in _USER_FIELDS: - record[str(field)] = user[field] - record[OutputWriterJSON._DURATION_FIELD] = str(duration) - return record + def _key_to_object(grouping, key): + if not grouping in OutputWriterJSON._CONVERT_KEY_TO_OBJECT: + raise NotImplementedError('unsupported grouping: ' + str(grouping)) + return OutputWriterJSON._CONVERT_KEY_TO_OBJECT[grouping](key) + + def process_database(self, grouping, db_reader): + arr = [] + for key, duration in grouping.enum_durations(db_reader).items(): + obj = self._key_to_object(grouping, key) + obj[self._DURATION_FIELD] = str(duration) + arr.append(obj) + self._fd.write(json.dumps(arr, indent=3)) class BarChartBuilder: _BAR_HEIGHT = 1. @@ -164,17 +200,29 @@ class BarChartBuilder: class PlotBuilder: def __init__(self, fd=sys.stdout): - self._duration_by_user = {} self._fd = fd - pass - def __enter__(self): - return self - - @staticmethod def _format_user(user): return '{}\n{}'.format(user.get_first_name(), user.get_last_name()) + def _format_date(date): + return str(date) + + def _format_weekday(weekday): + return str(weekday) + + _FORMAT_KEY = { + Grouping.USER: _format_user, + Grouping.DATE: _format_date, + Grouping.WEEKDAY: _format_weekday, + } + + @staticmethod + def _format_key(grouping, key): + if grouping not in PlotBuilder._FORMAT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(grouping)) + return PlotBuilder._FORMAT_KEY[grouping](key) + @staticmethod def _format_duration(seconds, _): return str(timedelta(seconds=seconds)) @@ -183,13 +231,17 @@ class PlotBuilder: def _duration_to_seconds(td): return td.total_seconds() - def _get_users(self): - return tuple(map(self._format_user, self._duration_by_user.keys())) + @staticmethod + def _extract_labels(grouping, durations): + return tuple(map(lambda key: PlotBuilder._format_key(grouping, key), durations.keys())) + + @staticmethod + def _extract_values(durations): + return tuple(map(PlotBuilder._duration_to_seconds, durations.values())) - def _get_durations(self): - return tuple(map(self._duration_to_seconds, self._duration_by_user.values())) + def process_database(self, grouping, db_reader): + durations = grouping.enum_durations(db_reader) - def __exit__(self, *args): bar_chart = BarChartBuilder() bar_chart.set_title('How much time people spend online?') @@ -200,13 +252,13 @@ class PlotBuilder: fontsize='small', rotation=30) bar_chart.set_value_label_formatter(self._format_duration) - users = self._get_users() - durations = self._get_durations() + labels = self._extract_labels(grouping, durations) + durations = self._extract_values(durations) - if not self._duration_by_user or not max(durations): + if not labels or not max(durations): bar_chart.set_value_axis_limits(0) - bars = bar_chart.plot_bars(users, durations) + bars = bar_chart.plot_bars(labels, durations) bar_chart.set_property(bars, alpha=.33) if self._fd is sys.stdout: @@ -214,43 +266,39 @@ class PlotBuilder: else: bar_chart.save(self._fd) - def add_user_duration(self, user, duration): - #if len(self._duration_by_user) >= 1: - # return - #if duration.total_seconds(): - # return - self._duration_by_user[user] = duration # + timedelta(seconds=3) - -def open_output_writer_csv(fd): - return OutputWriterCSV(fd) - -def open_output_writer_json(fd): - return OutputWriterJSON(fd) +class OutputFormat(Enum): + CSV = 'csv' + JSON = 'json' + IMG = 'img' -def open_output_writer_img(fd): - return PlotBuilder(fd) + def create_writer(self, fd): + if self is OutputFormat.CSV: + return OutputWriterCSV(fd) + elif self is OutputFormat.JSON: + return OutputWriterJSON(fd) + elif self is OutputFormat.IMG: + return PlotBuilder(fd) + else: + raise NotImplementedError('unsupported output format: ' + str(self)) -def open_output_writer(fd, fmt): - if fmt is OutputFormat.CSV: - return open_output_writer_csv(fd) - elif fmt is OutputFormat.JSON: - return open_output_writer_json(fd) - elif fmt is OutputFormat.IMG: - return open_output_writer_img(fd) - else: - raise NotImplementedError('unsupported output type: ' + str(fmt)) + def __str__(self): + return self.value if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() + def grouping(s): + try: + return Grouping(s) + except ValueError: + raise argparse.ArgumentTypeError() def database_format(s): try: return DatabaseFormat(s) except ValueError: raise argparse.ArgumentTypeError() - def output_format(s): try: return OutputFormat(s) @@ -262,6 +310,10 @@ if __name__ == '__main__': parser.add_argument('output', type=argparse.FileType('w'), nargs='?', default=sys.stdout, help='output path (standard output by default)') + parser.add_argument('--grouping', type=grouping, + choices=tuple(grouping for grouping in Grouping), + default=Grouping.USER, + help='set grouping') parser.add_argument('--input-format', type=database_format, choices=tuple(fmt for fmt in DatabaseFormat), default=DatabaseFormat.CSV, @@ -274,5 +326,5 @@ if __name__ == '__main__': args = parser.parse_args() with args.input_format.create_reader(args.input) as db_reader: - with open_output_writer(args.output, args.output_format) as output_writer: - process_database(db_reader, output_writer) + output_writer = args.output_format.create_writer(args.output) + output_writer.process_database(args.grouping, db_reader) |