aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorEgor Tensin <Egor.Tensin@gmail.com>2017-01-28 02:01:44 +0300
committerEgor Tensin <Egor.Tensin@gmail.com>2017-01-28 02:01:44 +0300
commitf989577d524f27c079af619b9d0b6b4a9d70c386 (patch)
treebad199636fc9cd7152e64e8ffb6075c13755412a
parentrefactoring (diff)
downloadvk-scripts-f989577d524f27c079af619b9d0b6b4a9d70c386.tar.gz
vk-scripts-f989577d524f27c079af619b9d0b6b4a9d70c386.zip
bin: move file i/o to a separate module
-rw-r--r--bin/mutual_friends.py64
-rw-r--r--bin/online_sessions.py101
-rw-r--r--bin/utils/output.py43
3 files changed, 113 insertions, 95 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py
index 76bf00f..3b0576f 100644
--- a/bin/mutual_friends.py
+++ b/bin/mutual_friends.py
@@ -3,17 +3,17 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+import abc
import argparse
from collections import OrderedDict
-from contextlib import contextmanager
-import csv
from enum import Enum
-import json
import sys
from vk.api import API, Language
from vk.user import UserField
+from .utils import output
+
_OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME
def _query_friend_list(api, user):
@@ -25,25 +25,25 @@ def _filter_user_fields(user):
new_user[str(field)] = user[field] if field in user else None
return new_user
-class OutputWriterCSV:
+class OutputSinkMutualFriends(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def write_mutual_friends(self, friend_list):
+ pass
+
+class OutputSinkCSV(OutputSinkMutualFriends):
def __init__(self, fd=sys.stdout):
- self._writer = csv.writer(fd, lineterminator='\n')
+ self._writer = output.OutputWriterCSV(fd)
def write_mutual_friends(self, friend_list):
for user in friend_list:
- user = _filter_user_fields(user)
- self._writer.writerow(user.values())
+ self._writer.write_row(user.values())
-class OutputWriterJSON:
+class OutputSinkJSON(OutputSinkMutualFriends):
def __init__(self, fd=sys.stdout):
- self._fd = fd
+ self._writer = output.OutputWriterJSON(fd)
def write_mutual_friends(self, friend_list):
- arr = []
- for user in friend_list:
- arr.append(_filter_user_fields(user))
- self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False))
- self._fd.write('\n')
+ self._writer.write(friend_list)
class OutputFormat(Enum):
CSV = 'csv'
@@ -52,29 +52,17 @@ class OutputFormat(Enum):
def __str__(self):
return self.value
- @contextmanager
- def create_writer(self, path=None):
- with self._open_file(path) as fd:
- if self is OutputFormat.CSV:
- yield OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- yield OutputWriterJSON(fd)
- else:
- raise NotImplementedError('unsupported output format: ' + str(self))
-
@staticmethod
- @contextmanager
- def _open_file(path=None):
- fd = sys.stdout
- if path is None:
- pass
+ def open_file(path=None):
+ return output.open_text_file(path)
+
+ def create_sink(self, fd=sys.stdout):
+ if self is OutputFormat.CSV:
+ return OutputSinkCSV(fd)
+ elif self is OutputFormat.JSON:
+ return OutputSinkJSON(fd)
else:
- fd = open(path, 'w', encoding='utf-8')
- try:
- yield fd
- finally:
- if fd is not sys.stdout:
- fd.close()
+ raise NotImplementedError('unsupported output format: ' + str(self))
def _parse_output_format(s):
try:
@@ -107,9 +95,11 @@ def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV):
friend_lists = (frozenset(_query_friend_list(api, user)) for user in users)
mutual_friends = frozenset.intersection(*friend_lists)
+ mutual_friends = [_filter_user_fields(user) for user in mutual_friends]
- with out_fmt.create_writer(out_path) as writer:
- writer.write_mutual_friends(mutual_friends)
+ with out_fmt.open_file(out_path) as out_fd:
+ sink = out_fmt.create_sink(out_fd)
+ sink.write_mutual_friends(mutual_friends)
def main(args=None):
write_mutual_friends(**vars(_parse_args(args)))
diff --git a/bin/online_sessions.py b/bin/online_sessions.py
index 5a9bdc8..cb3f4bf 100644
--- a/bin/online_sessions.py
+++ b/bin/online_sessions.py
@@ -3,13 +3,11 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+import abc
import argparse
-import csv
from collections import OrderedDict
-from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from enum import Enum
-import json
import sys
from vk.tracking import OnlineSessionEnumerator
@@ -17,6 +15,7 @@ from vk.tracking.db import Format as DatabaseFormat
from vk.user import UserField
from .utils.bar_chart import BarChartBuilder
+from .utils import output
class GroupBy(Enum):
USER = 'user'
@@ -24,6 +23,9 @@ class GroupBy(Enum):
WEEKDAY = 'weekday'
HOUR = 'hour'
+ def __str__(self):
+ return self.value
+
def group(self, db_reader, time_from=None, time_to=None):
online_streaks = OnlineSessionEnumerator(time_from, time_to)
if self is GroupBy.USER:
@@ -37,9 +39,6 @@ class GroupBy(Enum):
else:
raise NotImplementedError('unsupported grouping: ' + str(self))
- def __str__(self):
- return self.value
-
_OUTPUT_USER_FIELDS = (
UserField.UID,
UserField.FIRST_NAME,
@@ -47,6 +46,11 @@ _OUTPUT_USER_FIELDS = (
UserField.DOMAIN,
)
+class OutputSinkOnlineSessions(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def process_database(self, group_by, db_reader, time_from=None, time_to=None):
+ pass
+
class OutputConverterCSV:
@staticmethod
def convert_user(user):
@@ -64,9 +68,9 @@ class OutputConverterCSV:
def convert_hour(hour):
return [str(timedelta(hours=hour))]
-class OutputWriterCSV:
+class OutputSinkCSV(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
- self._writer = csv.writer(fd, lineterminator='\n')
+ self._writer = output.OutputWriterCSV(fd)
_CONVERT_KEY = {
GroupBy.USER: OutputConverterCSV.convert_user,
@@ -77,18 +81,15 @@ class OutputWriterCSV:
@staticmethod
def _key_to_row(group_by, key):
- if group_by not in OutputWriterCSV._CONVERT_KEY:
+ if group_by not in OutputSinkCSV._CONVERT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
- return OutputWriterCSV._CONVERT_KEY[group_by](key)
+ return OutputSinkCSV._CONVERT_KEY[group_by](key)
def process_database(self, group_by, db_reader, time_from=None, time_to=None):
for key, duration in group_by.group(db_reader, time_from, time_to).items():
row = self._key_to_row(group_by, key)
row.append(str(duration))
- self._write_row(row)
-
- def _write_row(self, row):
- self._writer.writerow(row)
+ self._writer.write_row(row)
class OutputConverterJSON:
_DATE_FIELD = 'date'
@@ -124,9 +125,9 @@ class OutputConverterJSON:
obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour))
return obj
-class OutputWriterJSON:
+class OutputSinkJSON(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
- self._fd = fd
+ self._writer = output.OutputWriterJSON(fd)
_DURATION_FIELD = 'duration'
@@ -141,13 +142,9 @@ class OutputWriterJSON:
@staticmethod
def _key_to_object(group_by, key):
- if not group_by in OutputWriterJSON._CONVERT_KEY:
+ if not group_by in OutputSinkJSON._CONVERT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
- return OutputWriterJSON._CONVERT_KEY[group_by](key)
-
- def _write(self, entries):
- self._fd.write(json.dumps(entries, indent=3, ensure_ascii=False))
- self._fd.write('\n')
+ return OutputSinkJSON._CONVERT_KEY[group_by](key)
def process_database(self, group_by, db_reader, time_from=None, time_to=None):
entries = []
@@ -155,7 +152,7 @@ class OutputWriterJSON:
entry = self._key_to_object(group_by, key)
entry[self._DURATION_FIELD] = str(duration)
entries.append(entry)
- self._write(entries)
+ self._writer.write(entries)
class OutputConverterPlot:
@staticmethod
@@ -174,7 +171,7 @@ class OutputConverterPlot:
def convert_hour(hour):
return '{}:00'.format(hour)
-class OutputWriterPlot:
+class OutputSinkPlot(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
self._fd = fd
@@ -189,9 +186,9 @@ class OutputWriterPlot:
@staticmethod
def _format_key(group_by, key):
- if group_by not in OutputWriterPlot._FORMAT_KEY:
+ if group_by not in OutputSinkPlot._FORMAT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
- return OutputWriterPlot._FORMAT_KEY[group_by](key)
+ return OutputSinkPlot._FORMAT_KEY[group_by](key)
@staticmethod
def _format_duration(seconds, _):
@@ -203,11 +200,11 @@ class OutputWriterPlot:
@staticmethod
def _extract_labels(group_by, durations):
- return tuple(map(lambda key: OutputWriterPlot._format_key(group_by, key), durations.keys()))
+ return tuple(map(lambda key: OutputSinkPlot._format_key(group_by, key), durations.keys()))
@staticmethod
def _extract_values(durations):
- return tuple(map(OutputWriterPlot._duration_to_seconds, durations.values()))
+ return tuple(map(OutputSinkPlot._duration_to_seconds, durations.values()))
def process_database(
self, group_by, db_reader, time_from=None, time_to=None):
@@ -215,7 +212,7 @@ class OutputWriterPlot:
durations = group_by.group(db_reader, time_from, time_to)
bar_chart = BarChartBuilder()
- bar_chart.set_title(OutputWriterPlot.TITLE)
+ bar_chart.set_title(OutputSinkPlot.TITLE)
bar_chart.enable_grid_for_values()
bar_chart.only_integer_values()
bar_chart.set_property(bar_chart.get_values_labels(),
@@ -245,37 +242,24 @@ class OutputFormat(Enum):
JSON = 'json'
PLOT = 'plot'
- @contextmanager
- def create_writer(self, path=None):
- with self._open_file(path) as fd:
- if self is OutputFormat.CSV:
- yield OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- yield OutputWriterJSON(fd)
- elif self is OutputFormat.PLOT:
- yield OutputWriterPlot(fd)
- else:
- raise NotImplementedError('unsupported output format: ' + str(self))
-
- @contextmanager
- def _open_file(self, path=None):
- fd = sys.stdout
- if path is None:
- pass
- elif self is OutputFormat.CSV or self is OutputFormat.JSON:
- fd = open(path, 'w', encoding='utf-8')
+ def __str__(self):
+ return self.value
+
+ def create_sink(self, fd=sys.stdout):
+ if self is OutputFormat.CSV:
+ return OutputSinkCSV(fd)
+ elif self is OutputFormat.JSON:
+ return OutputSinkJSON(fd)
elif self is OutputFormat.PLOT:
- fd = open(path, 'wb')
+ return OutputSinkPlot(fd)
else:
raise NotImplementedError('unsupported output format: ' + str(self))
- try:
- yield fd
- finally:
- if fd is not sys.stdout:
- fd.close()
- def __str__(self):
- return self.value
+ def open_file(self, path=None):
+ if self is OutputFormat.PLOT:
+ return output.open_binary_file(path)
+ else:
+ return output.open_text_file(path)
def _parse_group_by(s):
try:
@@ -352,8 +336,9 @@ def process_online_sessions(
time_from, time_to = time_to, time_from
with db_fmt.create_reader(db_path) as db_reader:
- with out_fmt.create_writer(out_path) as out_writer:
- out_writer.process_database(
+ with out_fmt.open_file(out_path) as out_fd:
+ out_sink = out_fmt.create_sink(out_fd)
+ out_sink.process_database(
group_by, db_reader,
time_from=time_from,
time_to=time_to)
diff --git a/bin/utils/output.py b/bin/utils/output.py
new file mode 100644
index 0000000..8954c8b
--- /dev/null
+++ b/bin/utils/output.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+from contextlib import contextmanager
+import csv
+import json
+import sys
+
+class OutputWriterJSON:
+ def __init__(self, fd=sys.stdout):
+ self._fd = fd
+
+ def write(self, something):
+ self._fd.write(json.dumps(something, indent=3, ensure_ascii=False))
+ self._fd.write('\n')
+
+class OutputWriterCSV:
+ def __init__(self, fd=sys.stdout):
+ self._writer = csv.writer(fd, lineterminator='\n')
+
+ def write_row(self, row):
+ self._writer.writerow(row)
+
+@contextmanager
+def _open_file(path=None, **kwargs):
+ fd = sys.stdout
+ if path is None:
+ pass
+ else:
+ fd = open(path, **kwargs)
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
+
+def open_text_file(path=None):
+ return _open_file(path, mode='w', encoding='utf-8')
+
+def open_binary_file(path=None):
+ return _open_file(path, mode='wb')