aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
-rw-r--r--bin/mutual_friends.py58
-rw-r--r--bin/online_sessions.py57
-rw-r--r--bin/track_status.py18
-rw-r--r--vk/tracking/db/backend/csv.py12
-rw-r--r--vk/tracking/db/backend/log.py6
-rw-r--r--vk/tracking/db/backend/null.py16
-rw-r--r--vk/tracking/db/format.py64
7 files changed, 131 insertions, 100 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py
index d8591cf..76bf00f 100644
--- a/bin/mutual_friends.py
+++ b/bin/mutual_friends.py
@@ -5,6 +5,7 @@
import argparse
from collections import OrderedDict
+from contextlib import contextmanager
import csv
from enum import Enum
import json
@@ -28,12 +29,6 @@ class OutputWriterCSV:
def __init__(self, fd=sys.stdout):
self._writer = csv.writer(fd, lineterminator='\n')
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def write_mutual_friends(self, friend_list):
for user in friend_list:
user = _filter_user_fields(user)
@@ -42,18 +37,13 @@ class OutputWriterCSV:
class OutputWriterJSON:
def __init__(self, fd=sys.stdout):
self._fd = fd
- self._arr = []
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- self._fd.write(json.dumps(self._arr, indent=3, ensure_ascii=False))
- self._fd.write('\n')
def write_mutual_friends(self, friend_list):
+ arr = []
for user in friend_list:
- self._arr.append(_filter_user_fields(user))
+ arr.append(_filter_user_fields(user))
+ self._fd.write(json.dumps(arr, indent=3, ensure_ascii=False))
+ self._fd.write('\n')
class OutputFormat(Enum):
CSV = 'csv'
@@ -62,13 +52,29 @@ class OutputFormat(Enum):
def __str__(self):
return self.value
- def create_writer(self, fd=sys.stdout):
- if self is OutputFormat.CSV:
- return OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- return OutputWriterJSON(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_file(path) as fd:
+ if self is OutputFormat.CSV:
+ yield OutputWriterCSV(fd)
+ elif self is OutputFormat.JSON:
+ yield OutputWriterJSON(fd)
+ else:
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
+ @staticmethod
+ @contextmanager
+ def _open_file(path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
else:
- raise NotImplementedError('unsupported output format: ' + str(self))
+ fd = open(path, 'w', encoding='utf-8')
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
def _parse_output_format(s):
try:
@@ -85,26 +91,24 @@ def _parse_args(args=None):
parser.add_argument('uids', metavar='UID', nargs='+',
help='user IDs or "screen names"')
- parser.add_argument('-f', '--format', dest='fmt',
+ parser.add_argument('-f', '--format', dest='out_fmt',
type=_parse_output_format,
default=OutputFormat.CSV,
choices=OutputFormat,
help='specify output format')
- parser.add_argument('-o', '--output', metavar='PATH', dest='fd',
- type=argparse.FileType('w', encoding='utf-8'),
- default=sys.stdout,
+ parser.add_argument('-o', '--output', metavar='PATH', dest='out_path',
help='set output file path (standard output by default)')
return parser.parse_args(args)
-def write_mutual_friends(uids, fmt=OutputFormat.CSV, fd=sys.stdout):
+def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV):
api = API(Language.EN)
users = api.users_get(uids)
friend_lists = (frozenset(_query_friend_list(api, user)) for user in users)
mutual_friends = frozenset.intersection(*friend_lists)
- with fmt.create_writer(fd) as writer:
+ with out_fmt.create_writer(out_path) as writer:
writer.write_mutual_friends(mutual_friends)
def main(args=None):
diff --git a/bin/online_sessions.py b/bin/online_sessions.py
index ec26689..7a763e5 100644
--- a/bin/online_sessions.py
+++ b/bin/online_sessions.py
@@ -6,6 +6,7 @@
import argparse
import csv
from collections import OrderedDict
+from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from enum import Enum
import json
@@ -244,15 +245,34 @@ class OutputFormat(Enum):
JSON = 'json'
PLOT = 'plot'
- def create_writer(self, fd):
- if self is OutputFormat.CSV:
- return OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- return OutputWriterJSON(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_file(path) as fd:
+ if self is OutputFormat.CSV:
+ yield OutputWriterCSV(fd)
+ elif self is OutputFormat.JSON:
+ yield OutputWriterJSON(fd)
+ elif self is OutputFormat.PLOT:
+ yield OutputWriterPlot(fd)
+ else:
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
+ @contextmanager
+ def _open_file(self, path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
+ elif self is OutputFormat.CSV or self is OutputFormat.JSON:
+ fd = open(path, 'w', encoding='utf-8')
elif self is OutputFormat.PLOT:
- return OutputWriterPlot(fd)
+ fd = open(path, 'wb')
else:
raise NotImplementedError('unsupported output format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
def __str__(self):
return self.value
@@ -293,12 +313,9 @@ def _parse_args(args=None):
parser = argparse.ArgumentParser(
description='View/visualize the amount of time people spend online.')
- parser.add_argument('db_fd', metavar='input',
- type=argparse.FileType('r', encoding='utf-8'),
- help='database file path')
- parser.add_argument('fd', metavar='output', nargs='?',
- type=argparse.FileType('w', encoding='utf-8'),
- default=sys.stdout,
+ parser.add_argument('db_path', metavar='input', nargs='?',
+ help='database file path (standard input by default)')
+ parser.add_argument('out_path', metavar='output', nargs='?',
help='output file path (standard output by default)')
parser.add_argument('-g', '--group-by',
type=_parse_group_by,
@@ -310,7 +327,7 @@ def _parse_args(args=None):
default=DatabaseFormat.CSV,
choices=DatabaseFormat,
help='specify database format')
- parser.add_argument('-o', '--output-format', dest='fmt',
+ parser.add_argument('-o', '--output-format', dest='out_fmt',
type=_parse_output_format,
choices=OutputFormat,
default=OutputFormat.CSV,
@@ -325,8 +342,8 @@ def _parse_args(args=None):
return parser.parse_args(args)
def process_online_sessions(
- db_fd, db_fmt=DatabaseFormat.CSV,
- fd=sys.stdout, fmt=OutputFormat.CSV,
+ db_path=None, db_fmt=DatabaseFormat.CSV,
+ out_path=None, out_fmt=OutputFormat.CSV,
group_by=GroupBy.USER,
time_from=None, time_to=None):
@@ -334,10 +351,12 @@ def process_online_sessions(
if time_from > time_to:
time_from, time_to = time_to, time_from
- with db_fmt.create_reader(db_fd) as db_reader:
- output_writer = fmt.create_writer(fd)
- output_writer.process_database(
- group_by, db_reader, time_from=time_from, time_to=time_to)
+ with db_fmt.create_reader(db_path) as db_reader:
+ with out_fmt.create_writer(out_path) as out_writer:
+ out_writer.process_database(
+ group_by, db_reader,
+ time_from=time_from,
+ time_to=time_to)
def main(args=None):
process_online_sessions(**vars(_parse_args(args)))
diff --git a/bin/track_status.py b/bin/track_status.py
index d8d8908..404c1ec 100644
--- a/bin/track_status.py
+++ b/bin/track_status.py
@@ -40,36 +40,32 @@ def _parse_args(args=None):
type=_parse_positive_integer,
default=DEFAULT_TIMEOUT,
help='set refresh interval')
- parser.add_argument('-l', '--log', metavar='PATH', dest='log_fd',
- type=argparse.FileType('w', encoding='utf-8'),
- default=sys.stdout,
+ parser.add_argument('-l', '--log', metavar='PATH', dest='log_path',
help='set log file path (standard output by default)')
parser.add_argument('-f', '--format', dest='db_fmt',
type=_parse_database_format,
choices=DatabaseFormat,
default=DEFAULT_DB_FORMAT,
help='specify database format')
- parser.add_argument('-o', '--output', metavar='PATH', dest='db_fd',
- type=argparse.FileType('w', encoding='utf-8'),
- default=None,
+ parser.add_argument('-o', '--output', metavar='PATH', dest='db_path',
help='set database file path')
return parser.parse_args(args)
def track_status(
uids, timeout=DEFAULT_TIMEOUT,
- log_fd=sys.stdout,
- db_fd=None, db_fmt=DEFAULT_DB_FORMAT):
+ log_path=None,
+ db_path=None, db_fmt=DEFAULT_DB_FORMAT):
api = API(Language.EN, deactivated_users=False)
tracker = StatusTracker(api, timeout)
- if db_fmt is DatabaseFormat.LOG or db_fd is None:
+ if db_fmt is DatabaseFormat.LOG or db_path is None:
db_fmt = DatabaseFormat.NULL
- with DatabaseFormat.LOG.create_writer(log_fd) as log_writer:
+ with DatabaseFormat.LOG.create_writer(log_path) as log_writer:
tracker.add_database_writer(log_writer)
- with db_fmt.create_writer(db_fd) as db_writer:
+ with db_fmt.create_writer(db_path) as db_writer:
tracker.add_database_writer(db_writer)
tracker.loop(uids)
diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py
index f20c617..e607774 100644
--- a/vk/tracking/db/backend/csv.py
+++ b/vk/tracking/db/backend/csv.py
@@ -14,12 +14,6 @@ class Writer:
self._fd = fd
self._writer = csv.writer(fd, lineterminator='\n')
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def on_initial_status(self, user):
self._write_record(user)
self._fd.flush()
@@ -47,12 +41,6 @@ class Reader(Iterable):
def __init__(self, fd):
self._reader = csv.reader(fd)
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def __iter__(self):
return map(Reader._record_from_row, self._reader)
diff --git a/vk/tracking/db/backend/log.py b/vk/tracking/db/backend/log.py
index 6eebc35..f248d65 100644
--- a/vk/tracking/db/backend/log.py
+++ b/vk/tracking/db/backend/log.py
@@ -15,12 +15,6 @@ class Writer:
datefmt='%Y-%m-%d %H:%M:%S'))
self._logger.addHandler(handler)
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def info(self, msg):
self._logger.info(msg)
diff --git a/vk/tracking/db/backend/null.py b/vk/tracking/db/backend/null.py
index 6454f84..3cdb41e 100644
--- a/vk/tracking/db/backend/null.py
+++ b/vk/tracking/db/backend/null.py
@@ -6,13 +6,7 @@
from collections.abc import Iterable
class Writer:
- def __init__(self, fd):
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
+ def __init__(self):
pass
def on_initial_status(self, user):
@@ -25,13 +19,7 @@ class Writer:
pass
class Reader(Iterable):
- def __init__(self, fd):
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
+ def __init__(self):
pass
def __iter__(self):
diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py
index 9d5c6e4..e1b34a1 100644
--- a/vk/tracking/db/format.py
+++ b/vk/tracking/db/format.py
@@ -3,7 +3,9 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+from contextlib import contextmanager
from enum import Enum
+import sys
from . import backend
@@ -12,25 +14,65 @@ class Format(Enum):
LOG = 'log'
NULL = 'null'
- def create_writer(self, fd):
- if self is Format.CSV:
- return backend.csv.Writer(fd)
- elif self is Format.LOG:
- return backend.log.Writer(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_output_file(path) as fd:
+ if self is Format.CSV:
+ yield backend.csv.Writer(fd)
+ elif self is Format.LOG:
+ yield backend.log.Writer(fd)
+ elif self is Format.NULL:
+ yield backend.null.Writer()
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
+
+ @contextmanager
+ def _open_output_file(self, path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
+ elif self is Format.CSV or self is Format.LOG:
+ fd = open(path, 'w', encoding='utf-8')
elif self is Format.NULL:
- return backend.null.Writer(fd)
+ pass
else:
raise NotImplementedError('unsupported database format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
+
+ @contextmanager
+ def create_reader(self, path=None):
+ with self._open_input_file(path) as fd:
+ if self is Format.CSV:
+ yield backend.csv.Reader(fd)
+ elif self is Format.LOG:
+ raise NotImplementedError('cannot read from a log file')
+ elif self is Format.NULL:
+ yield backend.null.Reader()
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
- def create_reader(self, fd):
- if self is Format.CSV:
- return backend.csv.Reader(fd)
+ @contextmanager
+ def _open_input_file(self, path=None):
+ fd = sys.stdin
+ if path is None:
+ pass
+ elif self is Format.CSV:
+ fd = open(path, encoding='utf-8')
elif self is Format.LOG:
- raise NotImplementedError()
+ raise NotImplementedError('cannot read from a log file')
elif self is Format.NULL:
- return backend.null.Reader(fd)
+ pass
else:
raise NotImplementedError('unsupported database format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdin:
+ fd.close()
def __str__(self):
return self.value