aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorEgor Tensin <Egor.Tensin@gmail.com>2017-01-28 02:40:27 +0300
committerEgor Tensin <Egor.Tensin@gmail.com>2017-01-28 02:40:27 +0300
commit1c0e2c13d4f0acbe42e43d886df4f1067506efed (patch)
tree8d76116d718293275315605b66f1466af7d09c27
parentbin: move file i/o to a separate module (diff)
downloadvk-scripts-1c0e2c13d4f0acbe42e43d886df4f1067506efed.tar.gz
vk-scripts-1c0e2c13d4f0acbe42e43d886df4f1067506efed.zip
vk: move file i/o to a separate module
-rw-r--r--bin/online_sessions.py3
-rw-r--r--bin/track_status.py7
-rw-r--r--vk/tracking/db/backend/csv.py14
-rw-r--r--vk/tracking/db/format.py83
-rw-r--r--vk/tracking/db/io.py43
5 files changed, 84 insertions, 66 deletions
diff --git a/bin/online_sessions.py b/bin/online_sessions.py
index cb3f4bf..be3a4f2 100644
--- a/bin/online_sessions.py
+++ b/bin/online_sessions.py
@@ -335,7 +335,8 @@ def process_online_sessions(
if time_from > time_to:
time_from, time_to = time_to, time_from
- with db_fmt.create_reader(db_path) as db_reader:
+ with db_fmt.open_input_file(db_path) as db_fd:
+ db_reader = db_fmt.create_reader(db_fd)
with out_fmt.open_file(out_path) as out_fd:
out_sink = out_fmt.create_sink(out_fd)
out_sink.process_database(
diff --git a/bin/track_status.py b/bin/track_status.py
index 404c1ec..4c679c5 100644
--- a/bin/track_status.py
+++ b/bin/track_status.py
@@ -63,10 +63,13 @@ def track_status(
if db_fmt is DatabaseFormat.LOG or db_path is None:
db_fmt = DatabaseFormat.NULL
- with DatabaseFormat.LOG.create_writer(log_path) as log_writer:
+ with DatabaseFormat.LOG.open_output_file(log_path) as log_fd:
+ log_writer = DatabaseFormat.LOG.create_writer(log_fd)
tracker.add_database_writer(log_writer)
- with db_fmt.create_writer(db_path) as db_writer:
+ with db_fmt.open_output_file(db_path) as db_fd:
+ db_writer = db_fmt.create_writer(db_fd)
tracker.add_database_writer(db_writer)
+
tracker.loop(uids)
def main(args=None):
diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py
index e607774..5b2cc7d 100644
--- a/vk/tracking/db/backend/csv.py
+++ b/vk/tracking/db/backend/csv.py
@@ -4,23 +4,20 @@
# Distributed under the MIT License.
from collections.abc import Iterable
-import csv
+from ..io import InputReaderCSV, OutputWriterCSV
from ..record import Record
from ..timestamp import Timestamp
class Writer:
def __init__(self, fd):
- self._fd = fd
- self._writer = csv.writer(fd, lineterminator='\n')
+ self._writer = OutputWriterCSV(fd)
def on_initial_status(self, user):
self._write_record(user)
- self._fd.flush()
def on_status_update(self, user):
self._write_record(user)
- self._fd.flush()
def on_connection_error(self, e):
pass
@@ -28,10 +25,7 @@ class Writer:
def _write_record(self, user):
if not self:
return
- self._write_row(self._record_to_row(Record.from_user(user)))
-
- def _write_row(self, row):
- self._writer.writerow(row)
+ self._writer.write_row(self._record_to_row(Record.from_user(user)))
@staticmethod
def _record_to_row(record):
@@ -39,7 +33,7 @@ class Writer:
class Reader(Iterable):
def __init__(self, fd):
- self._reader = csv.reader(fd)
+ self._reader = InputReaderCSV(fd)
def __iter__(self):
return map(Reader._record_from_row, self._reader)
diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py
index e1b34a1..7b3f312 100644
--- a/vk/tracking/db/format.py
+++ b/vk/tracking/db/format.py
@@ -3,76 +3,53 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
-from contextlib import contextmanager
from enum import Enum
import sys
-from . import backend
+from . import backend, io
class Format(Enum):
CSV = 'csv'
LOG = 'log'
NULL = 'null'
- @contextmanager
- def create_writer(self, path=None):
- with self._open_output_file(path) as fd:
- if self is Format.CSV:
- yield backend.csv.Writer(fd)
- elif self is Format.LOG:
- yield backend.log.Writer(fd)
- elif self is Format.NULL:
- yield backend.null.Writer()
- else:
- raise NotImplementedError('unsupported database format: ' + str(self))
+ def __str__(self):
+ return self.value
- @contextmanager
- def _open_output_file(self, path=None):
- fd = sys.stdout
- if path is None:
- pass
- elif self is Format.CSV or self is Format.LOG:
- fd = open(path, 'w', encoding='utf-8')
+ def create_writer(self, fd=sys.stdout):
+ if self is Format.CSV:
+ return backend.csv.Writer(fd)
+ elif self is Format.LOG:
+ return backend.log.Writer(fd)
elif self is Format.NULL:
- pass
+ return backend.null.Writer()
else:
raise NotImplementedError('unsupported database format: ' + str(self))
- try:
- yield fd
- finally:
- if fd is not sys.stdout:
- fd.close()
- @contextmanager
- def create_reader(self, path=None):
- with self._open_input_file(path) as fd:
- if self is Format.CSV:
- yield backend.csv.Reader(fd)
- elif self is Format.LOG:
- raise NotImplementedError('cannot read from a log file')
- elif self is Format.NULL:
- yield backend.null.Reader()
- else:
- raise NotImplementedError('unsupported database format: ' + str(self))
+ def open_output_file(self, path=None):
+ if self is Format.CSV or self is Format.LOG:
+ return io.open_output_text_file(path)
+ elif self is Format.NULL:
+ return io.open_output_text_file(None)
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
- @contextmanager
- def _open_input_file(self, path=None):
- fd = sys.stdin
- if path is None:
- pass
- elif self is Format.CSV:
- fd = open(path, encoding='utf-8')
+ def create_reader(self, fd=sys.stdin):
+ if self is Format.CSV:
+ return backend.csv.Reader(fd)
elif self is Format.LOG:
- raise NotImplementedError('cannot read from a log file')
+ return NotImplementedError('cannot read from a log file')
elif self is Format.NULL:
- pass
+ return backend.null.Reader()
else:
raise NotImplementedError('unsupported database format: ' + str(self))
- try:
- yield fd
- finally:
- if fd is not sys.stdin:
- fd.close()
- def __str__(self):
- return self.value
+ def open_input_file(self, path=None):
+ if self is Format.CSV:
+ return io.open_input_text_file(path)
+ elif self is Format.LOG:
+ raise NotImplementedError('cannot read from a log file')
+ elif self is Format.NULL:
+ return io.open_input_text_file(None)
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
diff --git a/vk/tracking/db/io.py b/vk/tracking/db/io.py
new file mode 100644
index 0000000..04bab3d
--- /dev/null
+++ b/vk/tracking/db/io.py
@@ -0,0 +1,43 @@
+# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+from contextlib import contextmanager
+import csv
+import sys
+
+class OutputWriterCSV:
+ def __init__(self, fd=sys.stdout):
+ self._fd = fd
+ self._writer = csv.writer(fd, lineterminator='\n')
+
+ def write_row(self, row):
+ self._writer.writerow(row)
+ self._fd.flush()
+
+class InputReaderCSV:
+ def __init__(self, fd=sys.stdin):
+ self._reader = csv.reader(fd)
+
+ def __iter__(self):
+ return iter(self._reader)
+
+@contextmanager
+def _open_file(path=None, default=None, **kwargs):
+ fd = default
+ if path is None:
+ pass
+ else:
+ fd = open(path, **kwargs)
+ try:
+ yield fd
+ finally:
+ if fd is not default:
+ fd.close()
+
+def open_output_text_file(path=None):
+ return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8')
+
+def open_input_text_file(path=None):
+ return _open_file(path, default=sys.stdin, mode='r', encoding='utf-8')