aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/bin/online_sessions.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--bin/online_sessions.py57
1 files changed, 38 insertions, 19 deletions
diff --git a/bin/online_sessions.py b/bin/online_sessions.py
index ec26689..7a763e5 100644
--- a/bin/online_sessions.py
+++ b/bin/online_sessions.py
@@ -6,6 +6,7 @@
import argparse
import csv
from collections import OrderedDict
+from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from enum import Enum
import json
@@ -244,15 +245,34 @@ class OutputFormat(Enum):
JSON = 'json'
PLOT = 'plot'
- def create_writer(self, fd):
- if self is OutputFormat.CSV:
- return OutputWriterCSV(fd)
- elif self is OutputFormat.JSON:
- return OutputWriterJSON(fd)
+ @contextmanager
+ def create_writer(self, path=None):
+ with self._open_file(path) as fd:
+ if self is OutputFormat.CSV:
+ yield OutputWriterCSV(fd)
+ elif self is OutputFormat.JSON:
+ yield OutputWriterJSON(fd)
+ elif self is OutputFormat.PLOT:
+ yield OutputWriterPlot(fd)
+ else:
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
+ @contextmanager
+ def _open_file(self, path=None):
+ fd = sys.stdout
+ if path is None:
+ pass
+ elif self is OutputFormat.CSV or self is OutputFormat.JSON:
+ fd = open(path, 'w', encoding='utf-8')
elif self is OutputFormat.PLOT:
- return OutputWriterPlot(fd)
+ fd = open(path, 'wb')
else:
raise NotImplementedError('unsupported output format: ' + str(self))
+ try:
+ yield fd
+ finally:
+ if fd is not sys.stdout:
+ fd.close()
def __str__(self):
return self.value
@@ -293,12 +313,9 @@ def _parse_args(args=None):
parser = argparse.ArgumentParser(
description='View/visualize the amount of time people spend online.')
- parser.add_argument('db_fd', metavar='input',
- type=argparse.FileType('r', encoding='utf-8'),
- help='database file path')
- parser.add_argument('fd', metavar='output', nargs='?',
- type=argparse.FileType('w', encoding='utf-8'),
- default=sys.stdout,
+ parser.add_argument('db_path', metavar='input', nargs='?',
+ help='database file path (standard input by default)')
+ parser.add_argument('out_path', metavar='output', nargs='?',
help='output file path (standard output by default)')
parser.add_argument('-g', '--group-by',
type=_parse_group_by,
@@ -310,7 +327,7 @@ def _parse_args(args=None):
default=DatabaseFormat.CSV,
choices=DatabaseFormat,
help='specify database format')
- parser.add_argument('-o', '--output-format', dest='fmt',
+ parser.add_argument('-o', '--output-format', dest='out_fmt',
type=_parse_output_format,
choices=OutputFormat,
default=OutputFormat.CSV,
@@ -325,8 +342,8 @@ def _parse_args(args=None):
return parser.parse_args(args)
def process_online_sessions(
- db_fd, db_fmt=DatabaseFormat.CSV,
- fd=sys.stdout, fmt=OutputFormat.CSV,
+ db_path=None, db_fmt=DatabaseFormat.CSV,
+ out_path=None, out_fmt=OutputFormat.CSV,
group_by=GroupBy.USER,
time_from=None, time_to=None):
@@ -334,10 +351,12 @@ def process_online_sessions(
if time_from > time_to:
time_from, time_to = time_to, time_from
- with db_fmt.create_reader(db_fd) as db_reader:
- output_writer = fmt.create_writer(fd)
- output_writer.process_database(
- group_by, db_reader, time_from=time_from, time_to=time_to)
+ with db_fmt.create_reader(db_path) as db_reader:
+ with out_fmt.create_writer(out_path) as out_writer:
+ out_writer.process_database(
+ group_by, db_reader,
+ time_from=time_from,
+ time_to=time_to)
def main(args=None):
process_online_sessions(**vars(_parse_args(args)))