aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/bin/online_sessions.py
diff options
context:
space:
mode:
authorEgor Tensin <Egor.Tensin@gmail.com>2017-01-28 02:01:44 +0300
committerEgor Tensin <Egor.Tensin@gmail.com>2017-01-28 02:01:44 +0300
commitf989577d524f27c079af619b9d0b6b4a9d70c386 (patch)
treebad199636fc9cd7152e64e8ffb6075c13755412a /bin/online_sessions.py
parentrefactoring (diff)
downloadvk-scripts-f989577d524f27c079af619b9d0b6b4a9d70c386.tar.gz
vk-scripts-f989577d524f27c079af619b9d0b6b4a9d70c386.zip
bin: move file i/o to a separate module
Diffstat (limited to 'bin/online_sessions.py')
-rw-r--r--bin/online_sessions.py101
1 files changed, 43 insertions, 58 deletions
diff --git a/bin/online_sessions.py b/bin/online_sessions.py
index 5a9bdc8..cb3f4bf 100644
--- a/bin/online_sessions.py
+++ b/bin/online_sessions.py
@@ -3,13 +3,11 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+import abc
import argparse
-import csv
from collections import OrderedDict
-from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from enum import Enum
-import json
import sys
from vk.tracking import OnlineSessionEnumerator
@@ -17,6 +15,7 @@ from vk.tracking.db import Format as DatabaseFormat
from vk.user import UserField
from .utils.bar_chart import BarChartBuilder
+from .utils import output
class GroupBy(Enum):
USER = 'user'
@@ -24,6 +23,9 @@ class GroupBy(Enum):
WEEKDAY = 'weekday'
HOUR = 'hour'
+ def __str__(self):
+ return self.value
+
def group(self, db_reader, time_from=None, time_to=None):
online_streaks = OnlineSessionEnumerator(time_from, time_to)
if self is GroupBy.USER:
@@ -37,9 +39,6 @@ class GroupBy(Enum):
else:
raise NotImplementedError('unsupported grouping: ' + str(self))
- def __str__(self):
- return self.value
-
_OUTPUT_USER_FIELDS = (
UserField.UID,
UserField.FIRST_NAME,
@@ -47,6 +46,11 @@ _OUTPUT_USER_FIELDS = (
UserField.DOMAIN,
)
+class OutputSinkOnlineSessions(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def process_database(self, group_by, db_reader, time_from=None, time_to=None):
+ pass
+
class OutputConverterCSV:
@staticmethod
def convert_user(user):
@@ -64,9 +68,9 @@ class OutputConverterCSV:
def convert_hour(hour):
return [str(timedelta(hours=hour))]
-class OutputWriterCSV:
+class OutputSinkCSV(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
- self._writer = csv.writer(fd, lineterminator='\n')
+ self._writer = output.OutputWriterCSV(fd)
_CONVERT_KEY = {
GroupBy.USER: OutputConverterCSV.convert_user,
@@ -77,18 +81,15 @@ class OutputWriterCSV:
@staticmethod
def _key_to_row(group_by, key):
- if group_by not in OutputWriterCSV._CONVERT_KEY:
+ if group_by not in OutputSinkCSV._CONVERT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
- return OutputWriterCSV._CONVERT_KEY[group_by](key)
+ return OutputSinkCSV._CONVERT_KEY[group_by](key)
def process_database(self, group_by, db_reader, time_from=None, time_to=None):
for key, duration in group_by.group(db_reader, time_from, time_to).items():
row = self._key_to_row(group_by, key)
row.append(str(duration))
- self._write_row(row)
-
- def _write_row(self, row):
- self._writer.writerow(row)
+ self._writer.write_row(row)
class OutputConverterJSON:
_DATE_FIELD = 'date'
@@ -124,9 +125,9 @@ class OutputConverterJSON:
obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour))
return obj
-class OutputWriterJSON:
+class OutputSinkJSON(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
- self._fd = fd
+ self._writer = output.OutputWriterJSON(fd)
_DURATION_FIELD = 'duration'
@@ -141,13 +142,9 @@ class OutputWriterJSON:
@staticmethod
def _key_to_object(group_by, key):
- if not group_by in OutputWriterJSON._CONVERT_KEY:
+ if not group_by in OutputSinkJSON._CONVERT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
- return OutputWriterJSON._CONVERT_KEY[group_by](key)
-
- def _write(self, entries):
- self._fd.write(json.dumps(entries, indent=3, ensure_ascii=False))
- self._fd.write('\n')
+ return OutputSinkJSON._CONVERT_KEY[group_by](key)
def process_database(self, group_by, db_reader, time_from=None, time_to=None):
entries = []
@@ -155,7 +152,7 @@ class OutputWriterJSON:
entry = self._key_to_object(group_by, key)
entry[self._DURATION_FIELD] = str(duration)
entries.append(entry)
- self._write(entries)
+ self._writer.write(entries)
class OutputConverterPlot:
@staticmethod
@@ -174,7 +171,7 @@ class OutputConverterPlot:
def convert_hour(hour):
return '{}:00'.format(hour)
-class OutputWriterPlot:
+class OutputSinkPlot(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
self._fd = fd
@@ -189,9 +186,9 @@ class OutputWriterPlot:
@staticmethod
def _format_key(group_by, key):
- if group_by not in OutputWriterPlot._FORMAT_KEY:
+ if group_by not in OutputSinkPlot._FORMAT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
- return OutputWriterPlot._FORMAT_KEY[group_by](key)
+ return OutputSinkPlot._FORMAT_KEY[group_by](key)
@staticmethod
def _format_duration(seconds, _):
@@ -203,11 +200,11 @@ class OutputWriterPlot:
@staticmethod
def _extract_labels(group_by, durations):
- return tuple(map(lambda key: OutputWriterPlot._format_key(group_by, key), durations.keys()))
+ return tuple(map(lambda key: OutputSinkPlot._format_key(group_by, key), durations.keys()))
@staticmethod
def _extract_values(durations):
- return tuple(map(OutputWriterPlot._duration_to_seconds, durations.values()))
+ return tuple(map(OutputSinkPlot._duration_to_seconds, durations.values()))
def process_database(
self, group_by, db_reader, time_from=None, time_to=None):
@@ -215,7 +212,7 @@ class OutputWriterPlot:
durations = group_by.group(db_reader, time_from, time_to)
bar_chart = BarChartBuilder()
- bar_chart.set_title(OutputWriterPlot.TITLE)
+ bar_chart.set_title(OutputSinkPlot.TITLE)
bar_chart.enable_grid_for_values()
bar_chart.only_integer_values()
bar_chart.set_property(bar_chart.get_values_labels(),
@@ -245,37 +242,24 @@ class OutputFormat(Enum):
JSON = 'json'
PLOT = 'plot'
- @contextmanager
- def create_writer(self, path=None):
- with self._open_file(path) as fd:
- if self is OutputFormat.CSV:
- yield OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- yield OutputWriterJSON(fd)
- elif self is OutputFormat.PLOT:
- yield OutputWriterPlot(fd)
- else:
- raise NotImplementedError('unsupported output format: ' + str(self))
-
- @contextmanager
- def _open_file(self, path=None):
- fd = sys.stdout
- if path is None:
- pass
- elif self is OutputFormat.CSV or self is OutputFormat.JSON:
- fd = open(path, 'w', encoding='utf-8')
+ def __str__(self):
+ return self.value
+
+ def create_sink(self, fd=sys.stdout):
+ if self is OutputFormat.CSV:
+ return OutputSinkCSV(fd)
+ elif self is OutputFormat.JSON:
+ return OutputSinkJSON(fd)
elif self is OutputFormat.PLOT:
- fd = open(path, 'wb')
+ return OutputSinkPlot(fd)
else:
raise NotImplementedError('unsupported output format: ' + str(self))
- try:
- yield fd
- finally:
- if fd is not sys.stdout:
- fd.close()
- def __str__(self):
- return self.value
+ def open_file(self, path=None):
+ if self is OutputFormat.PLOT:
+ return output.open_binary_file(path)
+ else:
+ return output.open_text_file(path)
def _parse_group_by(s):
try:
@@ -352,8 +336,9 @@ def process_online_sessions(
time_from, time_to = time_to, time_from
with db_fmt.create_reader(db_path) as db_reader:
- with out_fmt.create_writer(out_path) as out_writer:
- out_writer.process_database(
+ with out_fmt.open_file(out_path) as out_fd:
+ out_sink = out_fmt.create_sink(out_fd)
+ out_sink.process_database(
group_by, db_reader,
time_from=time_from,
time_to=time_to)