diff options
Diffstat (limited to 'vk/tracking/db/format.py')
-rw-r--r-- | vk/tracking/db/format.py | 64 |
1 files changed, 53 insertions, 11 deletions
diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py index 9d5c6e4..e1b34a1 100644 --- a/vk/tracking/db/format.py +++ b/vk/tracking/db/format.py @@ -3,7 +3,9 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. +from contextlib import contextmanager from enum import Enum +import sys from . import backend @@ -12,25 +14,65 @@ class Format(Enum): LOG = 'log' NULL = 'null' - def create_writer(self, fd): - if self is Format.CSV: - return backend.csv.Writer(fd) - elif self is Format.LOG: - return backend.log.Writer(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_output_file(path) as fd: + if self is Format.CSV: + yield backend.csv.Writer(fd) + elif self is Format.LOG: + yield backend.log.Writer(fd) + elif self is Format.NULL: + yield backend.null.Writer() + else: + raise NotImplementedError('unsupported database format: ' + str(self)) + + @contextmanager + def _open_output_file(self, path=None): + fd = sys.stdout + if path is None: + pass + elif self is Format.CSV or self is Format.LOG: + fd = open(path, 'w', encoding='utf-8') elif self is Format.NULL: - return backend.null.Writer(fd) + pass else: raise NotImplementedError('unsupported database format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() + + @contextmanager + def create_reader(self, path=None): + with self._open_input_file(path) as fd: + if self is Format.CSV: + yield backend.csv.Reader(fd) + elif self is Format.LOG: + raise NotImplementedError('cannot read from a log file') + elif self is Format.NULL: + yield backend.null.Reader() + else: + raise NotImplementedError('unsupported database format: ' + str(self)) - def create_reader(self, fd): - if self is Format.CSV: - return backend.csv.Reader(fd) + @contextmanager + def _open_input_file(self, path=None): + fd = sys.stdin + if path is None: + pass + elif self is Format.CSV: + fd = open(path, encoding='utf-8') elif self is Format.LOG: - raise NotImplementedError() + raise NotImplementedError('cannot read from a log file') elif self is Format.NULL: - return backend.null.Reader(fd) + pass else: raise NotImplementedError('unsupported database format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdin: + fd.close() def __str__(self): return self.value |