From 5be66f6e80775d834e6497a91dd296fab12e8725 Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Tue, 19 Jul 2016 21:21:57 +0300 Subject: refactoring --- bin/online_duration.py | 157 ++++++++++++++++++++++++++----------------------- bin/track_status.py | 7 ++- 2 files changed, 89 insertions(+), 75 deletions(-) diff --git a/bin/online_duration.py b/bin/online_duration.py index 476f20c..d5b8f35 100644 --- a/bin/online_duration.py +++ b/bin/online_duration.py @@ -17,21 +17,22 @@ from vk.tracking import OnlineStreakEnumerator from vk.tracking.db import Format as DatabaseFormat from vk.user import UserField -class Grouping(Enum): +class GroupBy(Enum): USER = 'user' DATE = 'date' WEEKDAY = 'weekday' HOUR = 'hour' def enum_durations(self, db_reader, date_from=None, date_to=None): - if self is Grouping.USER: - return OnlineStreakEnumerator(date_from, date_to).group_by_user(db_reader) - elif self is Grouping.DATE: - return OnlineStreakEnumerator(date_from, date_to).group_by_date(db_reader) - elif self is Grouping.WEEKDAY: - return OnlineStreakEnumerator(date_from, date_to).group_by_weekday(db_reader) - elif self is Grouping.HOUR: - return OnlineStreakEnumerator(date_from, date_to).group_by_hour(db_reader) + online_streaks = OnlineStreakEnumerator(date_from, date_to) + if self is GroupBy.USER: + return online_streaks.group_by_user(db_reader) + elif self is GroupBy.DATE: + return online_streaks.group_by_date(db_reader) + elif self is GroupBy.WEEKDAY: + return online_streaks.group_by_weekday(db_reader) + elif self is GroupBy.HOUR: + return online_streaks.group_by_hour(db_reader) else: raise NotImplementedError('unsupported grouping: ' + str(self)) @@ -67,21 +68,21 @@ class OutputWriterCSV: self._writer = csv.writer(fd, lineterminator='\n') _CONVERT_KEY = { - Grouping.USER: OutputConverterCSV.convert_user, - Grouping.DATE: OutputConverterCSV.convert_date, - Grouping.WEEKDAY: OutputConverterCSV.convert_weekday, - Grouping.HOUR: OutputConverterCSV.convert_hour, + GroupBy.USER: OutputConverterCSV.convert_user, + GroupBy.DATE: OutputConverterCSV.convert_date, + GroupBy.WEEKDAY: OutputConverterCSV.convert_weekday, + GroupBy.HOUR: OutputConverterCSV.convert_hour, } @staticmethod - def _key_to_row(grouping, key): - if grouping not in OutputWriterCSV._CONVERT_KEY: - raise NotImplementedError('unsupported grouping: ' + str(grouping)) - return OutputWriterCSV._CONVERT_KEY[grouping](key) - - def process_database(self, grouping, db_reader, date_from=None, date_to=None): - for key, duration in grouping.enum_durations(db_reader, date_from, date_to).items(): - row = self._key_to_row(grouping, key) + def _key_to_row(group_by, key): + if group_by not in OutputWriterCSV._CONVERT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(group_by)) + return OutputWriterCSV._CONVERT_KEY[group_by](key) + + def process_database(self, group_by, db_reader, date_from=None, date_to=None): + for key, duration in group_by.enum_durations(db_reader, date_from, date_to).items(): + row = self._key_to_row(group_by, key) row.append(str(duration)) self._write_row(row) @@ -131,26 +132,26 @@ class OutputWriterJSON: assert _DURATION_FIELD not in map(str, _OUTPUT_USER_FIELDS) _CONVERT_KEY = { - Grouping.USER: OutputConverterJSON.convert_user, - Grouping.DATE: OutputConverterJSON.convert_date, - Grouping.WEEKDAY: OutputConverterJSON.convert_weekday, - Grouping.HOUR: OutputConverterJSON.convert_hour, + GroupBy.USER: OutputConverterJSON.convert_user, + GroupBy.DATE: OutputConverterJSON.convert_date, + GroupBy.WEEKDAY: OutputConverterJSON.convert_weekday, + GroupBy.HOUR: OutputConverterJSON.convert_hour, } @staticmethod - def _key_to_object(grouping, key): - if not grouping in OutputWriterJSON._CONVERT_KEY: - raise NotImplementedError('unsupported grouping: ' + str(grouping)) - return OutputWriterJSON._CONVERT_KEY[grouping](key) + def _key_to_object(group_by, key): + if not group_by in OutputWriterJSON._CONVERT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(group_by)) + return OutputWriterJSON._CONVERT_KEY[group_by](key) def _write(self, x): self._fd.write(json.dumps(x, indent=3, ensure_ascii=False)) self._fd.write('\n') - def process_database(self, grouping, db_reader, date_from=None, date_to=None): + def process_database(self, group_by, db_reader, date_from=None, date_to=None): arr = [] - for key, duration in grouping.enum_durations(db_reader, date_from, date_to).items(): - obj = self._key_to_object(grouping, key) + for key, duration in group_by.enum_durations(db_reader, date_from, date_to).items(): + obj = self._key_to_object(group_by, key) obj[self._DURATION_FIELD] = str(duration) arr.append(obj) self._write(arr) @@ -209,7 +210,11 @@ class BarChartBuilder: def set_height(self, inches): self._set_size(inches, dim=1) - def plot_bars(self, bar_labels, values, datetime_ticks=False): + def plot_bars( + self, bar_labels, bar_lengths, + bars_between_ticks=False, + inches_per_bar=1): + numof_bars = len(bar_labels) if not numof_bars: @@ -217,20 +222,23 @@ class BarChartBuilder: self._get_bar_axis().set_tick_params(labelleft=False) return [] - self.set_height(numof_bars / 2 if datetime_ticks else numof_bars) + self.set_height(inches_per_bar * numof_bars) bar_offsets = np.arange(numof_bars) * 2 * self._BAR_HEIGHT + self._BAR_HEIGHT - bar_axis_min, bar_axis_max = 0, 2 * self._BAR_HEIGHT * numof_bars - if datetime_ticks: + if bars_between_ticks: self._get_bar_axis().set_ticks(bar_offsets - self._BAR_HEIGHT) else: self._get_bar_axis().set_ticks(bar_offsets) - self._get_bar_axis().set_ticklabels(bar_labels) + bar_axis_min = 0 + bar_axis_max = 2 * self._BAR_HEIGHT * numof_bars self.set_bar_axis_limits(bar_axis_min, bar_axis_max) - return self._ax.barh(bar_offsets, values, align='center', height=self._BAR_HEIGHT) + self._get_bar_axis().set_ticklabels(bar_labels) + + return self._ax.barh( + bar_offsets, bar_lengths, align='center', height=self._BAR_HEIGHT) @staticmethod def show(): @@ -263,17 +271,17 @@ class OutputWriterPlot: TITLE = 'How much time people spend online' _FORMAT_KEY = { - Grouping.USER: OutputConverterPlot.convert_user, - Grouping.DATE: OutputConverterPlot.convert_date, - Grouping.WEEKDAY: OutputConverterPlot.convert_weekday, - Grouping.HOUR: OutputConverterPlot.convert_hour, + GroupBy.USER: OutputConverterPlot.convert_user, + GroupBy.DATE: OutputConverterPlot.convert_date, + GroupBy.WEEKDAY: OutputConverterPlot.convert_weekday, + GroupBy.HOUR: OutputConverterPlot.convert_hour, } @staticmethod - def _format_key(grouping, key): - if grouping not in OutputWriterPlot._FORMAT_KEY: - raise NotImplementedError('unsupported grouping: ' + str(grouping)) - return OutputWriterPlot._FORMAT_KEY[grouping](key) + def _format_key(group_by, key): + if group_by not in OutputWriterPlot._FORMAT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(group_by)) + return OutputWriterPlot._FORMAT_KEY[group_by](key) @staticmethod def _format_duration(seconds, _): @@ -284,15 +292,18 @@ class OutputWriterPlot: return td.total_seconds() @staticmethod - def _extract_labels(grouping, durations): - return tuple(map(lambda key: OutputWriterPlot._format_key(grouping, key), durations.keys())) + def _extract_labels(group_by, durations): + return tuple(map(lambda key: OutputWriterPlot._format_key(group_by, key), durations.keys())) @staticmethod def _extract_values(durations): return tuple(map(OutputWriterPlot._duration_to_seconds, durations.values())) - def process_database(self, grouping, db_reader, date_from=None, date_to=None): - durations = grouping.enum_durations(db_reader, date_from, date_to) + def process_database( + self, group_by, db_reader, date_from=None, date_to=None): + + durations = group_by.enum_durations( + db_reader, date_from, date_to) bar_chart = BarChartBuilder() @@ -300,17 +311,20 @@ class OutputWriterPlot: bar_chart.set_value_grid() bar_chart.set_integer_values_only() - bar_chart.set_property(bar_chart.get_value_labels(), - fontsize='small', rotation=30) + bar_chart.set_property( + bar_chart.get_value_labels(), fontsize='small', rotation=30) bar_chart.set_value_label_formatter(self._format_duration) - labels = self._extract_labels(grouping, durations) + labels = self._extract_labels(group_by, durations) durations = self._extract_values(durations) if not labels or not max(durations): bar_chart.set_value_axis_limits(0) - bars = bar_chart.plot_bars(labels, durations, grouping is Grouping.HOUR) + bars = bar_chart.plot_bars( + labels, durations, + bars_between_ticks=group_by is GroupBy.HOUR, + inches_per_bar=.5 if group_by is GroupBy.HOUR else 1) bar_chart.set_property(bars, alpha=.33) if self._fd is sys.stdout: @@ -336,11 +350,11 @@ class OutputFormat(Enum): def __str__(self): return self.value -def _parse_grouping(s): +def _parse_group_by(s): try: - return Grouping(s) + return GroupBy(s) except ValueError: - raise argparse.ArgumentTypeError('invalid grouping: ' + s) + raise argparse.ArgumentTypeError('invalid "group by" value: ' + s) def _parse_database_format(s): try: @@ -371,16 +385,16 @@ def _parse_args(args=sys.argv): parser.add_argument('db_fd', metavar='input', type=argparse.FileType('r', encoding='utf-8'), - help='database path') + help='database file path') parser.add_argument('fd', metavar='output', nargs='?', type=argparse.FileType('w', encoding='utf-8'), default=sys.stdout, - help='output path (standard output by default)') - parser.add_argument('-g', '--grouping', - type=_parse_grouping, - choices=Grouping, - default=Grouping.USER, - help='group database records by date, weekday, etc.') + help='output file path (standard output by default)') + parser.add_argument('-g', '--group-by', + type=_parse_group_by, + choices=GroupBy, + default=GroupBy.USER, + help='group online streaks by user/date/etc.') parser.add_argument('-i', '--input-format', dest='db_fmt', type=_parse_database_format, default=DatabaseFormat.CSV, @@ -400,11 +414,11 @@ def _parse_args(args=sys.argv): return parser.parse_args(args[1:]) -def write_online_duration(db_fd, fd=sys.stdout, - db_fmt=DatabaseFormat.CSV, - fmt=OutputFormat.CSV, - grouping=Grouping.USER, - date_from=None, date_to=None): +def write_online_duration( + db_fd, db_fmt=DatabaseFormat.CSV, + fd=sys.stdout, fmt=OutputFormat.CSV, + group_by=GroupBy.USER, + date_from=None, date_to=None): if date_from is not None and date_to is not None: if date_from > date_to: @@ -412,9 +426,8 @@ def write_online_duration(db_fd, fd=sys.stdout, with db_fmt.create_reader(db_fd) as db_reader: output_writer = fmt.create_writer(fd) - output_writer.process_database(grouping, db_reader, - date_from=date_from, - date_to=date_to) + output_writer.process_database( + group_by, db_reader, date_from=date_from, date_to=date_to) def main(args=sys.argv): args = _parse_args(args) diff --git a/bin/track_status.py b/bin/track_status.py index 19fae07..a5befef 100644 --- a/bin/track_status.py +++ b/bin/track_status.py @@ -52,9 +52,10 @@ def _parse_args(args=sys.argv): return parser.parse_args(args[1:]) -def track_status(uids, timeout=DEFAULT_TIMEOUT, - log_fd=sys.stdout, - db_fd=None, db_fmt=DEFAULT_DB_FORMAT): +def track_status( + uids, timeout=DEFAULT_TIMEOUT, + log_fd=sys.stdout, + db_fd=None, db_fmt=DEFAULT_DB_FORMAT): api = API(Language.EN, deactivated_users=False) tracker = StatusTracker(api, timeout) -- cgit v1.2.3