From db723bbd4832c9901363586f08543e43ef7ffa4e Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Thu, 26 Jan 2017 22:33:34 +0300 Subject: refactoring Context managers everywhere! --- vk/tracking/db/backend/csv.py | 12 -------- vk/tracking/db/backend/log.py | 6 ---- vk/tracking/db/backend/null.py | 16 ++--------- vk/tracking/db/format.py | 64 ++++++++++++++++++++++++++++++++++-------- 4 files changed, 55 insertions(+), 43 deletions(-) (limited to 'vk') diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py index f20c617..e607774 100644 --- a/vk/tracking/db/backend/csv.py +++ b/vk/tracking/db/backend/csv.py @@ -14,12 +14,6 @@ class Writer: self._fd = fd self._writer = csv.writer(fd, lineterminator='\n') - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def on_initial_status(self, user): self._write_record(user) self._fd.flush() @@ -47,12 +41,6 @@ class Reader(Iterable): def __init__(self, fd): self._reader = csv.reader(fd) - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def __iter__(self): return map(Reader._record_from_row, self._reader) diff --git a/vk/tracking/db/backend/log.py b/vk/tracking/db/backend/log.py index 6eebc35..f248d65 100644 --- a/vk/tracking/db/backend/log.py +++ b/vk/tracking/db/backend/log.py @@ -15,12 +15,6 @@ class Writer: datefmt='%Y-%m-%d %H:%M:%S')) self._logger.addHandler(handler) - def __enter__(self): - return self - - def __exit__(self, *args): - pass - def info(self, msg): self._logger.info(msg) diff --git a/vk/tracking/db/backend/null.py b/vk/tracking/db/backend/null.py index 6454f84..3cdb41e 100644 --- a/vk/tracking/db/backend/null.py +++ b/vk/tracking/db/backend/null.py @@ -6,13 +6,7 @@ from collections.abc import Iterable class Writer: - def __init__(self, fd): - pass - - def __enter__(self): - return self - - def __exit__(self, *args): + def __init__(self): pass def on_initial_status(self, user): @@ -25,13 +19,7 @@ class Writer: pass class Reader(Iterable): - def __init__(self, fd): - pass - - def __enter__(self): - return self - - def __exit__(self, *args): + def __init__(self): pass def __iter__(self): diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py index 9d5c6e4..e1b34a1 100644 --- a/vk/tracking/db/format.py +++ b/vk/tracking/db/format.py @@ -3,7 +3,9 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. +from contextlib import contextmanager from enum import Enum +import sys from . import backend @@ -12,25 +14,65 @@ class Format(Enum): LOG = 'log' NULL = 'null' - def create_writer(self, fd): - if self is Format.CSV: - return backend.csv.Writer(fd) - elif self is Format.LOG: - return backend.log.Writer(fd) + @contextmanager + def create_writer(self, path=None): + with self._open_output_file(path) as fd: + if self is Format.CSV: + yield backend.csv.Writer(fd) + elif self is Format.LOG: + yield backend.log.Writer(fd) + elif self is Format.NULL: + yield backend.null.Writer() + else: + raise NotImplementedError('unsupported database format: ' + str(self)) + + @contextmanager + def _open_output_file(self, path=None): + fd = sys.stdout + if path is None: + pass + elif self is Format.CSV or self is Format.LOG: + fd = open(path, 'w', encoding='utf-8') elif self is Format.NULL: - return backend.null.Writer(fd) + pass else: raise NotImplementedError('unsupported database format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdout: + fd.close() + + @contextmanager + def create_reader(self, path=None): + with self._open_input_file(path) as fd: + if self is Format.CSV: + yield backend.csv.Reader(fd) + elif self is Format.LOG: + raise NotImplementedError('cannot read from a log file') + elif self is Format.NULL: + yield backend.null.Reader() + else: + raise NotImplementedError('unsupported database format: ' + str(self)) - def create_reader(self, fd): - if self is Format.CSV: - return backend.csv.Reader(fd) + @contextmanager + def _open_input_file(self, path=None): + fd = sys.stdin + if path is None: + pass + elif self is Format.CSV: + fd = open(path, encoding='utf-8') elif self is Format.LOG: - raise NotImplementedError() + raise NotImplementedError('cannot read from a log file') elif self is Format.NULL: - return backend.null.Reader(fd) + pass else: raise NotImplementedError('unsupported database format: ' + str(self)) + try: + yield fd + finally: + if fd is not sys.stdin: + fd.close() def __str__(self): return self.value -- cgit v1.2.3