aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/vk/tracking/db/format.py
diff options
context:
space:
mode:
Diffstat (limited to 'vk/tracking/db/format.py')
-rw-r--r--vk/tracking/db/format.py64
1 files changed, 53 insertions, 11 deletions
diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py
index 9d5c6e4..e1b34a1 100644
--- a/vk/tracking/db/format.py
+++ b/vk/tracking/db/format.py
@@ -3,7 +3,9 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+from contextlib import contextmanager
from enum import Enum
+import sys
from . import backend
@@ -12,25 +14,65 @@ class Format(Enum):
LOG = 'log'
NULL = 'null'
- def create_writer(self, fd):
- if self is Format.CSV:
- return backend.csv.Writer(fd)
- elif self is Format.LOG:
- return backend.log.Writer(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_output_file(path) as fd:
+ if self is Format.CSV:
+ yield backend.csv.Writer(fd)
+ elif self is Format.LOG:
+ yield backend.log.Writer(fd)
+ elif self is Format.NULL:
+ yield backend.null.Writer()
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
+
+ @contextmanager
+ def _open_output_file(self, path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
+ elif self is Format.CSV or self is Format.LOG:
+ fd = open(path, 'w', encoding='utf-8')
elif self is Format.NULL:
- return backend.null.Writer(fd)
+ pass
else:
raise NotImplementedError('unsupported database format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
+
+ @contextmanager
+ def create_reader(self, path=None):
+ with self._open_input_file(path) as fd:
+ if self is Format.CSV:
+ yield backend.csv.Reader(fd)
+ elif self is Format.LOG:
+ raise NotImplementedError('cannot read from a log file')
+ elif self is Format.NULL:
+ yield backend.null.Reader()
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
- def create_reader(self, fd):
- if self is Format.CSV:
- return backend.csv.Reader(fd)
+ @contextmanager
+ def _open_input_file(self, path=None):
+ fd = sys.stdin
+ if path is None:
+ pass
+ elif self is Format.CSV:
+ fd = open(path, encoding='utf-8')
elif self is Format.LOG:
- raise NotImplementedError()
+ raise NotImplementedError('cannot read from a log file')
elif self is Format.NULL:
- return backend.null.Reader(fd)
+ pass
else:
raise NotImplementedError('unsupported database format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdin:
+ fd.close()
def __str__(self):
return self.value