aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/vk/utils/tracking/utils/online_streak_duration.py
diff options
context:
space:
mode:
Diffstat (limited to 'vk/utils/tracking/utils/online_streak_duration.py')
-rw-r--r--vk/utils/tracking/utils/online_streak_duration.py206
1 files changed, 129 insertions, 77 deletions
diff --git a/vk/utils/tracking/utils/online_streak_duration.py b/vk/utils/tracking/utils/online_streak_duration.py
index aa99152..0d9313b 100644
--- a/vk/utils/tracking/utils/online_streak_duration.py
+++ b/vk/utils/tracking/utils/online_streak_duration.py
@@ -12,19 +12,24 @@ import sys
import matplotlib.pyplot as plt
import numpy as np
-from .. import OnlineStreakEnumerator
+from ..online_streaks import OnlineStreakEnumerator, Weekday
from ..db import Format as DatabaseFormat
from vk.user import UserField
-def process_database(db_reader, writer):
- by_user = OnlineStreakEnumerator().group_by_user(db_reader)
- for user, duration in by_user.items():
- writer.add_user_duration(user, duration)
-
-class OutputFormat(Enum):
- CSV = 'csv'
- JSON = 'json'
- IMG = 'img'
+class Grouping(Enum):
+ USER = 'user'
+ DATE = 'date'
+ WEEKDAY = 'weekday'
+
+ def enum_durations(self, db_reader):
+ if self is Grouping.USER:
+ return OnlineStreakEnumerator().group_by_user(db_reader)
+ elif self is Grouping.DATE:
+ return OnlineStreakEnumerator().group_by_date(db_reader)
+ elif self is Grouping.WEEKDAY:
+ return OnlineStreakEnumerator().group_by_weekday(db_reader)
+ else:
+ raise NotImplementedError('unsupported grouping: ' + str(self))
def __str__(self):
return self.value
@@ -40,49 +45,80 @@ class OutputWriterCSV:
def __init__(self, fd=sys.stdout):
self._writer = csv.writer(fd, lineterminator='\n')
- def __enter__(self):
- return self
+ def _user_to_row(user):
+ return [user[field] for field in _USER_FIELDS]
+
+ def _date_to_row(date):
+ return [str(date)]
+
+ def _weekday_to_row(weekday):
+ return [str(weekday)]
- def __exit__(self, *args):
- pass
+ _CONVERT_KEY_TO_ROW = {
+ Grouping.USER: _user_to_row,
+ Grouping.DATE: _date_to_row,
+ Grouping.WEEKDAY: _weekday_to_row,
+ }
+
+ @staticmethod
+ def _key_to_row(grouping, key):
+ if grouping not in OutputWriterCSV._CONVERT_KEY_TO_ROW:
+ raise NotImplementedError('unsupported grouping: ' + str(grouping))
+ return OutputWriterCSV._CONVERT_KEY_TO_ROW[grouping](key)
- def add_user_duration(self, user, duration):
- self._write_row(self._user_duration_to_row(user, duration))
+ def process_database(self, grouping, db_reader):
+ for key, duration in grouping.enum_durations(db_reader).items():
+ row = self._key_to_row(grouping, key)
+ row.append(str(duration))
+ self._write_row(row)
def _write_row(self, row):
self._writer.writerow(row)
- @staticmethod
- def _user_duration_to_row(user, duration):
- row = []
- for field in _USER_FIELDS:
- row.append(user[field])
- row.append(str(duration))
- return row
+_DATE_FIELD = 'date'
+_WEEKDAY_FIELD = 'weekday'
class OutputWriterJSON:
def __init__(self, fd=sys.stdout):
self._fd = fd
- self._array = []
- def __enter__(self):
- return self
+ def _user_to_object(user):
+ obj = OrderedDict()
+ for field in _USER_FIELDS:
+ obj[str(field)] = user[field]
+ return obj
- def __exit__(self, *args):
- self._fd.write(json.dumps(self._array, indent=3))
+ def _date_to_object(date):
+ obj = OrderedDict()
+ obj[_DATE_FIELD] = str(date)
+ return obj
- def add_user_duration(self, user, duration):
- self._array.append(self._user_duration_to_object(user, duration))
+ def _weekday_to_object(weekday):
+ obj = OrderedDict()
+ obj[_WEEKDAY_FIELD] = str(weekday)
+ return obj
_DURATION_FIELD = 'duration'
+ _CONVERT_KEY_TO_OBJECT = {
+ Grouping.USER: _user_to_object,
+ Grouping.DATE: _date_to_object,
+ Grouping.WEEKDAY: _weekday_to_object,
+ }
+
@staticmethod
- def _user_duration_to_object(user, duration):
- record = OrderedDict()
- for field in _USER_FIELDS:
- record[str(field)] = user[field]
- record[OutputWriterJSON._DURATION_FIELD] = str(duration)
- return record
+ def _key_to_object(grouping, key):
+ if not grouping in OutputWriterJSON._CONVERT_KEY_TO_OBJECT:
+ raise NotImplementedError('unsupported grouping: ' + str(grouping))
+ return OutputWriterJSON._CONVERT_KEY_TO_OBJECT[grouping](key)
+
+ def process_database(self, grouping, db_reader):
+ arr = []
+ for key, duration in grouping.enum_durations(db_reader).items():
+ obj = self._key_to_object(grouping, key)
+ obj[self._DURATION_FIELD] = str(duration)
+ arr.append(obj)
+ self._fd.write(json.dumps(arr, indent=3))
class BarChartBuilder:
_BAR_HEIGHT = 1.
@@ -164,17 +200,29 @@ class BarChartBuilder:
class PlotBuilder:
def __init__(self, fd=sys.stdout):
- self._duration_by_user = {}
self._fd = fd
- pass
- def __enter__(self):
- return self
-
- @staticmethod
def _format_user(user):
return '{}\n{}'.format(user.get_first_name(), user.get_last_name())
+ def _format_date(date):
+ return str(date)
+
+ def _format_weekday(weekday):
+ return str(weekday)
+
+ _FORMAT_KEY = {
+ Grouping.USER: _format_user,
+ Grouping.DATE: _format_date,
+ Grouping.WEEKDAY: _format_weekday,
+ }
+
+ @staticmethod
+ def _format_key(grouping, key):
+ if grouping not in PlotBuilder._FORMAT_KEY:
+ raise NotImplementedError('unsupported grouping: ' + str(grouping))
+ return PlotBuilder._FORMAT_KEY[grouping](key)
+
@staticmethod
def _format_duration(seconds, _):
return str(timedelta(seconds=seconds))
@@ -183,13 +231,17 @@ class PlotBuilder:
def _duration_to_seconds(td):
return td.total_seconds()
- def _get_users(self):
- return tuple(map(self._format_user, self._duration_by_user.keys()))
+ @staticmethod
+ def _extract_labels(grouping, durations):
+ return tuple(map(lambda key: PlotBuilder._format_key(grouping, key), durations.keys()))
+
+ @staticmethod
+ def _extract_values(durations):
+ return tuple(map(PlotBuilder._duration_to_seconds, durations.values()))
- def _get_durations(self):
- return tuple(map(self._duration_to_seconds, self._duration_by_user.values()))
+ def process_database(self, grouping, db_reader):
+ durations = grouping.enum_durations(db_reader)
- def __exit__(self, *args):
bar_chart = BarChartBuilder()
bar_chart.set_title('How much time people spend online?')
@@ -200,13 +252,13 @@ class PlotBuilder:
fontsize='small', rotation=30)
bar_chart.set_value_label_formatter(self._format_duration)
- users = self._get_users()
- durations = self._get_durations()
+ labels = self._extract_labels(grouping, durations)
+ durations = self._extract_values(durations)
- if not self._duration_by_user or not max(durations):
+ if not labels or not max(durations):
bar_chart.set_value_axis_limits(0)
- bars = bar_chart.plot_bars(users, durations)
+ bars = bar_chart.plot_bars(labels, durations)
bar_chart.set_property(bars, alpha=.33)
if self._fd is sys.stdout:
@@ -214,43 +266,39 @@ class PlotBuilder:
else:
bar_chart.save(self._fd)
- def add_user_duration(self, user, duration):
- #if len(self._duration_by_user) >= 1:
- # return
- #if duration.total_seconds():
- # return
- self._duration_by_user[user] = duration # + timedelta(seconds=3)
-
-def open_output_writer_csv(fd):
- return OutputWriterCSV(fd)
-
-def open_output_writer_json(fd):
- return OutputWriterJSON(fd)
+class OutputFormat(Enum):
+ CSV = 'csv'
+ JSON = 'json'
+ IMG = 'img'
-def open_output_writer_img(fd):
- return PlotBuilder(fd)
+ def create_writer(self, fd):
+ if self is OutputFormat.CSV:
+ return OutputWriterCSV(fd)
+ elif self is OutputFormat.JSON:
+ return OutputWriterJSON(fd)
+ elif self is OutputFormat.IMG:
+ return PlotBuilder(fd)
+ else:
+ raise NotImplementedError('unsupported output format: ' + str(self))
-def open_output_writer(fd, fmt):
- if fmt is OutputFormat.CSV:
- return open_output_writer_csv(fd)
- elif fmt is OutputFormat.JSON:
- return open_output_writer_json(fd)
- elif fmt is OutputFormat.IMG:
- return open_output_writer_img(fd)
- else:
- raise NotImplementedError('unsupported output type: ' + str(fmt))
+ def __str__(self):
+ return self.value
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
+ def grouping(s):
+ try:
+ return Grouping(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError()
def database_format(s):
try:
return DatabaseFormat(s)
except ValueError:
raise argparse.ArgumentTypeError()
-
def output_format(s):
try:
return OutputFormat(s)
@@ -262,6 +310,10 @@ if __name__ == '__main__':
parser.add_argument('output', type=argparse.FileType('w'),
nargs='?', default=sys.stdout,
help='output path (standard output by default)')
+ parser.add_argument('--grouping', type=grouping,
+ choices=tuple(grouping for grouping in Grouping),
+ default=Grouping.USER,
+ help='set grouping')
parser.add_argument('--input-format', type=database_format,
choices=tuple(fmt for fmt in DatabaseFormat),
default=DatabaseFormat.CSV,
@@ -274,5 +326,5 @@ if __name__ == '__main__':
args = parser.parse_args()
with args.input_format.create_reader(args.input) as db_reader:
- with open_output_writer(args.output, args.output_format) as output_writer:
- process_database(db_reader, output_writer)
+ output_writer = args.output_format.create_writer(args.output)
+ output_writer.process_database(args.grouping, db_reader)