aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/vk
diff options
context:
space:
mode:
Diffstat (limited to 'vk')
-rw-r--r--vk/tracking/db/backend/csv.py12
-rw-r--r--vk/tracking/db/backend/log.py6
-rw-r--r--vk/tracking/db/backend/null.py16
-rw-r--r--vk/tracking/db/format.py64
4 files changed, 55 insertions, 43 deletions
diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py
index f20c617..e607774 100644
--- a/vk/tracking/db/backend/csv.py
+++ b/vk/tracking/db/backend/csv.py
@@ -14,12 +14,6 @@ class Writer:
self._fd = fd
self._writer = csv.writer(fd, lineterminator='\n')
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def on_initial_status(self, user):
self._write_record(user)
self._fd.flush()
@@ -47,12 +41,6 @@ class Reader(Iterable):
def __init__(self, fd):
self._reader = csv.reader(fd)
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def __iter__(self):
return map(Reader._record_from_row, self._reader)
diff --git a/vk/tracking/db/backend/log.py b/vk/tracking/db/backend/log.py
index 6eebc35..f248d65 100644
--- a/vk/tracking/db/backend/log.py
+++ b/vk/tracking/db/backend/log.py
@@ -15,12 +15,6 @@ class Writer:
datefmt='%Y-%m-%d %H:%M:%S'))
self._logger.addHandler(handler)
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- pass
-
def info(self, msg):
self._logger.info(msg)
diff --git a/vk/tracking/db/backend/null.py b/vk/tracking/db/backend/null.py
index 6454f84..3cdb41e 100644
--- a/vk/tracking/db/backend/null.py
+++ b/vk/tracking/db/backend/null.py
@@ -6,13 +6,7 @@
from collections.abc import Iterable
class Writer:
- def __init__(self, fd):
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
+ def __init__(self):
pass
def on_initial_status(self, user):
@@ -25,13 +19,7 @@ class Writer:
pass
class Reader(Iterable):
- def __init__(self, fd):
- pass
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
+ def __init__(self):
pass
def __iter__(self):
diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py
index 9d5c6e4..e1b34a1 100644
--- a/vk/tracking/db/format.py
+++ b/vk/tracking/db/format.py
@@ -3,7 +3,9 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+from contextlib import contextmanager
from enum import Enum
+import sys
from . import backend
@@ -12,25 +14,65 @@ class Format(Enum):
LOG = 'log'
NULL = 'null'
- def create_writer(self, fd):
- if self is Format.CSV:
- return backend.csv.Writer(fd)
- elif self is Format.LOG:
- return backend.log.Writer(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_output_file(path) as fd:
+ if self is Format.CSV:
+ yield backend.csv.Writer(fd)
+ elif self is Format.LOG:
+ yield backend.log.Writer(fd)
+ elif self is Format.NULL:
+ yield backend.null.Writer()
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
+
+ @contextmanager
+ def _open_output_file(self, path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
+ elif self is Format.CSV or self is Format.LOG:
+ fd = open(path, 'w', encoding='utf-8')
elif self is Format.NULL:
- return backend.null.Writer(fd)
+ pass
else:
raise NotImplementedError('unsupported database format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
+
+ @contextmanager
+ def create_reader(self, path=None):
+ with self._open_input_file(path) as fd:
+ if self is Format.CSV:
+ yield backend.csv.Reader(fd)
+ elif self is Format.LOG:
+ raise NotImplementedError('cannot read from a log file')
+ elif self is Format.NULL:
+ yield backend.null.Reader()
+ else:
+ raise NotImplementedError('unsupported database format: ' + str(self))
- def create_reader(self, fd):
- if self is Format.CSV:
- return backend.csv.Reader(fd)
+ @contextmanager
+ def _open_input_file(self, path=None):
+ fd = sys.stdin
+ if path is None:
+ pass
+ elif self is Format.CSV:
+ fd = open(path, encoding='utf-8')
elif self is Format.LOG:
- raise NotImplementedError()
+ raise NotImplementedError('cannot read from a log file')
elif self is Format.NULL:
- return backend.null.Reader(fd)
+ pass
else:
raise NotImplementedError('unsupported database format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdin:
+ fd.close()
def __str__(self):
return self.value