diff options
author | Egor Tensin <Egor.Tensin@gmail.com> | 2017-01-28 02:40:27 +0300 |
---|---|---|
committer | Egor Tensin <Egor.Tensin@gmail.com> | 2017-01-28 02:40:27 +0300 |
commit | 1c0e2c13d4f0acbe42e43d886df4f1067506efed (patch) | |
tree | 8d76116d718293275315605b66f1466af7d09c27 | |
parent | bin: move file i/o to a separate module (diff) | |
download | vk-scripts-1c0e2c13d4f0acbe42e43d886df4f1067506efed.tar.gz vk-scripts-1c0e2c13d4f0acbe42e43d886df4f1067506efed.zip |
vk: move file i/o to a separate module
-rw-r--r-- | bin/online_sessions.py | 3 | ||||
-rw-r--r-- | bin/track_status.py | 7 | ||||
-rw-r--r-- | vk/tracking/db/backend/csv.py | 14 | ||||
-rw-r--r-- | vk/tracking/db/format.py | 83 | ||||
-rw-r--r-- | vk/tracking/db/io.py | 43 |
5 files changed, 84 insertions, 66 deletions
diff --git a/bin/online_sessions.py b/bin/online_sessions.py index cb3f4bf..be3a4f2 100644 --- a/bin/online_sessions.py +++ b/bin/online_sessions.py @@ -335,7 +335,8 @@ def process_online_sessions( if time_from > time_to: time_from, time_to = time_to, time_from - with db_fmt.create_reader(db_path) as db_reader: + with db_fmt.open_input_file(db_path) as db_fd: + db_reader = db_fmt.create_reader(db_fd) with out_fmt.open_file(out_path) as out_fd: out_sink = out_fmt.create_sink(out_fd) out_sink.process_database( diff --git a/bin/track_status.py b/bin/track_status.py index 404c1ec..4c679c5 100644 --- a/bin/track_status.py +++ b/bin/track_status.py @@ -63,10 +63,13 @@ def track_status( if db_fmt is DatabaseFormat.LOG or db_path is None: db_fmt = DatabaseFormat.NULL - with DatabaseFormat.LOG.create_writer(log_path) as log_writer: + with DatabaseFormat.LOG.open_output_file(log_path) as log_fd: + log_writer = DatabaseFormat.LOG.create_writer(log_fd) tracker.add_database_writer(log_writer) - with db_fmt.create_writer(db_path) as db_writer: + with db_fmt.open_output_file(db_path) as db_fd: + db_writer = db_fmt.create_writer(db_fd) tracker.add_database_writer(db_writer) + tracker.loop(uids) def main(args=None): diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py index e607774..5b2cc7d 100644 --- a/vk/tracking/db/backend/csv.py +++ b/vk/tracking/db/backend/csv.py @@ -4,23 +4,20 @@ # Distributed under the MIT License. from collections.abc import Iterable -import csv +from ..io import InputReaderCSV, OutputWriterCSV from ..record import Record from ..timestamp import Timestamp class Writer: def __init__(self, fd): - self._fd = fd - self._writer = csv.writer(fd, lineterminator='\n') + self._writer = OutputWriterCSV(fd) def on_initial_status(self, user): self._write_record(user) - self._fd.flush() def on_status_update(self, user): self._write_record(user) - self._fd.flush() def on_connection_error(self, e): pass @@ -28,10 +25,7 @@ class Writer: def _write_record(self, user): if not self: return - self._write_row(self._record_to_row(Record.from_user(user))) - - def _write_row(self, row): - self._writer.writerow(row) + self._writer.write_row(self._record_to_row(Record.from_user(user))) @staticmethod def _record_to_row(record): @@ -39,7 +33,7 @@ class Writer: class Reader(Iterable): def __init__(self, fd): - self._reader = csv.reader(fd) + self._reader = InputReaderCSV(fd) def __iter__(self): return map(Reader._record_from_row, self._reader) diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py index e1b34a1..7b3f312 100644 --- a/vk/tracking/db/format.py +++ b/vk/tracking/db/format.py @@ -3,76 +3,53 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. -from contextlib import contextmanager from enum import Enum import sys -from . import backend +from . import backend, io class Format(Enum): CSV = 'csv' LOG = 'log' NULL = 'null' - @contextmanager - def create_writer(self, path=None): - with self._open_output_file(path) as fd: - if self is Format.CSV: - yield backend.csv.Writer(fd) - elif self is Format.LOG: - yield backend.log.Writer(fd) - elif self is Format.NULL: - yield backend.null.Writer() - else: - raise NotImplementedError('unsupported database format: ' + str(self)) + def __str__(self): + return self.value - @contextmanager - def _open_output_file(self, path=None): - fd = sys.stdout - if path is None: - pass - elif self is Format.CSV or self is Format.LOG: - fd = open(path, 'w', encoding='utf-8') + def create_writer(self, fd=sys.stdout): + if self is Format.CSV: + return backend.csv.Writer(fd) + elif self is Format.LOG: + return backend.log.Writer(fd) elif self is Format.NULL: - pass + return backend.null.Writer() else: raise NotImplementedError('unsupported database format: ' + str(self)) - try: - yield fd - finally: - if fd is not sys.stdout: - fd.close() - @contextmanager - def create_reader(self, path=None): - with self._open_input_file(path) as fd: - if self is Format.CSV: - yield backend.csv.Reader(fd) - elif self is Format.LOG: - raise NotImplementedError('cannot read from a log file') - elif self is Format.NULL: - yield backend.null.Reader() - else: - raise NotImplementedError('unsupported database format: ' + str(self)) + def open_output_file(self, path=None): + if self is Format.CSV or self is Format.LOG: + return io.open_output_text_file(path) + elif self is Format.NULL: + return io.open_output_text_file(None) + else: + raise NotImplementedError('unsupported database format: ' + str(self)) - @contextmanager - def _open_input_file(self, path=None): - fd = sys.stdin - if path is None: - pass - elif self is Format.CSV: - fd = open(path, encoding='utf-8') + def create_reader(self, fd=sys.stdin): + if self is Format.CSV: + return backend.csv.Reader(fd) elif self is Format.LOG: - raise NotImplementedError('cannot read from a log file') + return NotImplementedError('cannot read from a log file') elif self is Format.NULL: - pass + return backend.null.Reader() else: raise NotImplementedError('unsupported database format: ' + str(self)) - try: - yield fd - finally: - if fd is not sys.stdin: - fd.close() - def __str__(self): - return self.value + def open_input_file(self, path=None): + if self is Format.CSV: + return io.open_input_text_file(path) + elif self is Format.LOG: + raise NotImplementedError('cannot read from a log file') + elif self is Format.NULL: + return io.open_input_text_file(None) + else: + raise NotImplementedError('unsupported database format: ' + str(self)) diff --git a/vk/tracking/db/io.py b/vk/tracking/db/io.py new file mode 100644 index 0000000..04bab3d --- /dev/null +++ b/vk/tracking/db/io.py @@ -0,0 +1,43 @@ +# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com> +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +from contextlib import contextmanager +import csv +import sys + +class OutputWriterCSV: + def __init__(self, fd=sys.stdout): + self._fd = fd + self._writer = csv.writer(fd, lineterminator='\n') + + def write_row(self, row): + self._writer.writerow(row) + self._fd.flush() + +class InputReaderCSV: + def __init__(self, fd=sys.stdin): + self._reader = csv.reader(fd) + + def __iter__(self): + return iter(self._reader) + +@contextmanager +def _open_file(path=None, default=None, **kwargs): + fd = default + if path is None: + pass + else: + fd = open(path, **kwargs) + try: + yield fd + finally: + if fd is not default: + fd.close() + +def open_output_text_file(path=None): + return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8') + +def open_input_text_file(path=None): + return _open_file(path, default=sys.stdin, mode='r', encoding='utf-8') |