aboutsummaryrefslogtreecommitdiffstatshomepage
diff options
context:
space:
mode:
authorEgor Tensin <Egor.Tensin@gmail.com>2019-12-23 07:20:36 +0300
committerEgor Tensin <Egor.Tensin@gmail.com>2019-12-23 07:20:36 +0300
commit7b8cc8a9f455eda41b9c7d70f4561a84fcda941e (patch)
treeb9e262e9a1dbb663c3b9f704a9fe4daf54be0ce9
parentTravis: online_sessions.sh: refactoring (diff)
downloadvk-scripts-7b8cc8a9f455eda41b9c7d70f4561a84fcda941e.tar.gz
vk-scripts-7b8cc8a9f455eda41b9c7d70f4561a84fcda941e.zip
pylint/pep8 fixes
-rw-r--r--bin/mutual_friends.py17
-rw-r--r--bin/online_sessions.py45
-rw-r--r--bin/show_status.py7
-rw-r--r--bin/track_status.py10
-rw-r--r--bin/utils/bar_chart.py2
-rw-r--r--bin/utils/io.py7
-rw-r--r--vk/api.py23
-rw-r--r--vk/error.py3
-rw-r--r--vk/last_seen.py4
-rw-r--r--vk/platform.py2
-rw-r--r--vk/tracking/db/backend/csv.py2
-rw-r--r--vk/tracking/db/backend/log.py1
-rw-r--r--vk/tracking/db/backend/null.py2
-rw-r--r--vk/tracking/db/format.py29
-rw-r--r--vk/tracking/db/io.py8
-rw-r--r--vk/tracking/db/meta.py2
-rw-r--r--vk/tracking/db/record.py1
-rw-r--r--vk/tracking/db/timestamp.py1
-rw-r--r--vk/tracking/online_sessions.py2
-rw-r--r--vk/tracking/status_tracker.py4
-rw-r--r--vk/user.py21
21 files changed, 141 insertions, 52 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py
index 82015bb..550d957 100644
--- a/bin/mutual_friends.py
+++ b/bin/mutual_friends.py
@@ -14,22 +14,27 @@ from vk.user import UserField
from .utils import io
+
_OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME
+
def _query_friend_list(api, user):
return api.friends_get(user.get_uid(), fields=_OUTPUT_USER_FIELDS)
+
def _filter_user_fields(user):
new_user = OrderedDict()
for field in _OUTPUT_USER_FIELDS:
new_user[str(field)] = user[field] if field in user else None
return new_user
+
class OutputSinkMutualFriends(metaclass=abc.ABCMeta):
@abc.abstractmethod
def write_mutual_friends(self, friend_list):
pass
+
class OutputSinkCSV(OutputSinkMutualFriends):
def __init__(self, fd=sys.stdout):
self._writer = io.FileWriterCSV(fd)
@@ -38,6 +43,7 @@ class OutputSinkCSV(OutputSinkMutualFriends):
for user in friend_list:
self._writer.write_row(user.values())
+
class OutputSinkJSON(OutputSinkMutualFriends):
def __init__(self, fd=sys.stdout):
self._writer = io.FileWriterJSON(fd)
@@ -45,6 +51,7 @@ class OutputSinkJSON(OutputSinkMutualFriends):
def write_mutual_friends(self, friend_list):
self._writer.write(friend_list)
+
class OutputFormat(Enum):
CSV = 'csv'
JSON = 'json'
@@ -59,10 +66,10 @@ class OutputFormat(Enum):
def create_sink(self, fd=sys.stdout):
if self is OutputFormat.CSV:
return OutputSinkCSV(fd)
- elif self is OutputFormat.JSON:
+ if self is OutputFormat.JSON:
return OutputSinkJSON(fd)
- else:
- raise NotImplementedError('unsupported output format: ' + str(self))
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
def _parse_output_format(s):
try:
@@ -70,6 +77,7 @@ def _parse_output_format(s):
except ValueError:
raise argparse.ArgumentTypeError('invalid output format: ' + s)
+
def _parse_args(args=None):
if args is None:
args = sys.argv[1:]
@@ -89,6 +97,7 @@ def _parse_args(args=None):
return parser.parse_args(args)
+
def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV):
api = API()
users = api.users_get(uids)
@@ -101,8 +110,10 @@ def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV):
sink = out_fmt.create_sink(out_fd)
sink.write_mutual_friends(mutual_friends)
+
def main(args=None):
write_mutual_friends(**vars(_parse_args(args)))
+
if __name__ == '__main__':
main()
diff --git a/bin/online_sessions.py b/bin/online_sessions.py
index 670a8c3..e0de7c9 100644
--- a/bin/online_sessions.py
+++ b/bin/online_sessions.py
@@ -17,6 +17,7 @@ from vk.user import UserField
from .utils.bar_chart import BarChartBuilder
from .utils import io
+
class GroupBy(Enum):
USER = 'user'
DATE = 'date'
@@ -30,14 +31,14 @@ class GroupBy(Enum):
online_streaks = OnlineSessionEnumerator(time_from, time_to)
if self is GroupBy.USER:
return online_streaks.group_by_user(db_reader)
- elif self is GroupBy.DATE:
+ if self is GroupBy.DATE:
return online_streaks.group_by_date(db_reader)
- elif self is GroupBy.WEEKDAY:
+ if self is GroupBy.WEEKDAY:
return online_streaks.group_by_weekday(db_reader)
- elif self is GroupBy.HOUR:
+ if self is GroupBy.HOUR:
return online_streaks.group_by_hour(db_reader)
- else:
- raise NotImplementedError('unsupported grouping: ' + str(self))
+ raise NotImplementedError('unsupported grouping: ' + str(self))
+
_OUTPUT_USER_FIELDS = (
UserField.UID,
@@ -46,11 +47,13 @@ _OUTPUT_USER_FIELDS = (
UserField.DOMAIN,
)
+
class OutputSinkOnlineSessions(metaclass=abc.ABCMeta):
@abc.abstractmethod
def process_database(self, group_by, db_reader, time_from=None, time_to=None):
pass
+
class OutputConverterCSV:
@staticmethod
def convert_user(user):
@@ -68,6 +71,7 @@ class OutputConverterCSV:
def convert_hour(hour):
return [str(timedelta(hours=hour))]
+
class OutputSinkCSV(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
self._writer = io.FileWriterCSV(fd)
@@ -91,6 +95,7 @@ class OutputSinkCSV(OutputSinkOnlineSessions):
row.append(str(duration))
self._writer.write_row(row)
+
class OutputConverterJSON:
_DATE_FIELD = 'date'
_WEEKDAY_FIELD = 'weekday'
@@ -125,6 +130,7 @@ class OutputConverterJSON:
obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour))
return obj
+
class OutputSinkJSON(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
self._writer = io.FileWriterJSON(fd)
@@ -142,7 +148,7 @@ class OutputSinkJSON(OutputSinkOnlineSessions):
@staticmethod
def _key_to_object(group_by, key):
- if not group_by in OutputSinkJSON._CONVERT_KEY:
+ if group_by not in OutputSinkJSON._CONVERT_KEY:
raise NotImplementedError('unsupported grouping: ' + str(group_by))
return OutputSinkJSON._CONVERT_KEY[group_by](key)
@@ -154,6 +160,7 @@ class OutputSinkJSON(OutputSinkOnlineSessions):
entries.append(entry)
self._writer.write(entries)
+
class OutputConverterPlot:
@staticmethod
def convert_user(user):
@@ -171,6 +178,7 @@ class OutputConverterPlot:
def convert_hour(hour):
return '{}:00'.format(hour)
+
class OutputSinkPlot(OutputSinkOnlineSessions):
def __init__(self, fd=sys.stdout):
self._fd = fd
@@ -200,11 +208,11 @@ class OutputSinkPlot(OutputSinkOnlineSessions):
@staticmethod
def _extract_labels(group_by, durations):
- return tuple(map(lambda key: OutputSinkPlot._format_key(group_by, key), durations.keys()))
+ return (OutputSinkPlot._format_key(group_by, key) for key in durations.keys())
@staticmethod
def _extract_values(durations):
- return tuple(map(OutputSinkPlot._duration_to_seconds, durations.values()))
+ return (OutputSinkPlot._duration_to_seconds(duration) for duration in durations.values())
def process_database(
self, group_by, db_reader, time_from=None, time_to=None):
@@ -219,8 +227,8 @@ class OutputSinkPlot(OutputSinkOnlineSessions):
fontsize='small', rotation=30)
bar_chart.set_value_label_formatter(self._format_duration)
- labels = self._extract_labels(group_by, durations)
- durations = self._extract_values(durations)
+ labels = tuple(self._extract_labels(group_by, durations))
+ durations = tuple(self._extract_values(durations))
if group_by is GroupBy.HOUR:
bar_chart.labels_align_middle = False
@@ -237,6 +245,7 @@ class OutputSinkPlot(OutputSinkOnlineSessions):
else:
bar_chart.save(self._fd)
+
class OutputFormat(Enum):
CSV = 'csv'
JSON = 'json'
@@ -248,38 +257,42 @@ class OutputFormat(Enum):
def create_sink(self, fd=sys.stdout):
if self is OutputFormat.CSV:
return OutputSinkCSV(fd)
- elif self is OutputFormat.JSON:
+ if self is OutputFormat.JSON:
return OutputSinkJSON(fd)
- elif self is OutputFormat.PLOT:
+ if self is OutputFormat.PLOT:
return OutputSinkPlot(fd)
- else:
- raise NotImplementedError('unsupported output format: ' + str(self))
+ raise NotImplementedError('unsupported output format: ' + str(self))
def open_file(self, path=None):
if self is OutputFormat.PLOT:
return io.open_output_binary_file(path)
return io.open_output_text_file(path)
+
def _parse_group_by(s):
try:
return GroupBy(s)
except ValueError:
raise argparse.ArgumentTypeError('invalid "group by" value: ' + s)
+
def _parse_database_format(s):
try:
return DatabaseFormat(s)
except ValueError:
raise argparse.ArgumentTypeError('invalid database format: ' + s)
+
def _parse_output_format(s):
try:
return OutputFormat(s)
except ValueError:
raise argparse.ArgumentTypeError('invalid output format: ' + s)
+
_DATE_RANGE_LIMIT_FORMAT = '%Y-%m-%dT%H:%M:%SZ'
+
def _parse_date_range_limit(s):
try:
dt = datetime.strptime(s, _DATE_RANGE_LIMIT_FORMAT)
@@ -289,6 +302,7 @@ def _parse_date_range_limit(s):
raise argparse.ArgumentTypeError(
msg.format(_DATE_RANGE_LIMIT_FORMAT, s))
+
def _parse_args(args=None):
if args is None:
args = sys.argv[1:]
@@ -324,6 +338,7 @@ def _parse_args(args=None):
return parser.parse_args(args)
+
def process_online_sessions(
db_path=None, db_fmt=DatabaseFormat.CSV,
out_path=None, out_fmt=OutputFormat.CSV,
@@ -343,8 +358,10 @@ def process_online_sessions(
time_from=time_from,
time_to=time_to)
+
def main(args=None):
process_online_sessions(**vars(_parse_args(args)))
+
if __name__ == '__main__':
main()
diff --git a/bin/show_status.py b/bin/show_status.py
index 1ca9cb5..cf59280 100644
--- a/bin/show_status.py
+++ b/bin/show_status.py
@@ -3,12 +3,14 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
-import argparse, sys
+import argparse
+import sys
from vk.api import API
from vk.tracking import StatusTracker
from vk.tracking.db import Format as DatabaseFormat
+
def _parse_args(args=None):
if args is None:
args = sys.argv[1:]
@@ -23,6 +25,7 @@ def _parse_args(args=None):
return parser.parse_args(args)
+
def track_status(uids, log_path=None):
api = API()
tracker = StatusTracker(api)
@@ -32,8 +35,10 @@ def track_status(uids, log_path=None):
tracker.add_database_writer(log_writer)
tracker.query_status(uids)
+
def main(args=None):
track_status(**vars(_parse_args(args)))
+
if __name__ == '__main__':
main()
diff --git a/bin/track_status.py b/bin/track_status.py
index 9c32ae9..2a974a5 100644
--- a/bin/track_status.py
+++ b/bin/track_status.py
@@ -3,15 +3,18 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
-import argparse, sys
+import argparse
+import sys
from vk.api import API
from vk.tracking import StatusTracker
from vk.tracking.db import Format as DatabaseFormat
+
DEFAULT_TIMEOUT = StatusTracker.DEFAULT_TIMEOUT
DEFAULT_DB_FORMAT = DatabaseFormat.CSV
+
def _parse_positive_integer(s):
try:
n = int(s)
@@ -21,12 +24,14 @@ def _parse_positive_integer(s):
raise argparse.ArgumentTypeError('must be a positive integer: ' + s)
return n
+
def _parse_database_format(s):
try:
return DatabaseFormat(s)
except ValueError:
raise argparse.ArgumentTypeError('invalid database format: ' + s)
+
def _parse_args(args=None):
if args is None:
args = sys.argv[1:]
@@ -52,6 +57,7 @@ def _parse_args(args=None):
return parser.parse_args(args)
+
def track_status(
uids, timeout=DEFAULT_TIMEOUT,
log_path=None,
@@ -72,8 +78,10 @@ def track_status(
tracker.loop(uids)
+
def main(args=None):
track_status(**vars(_parse_args(args)))
+
if __name__ == '__main__':
main()
diff --git a/bin/utils/bar_chart.py b/bin/utils/bar_chart.py
index 522dfed..f051efc 100644
--- a/bin/utils/bar_chart.py
+++ b/bin/utils/bar_chart.py
@@ -7,6 +7,7 @@ import matplotlib.pyplot as plt
from matplotlib import ticker
import numpy as np
+
class BarChartBuilder:
_BAR_HEIGHT = .5
@@ -130,6 +131,7 @@ class BarChartBuilder:
def save(self, path):
self._fig.savefig(path, bbox_inches='tight')
+
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
diff --git a/bin/utils/io.py b/bin/utils/io.py
index 04baa6a..bb8eef9 100644
--- a/bin/utils/io.py
+++ b/bin/utils/io.py
@@ -8,6 +8,7 @@ import csv
import json
import sys
+
class FileWriterJSON:
def __init__(self, fd=sys.stdout):
self._fd = fd
@@ -16,13 +17,14 @@ class FileWriterJSON:
self._fd.write(json.dumps(something, indent=3, ensure_ascii=False))
self._fd.write('\n')
+
class FileWriterCSV:
def __init__(self, fd=sys.stdout):
self._writer = csv.writer(fd, lineterminator='\n')
@staticmethod
def _convert_row_old_python(row):
- if isinstance(row, list) or isinstance(row, tuple):
+ if isinstance(row, (list, tuple)):
return row
return list(row)
@@ -31,6 +33,7 @@ class FileWriterCSV:
row = self._convert_row_old_python(row)
self._writer.writerow(row)
+
@contextmanager
def _open_file(path=None, default=None, **kwargs):
if path is None:
@@ -39,8 +42,10 @@ def _open_file(path=None, default=None, **kwargs):
with open(path, **kwargs) as fd:
yield fd
+
def open_output_text_file(path=None):
return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8')
+
def open_output_binary_file(path=None):
return _open_file(path, default=sys.stdout, mode='wb')
diff --git a/vk/api.py b/vk/api.py
index d79122d..0ba2f44 100644
--- a/vk/api.py
+++ b/vk/api.py
@@ -3,8 +3,7 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
-from collections import Iterable
-from collections.abc import Mapping
+from collections.abc import Iterable, Mapping
from enum import Enum
import json
from urllib.error import URLError
@@ -14,21 +13,24 @@ from urllib.request import urlopen
import vk.error
from vk.user import User
+
def _split_url(url):
return urllib.parse.urlsplit(url)[:3]
+
def _is_empty_param_value(value):
return isinstance(value, str) and not value
+
def _filter_empty_params(params, empty_params=False):
if empty_params:
return params
if isinstance(params, Mapping):
return {name: value for name, value in params.items() if not _is_empty_param_value(value)}
- elif isinstance(params, Iterable):
+ if isinstance(params, Iterable):
return [(name, value) for name, value in params if not _is_empty_param_value(value)]
- else:
- raise TypeError()
+ raise TypeError()
+
def _build_url(scheme, host, path, params=None, empty_params=False):
if params is None:
@@ -43,6 +45,7 @@ def _build_url(scheme, host, path, params=None, empty_params=False):
path = urllib.parse.quote(path)
return urllib.parse.urlunsplit((scheme, host, path, params, ''))
+
def _join_param_values(values):
if isinstance(values, str):
return values
@@ -50,13 +53,16 @@ def _join_param_values(values):
return ','.join(map(str, values))
return values
+
def _join_path(base, url):
if not base.endswith('/'):
base += '/'
return urllib.parse.urljoin(base, url)
+
ACCESS_TOKEN = '9722cef09722cef09722cef071974b8cbe997229722cef0cbabfd816916af6c7bd37006'
+
class Version(Enum):
V5_73 = '5.73'
DEFAULT = V5_73
@@ -64,6 +70,7 @@ class Version(Enum):
def __str__(self):
return self.value
+
class Language(Enum):
EN = 'en'
DEFAULT = EN
@@ -71,6 +78,7 @@ class Language(Enum):
def __str__(self):
return self.value
+
class Method(Enum):
USERS_GET = 'users.get'
FRIENDS_GET = 'friends.get'
@@ -78,6 +86,7 @@ class Method(Enum):
def __str__(self):
return self.value
+
class CommonParameters(Enum):
ACCESS_TOKEN = 'access_token'
VERSION = 'v'
@@ -86,6 +95,7 @@ class CommonParameters(Enum):
def __str__(self):
return self.value
+
class API:
_ROOT_URL = 'https://api.vk.com/method/'
@@ -136,4 +146,5 @@ class API:
fields=_join_param_values(fields))
if 'items' not in response:
raise vk.error.InvalidAPIResponseError(response)
- return self._filter_response_with_users(response['items'], deactivated_users)
+ return self._filter_response_with_users(response['items'],
+ deactivated_users)
diff --git a/vk/error.py b/vk/error.py
index 3a309c3..5f71e47 100644
--- a/vk/error.py
+++ b/vk/error.py
@@ -3,9 +3,11 @@
# For details, see https://github.com/egor-tensin/vk-scripts.
# Distributed under the MIT License.
+
class APIError(RuntimeError):
pass
+
class InvalidAPIResponseError(APIError):
def __init__(self, response):
super().__init__()
@@ -14,5 +16,6 @@ class InvalidAPIResponseError(APIError):
def __str__(self):
return str(self.response)
+
class APIConnectionError(APIError):
pass
diff --git a/vk/last_seen.py b/vk/last_seen.py
index 25975f2..cf76db5 100644
--- a/vk/last_seen.py
+++ b/vk/last_seen.py
@@ -11,6 +11,7 @@ from numbers import Integral, Real
from .platform import Platform
+
def _parse_time(x):
if isinstance(x, datetime):
if x.tzinfo is None or x.tzinfo.utcoffset(x) is None:
@@ -20,6 +21,7 @@ def _parse_time(x):
return datetime.fromtimestamp(x, tz=timezone.utc)
raise TypeError()
+
def _parse_platform(x):
if x in Platform:
return x
@@ -27,6 +29,7 @@ def _parse_platform(x):
return Platform.from_string(x)
return Platform(x)
+
class LastSeenField(Enum):
TIME = 'time'
PLATFORM = 'platform'
@@ -34,6 +37,7 @@ class LastSeenField(Enum):
def __str__(self):
return self.value
+
class LastSeen(MutableMapping):
@staticmethod
def from_api_response(source):
diff --git a/vk/platform.py b/vk/platform.py
index 8bf23c5..5cd78b8 100644
--- a/vk/platform.py
+++ b/vk/platform.py
@@ -6,6 +6,7 @@
from enum import Enum
import re
+
class Platform(Enum):
MOBILE = 1
IPHONE = 2
@@ -43,6 +44,7 @@ class Platform(Enum):
def get_descr_text_capitalized(self):
return self._capitalize_first_letter(self.get_descr_text())
+
_PLATFORM_DESCRIPTIONS = {
Platform.MOBILE: '"mobile" web version (or unrecognized mobile app)',
Platform.IPHONE: 'official iPhone app',
diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py
index 4943ff2..43038e4 100644
--- a/vk/tracking/db/backend/csv.py
+++ b/vk/tracking/db/backend/csv.py
@@ -8,6 +8,7 @@ from ..io import FileReaderCSV, FileWriterCSV
from ..record import Record
from ..timestamp import Timestamp
+
class Writer(meta.Writer):
def __init__(self, fd):
self._writer = FileWriterCSV(fd)
@@ -30,6 +31,7 @@ class Writer(meta.Writer):
def _record_to_row(record):
return [str(record.get_timestamp())] + [str(record[field]) for field in record]
+
class Reader(meta.Reader):
def __init__(self, fd):
self._reader = FileReaderCSV(fd)
diff --git a/vk/tracking/db/backend/log.py b/vk/tracking/db/backend/log.py
index 814cabc..d301856 100644
--- a/vk/tracking/db/backend/log.py
+++ b/vk/tracking/db/backend/log.py
@@ -7,6 +7,7 @@ import logging
from .. import meta
+
class Writer(meta.Writer):
def __init__(self, fd):
self._logger = logging.getLogger(__file__)
diff --git a/vk/tracking/db/backend/null.py b/vk/tracking/db/backend/null.py
index 663af10..80a66b4 100644
--- a/vk/tracking/db/backend/null.py
+++ b/vk/tracking/db/backend/null.py
@@ -5,6 +5,7 @@
from .. import meta
+
class Writer(meta.Writer):
def __init__(self):
pass
@@ -18,6 +19,7 @@ class Writer(meta.Writer):
def on_connection_error(self, e):
pass
+
class Reader(meta.Reader):
def __init__(self):
pass
diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py
index f9a670c..028d403 100644
--- a/vk/tracking/db/format.py
+++ b/vk/tracking/db/format.py
@@ -8,6 +8,7 @@ import sys
from . import backend, io
+
class Format(Enum):
CSV = 'csv'
LOG = 'log'
@@ -19,22 +20,20 @@ class Format(Enum):
def create_writer(self, fd=sys.stdout):
if self is Format.CSV:
return backend.csv.Writer(fd)
- elif self is Format.LOG:
+ if self is Format.LOG:
return backend.log.Writer(fd)
- elif self is Format.NULL:
+ if self is Format.NULL:
return backend.null.Writer()
- else:
- raise NotImplementedError('unsupported database format: ' + str(self))
+ raise NotImplementedError('unsupported database format: ' + str(self))
def open_output_file(self, path=None):
if self is Format.CSV:
return self._open_output_database_file(path)
- elif self is Format.LOG:
+ if self is Format.LOG:
return self._open_output_log_file(path)
- elif self is Format.NULL:
+ if self is Format.NULL:
return self._open_output_database_file(None)
- else:
- raise NotImplementedError('unsupported database format: ' + str(self))
+ raise NotImplementedError('unsupported database format: ' + str(self))
@staticmethod
def _open_output_log_file(path):
@@ -47,19 +46,17 @@ class Format(Enum):
def create_reader(self, fd=sys.stdin):
if self is Format.CSV:
return backend.csv.Reader(fd)
- elif self is Format.LOG:
+ if self is Format.LOG:
return NotImplementedError('cannot read from a log file')
- elif self is Format.NULL:
+ if self is Format.NULL:
return backend.null.Reader()
- else:
- raise NotImplementedError('unsupported database format: ' + str(self))
+ raise NotImplementedError('unsupported database format: ' + str(self))
def open_input_file(self, path=None):
if self is Format.CSV:
return io.open_input_text_file(path)
- elif self is Format.LOG:
+ if self is Format.LOG:
raise NotImplementedError('cannot read from a log file')
- elif self is Format.NULL:
+ if self is Format.NULL:
return io.open_input_text_file(None)
- else:
- raise NotImplementedError('unsupported database format: ' + str(self))
+ raise NotImplementedError('unsupported database format: ' + str(self))
diff --git a/vk/tracking/db/io.py b/vk/tracking/db/io.py
index 37d9c53..a89865f 100644
--- a/vk/tracking/db/io.py
+++ b/vk/tracking/db/io.py
@@ -7,6 +7,7 @@ from contextlib import contextmanager
import csv
import sys
+
class FileWriterCSV:
def __init__(self, fd=sys.stdout):
self._fd = fd
@@ -14,7 +15,7 @@ class FileWriterCSV:
@staticmethod
def _convert_row_old_python(row):
- if isinstance(row, list) or isinstance(row, tuple):
+ if isinstance(row, (list, tuple)):
return row
return list(row)
@@ -24,6 +25,7 @@ class FileWriterCSV:
self._writer.writerow(row)
self._fd.flush()
+
class FileReaderCSV:
def __init__(self, fd=sys.stdin):
self._reader = csv.reader(fd)
@@ -31,6 +33,7 @@ class FileReaderCSV:
def __iter__(self):
return iter(self._reader)
+
@contextmanager
def _open_file(path=None, default=None, **kwargs):
if path is None:
@@ -39,12 +42,15 @@ def _open_file(path=None, default=None, **kwargs):
with open(path, **kwargs) as fd:
yield fd
+
_DEFAULT_ENCODING = 'utf-8'
+
def open_output_text_file(path=None, mode='w'):
return _open_file(path, default=sys.stdout, mode=mode,
encoding=_DEFAULT_ENCODING)
+
def open_input_text_file(path=None):
return _open_file(path, default=sys.stdin, mode='r',
encoding=_DEFAULT_ENCODING)
diff --git a/vk/tracking/db/meta.py b/vk/tracking/db/meta.py
index 024d9d8..eb15c74 100644
--- a/vk/tracking/db/meta.py
+++ b/vk/tracking/db/meta.py
@@ -6,6 +6,7 @@
import abc
from collections.abc import Iterable
+
class Writer(metaclass=abc.ABCMeta):
@abc.abstractmethod
def on_initial_status(self, user):
@@ -19,6 +20,7 @@ class Writer(metaclass=abc.ABCMeta):
def on_connection_error(self, e):
pass
+
class Reader(Iterable, metaclass=abc.ABCMeta):
@abc.abstractmethod
def __iter__(self):
diff --git a/vk/tracking/db/record.py b/vk/tracking/db/record.py
index dfd47c6..6998238 100644
--- a/vk/tracking/db/record.py
+++ b/vk/tracking/db/record.py
@@ -12,6 +12,7 @@ from vk.user import User, UserField
from .timestamp import Timestamp
+
class Record(MutableMapping):
FIELDS = (
UserField.UID,
diff --git a/vk/tracking/db/timestamp.py b/vk/tracking/db/timestamp.py
index b2219ca..1309797 100644
--- a/vk/tracking/db/timestamp.py
+++ b/vk/tracking/db/timestamp.py
@@ -5,6 +5,7 @@
from datetime import datetime, timezone
+
class Timestamp:
@staticmethod
def _new():
diff --git a/vk/tracking/online_sessions.py b/vk/tracking/online_sessions.py
index 204e1cc..c43e11c 100644
--- a/vk/tracking/online_sessions.py
+++ b/vk/tracking/online_sessions.py
@@ -8,6 +8,7 @@ from collections.abc import MutableMapping
from datetime import timedelta
from enum import Enum
+
class Weekday(Enum):
MONDAY = 0
TUESDAY = 1
@@ -20,6 +21,7 @@ class Weekday(Enum):
def __str__(self):
return self.name[0] + self.name[1:].lower()
+
class OnlineSessionEnumerator(MutableMapping):
def __init__(self, time_from=None, time_to=None):
self._records = {}
diff --git a/vk/tracking/status_tracker.py b/vk/tracking/status_tracker.py
index 30e0f97..b87e059 100644
--- a/vk/tracking/status_tracker.py
+++ b/vk/tracking/status_tracker.py
@@ -11,6 +11,7 @@ import signal
import vk.error
from vk.user import UserField
+
class StatusTracker:
DEFAULT_TIMEOUT = 5
@@ -49,7 +50,8 @@ class StatusTracker:
_USER_FIELDS = UserField.DOMAIN, UserField.ONLINE, UserField.LAST_SEEN,
def _query_status(self, uids):
- user_list = self._api.users_get(uids, self._USER_FIELDS, deactivated_users=False)
+ user_list = self._api.users_get(uids, self._USER_FIELDS,
+ deactivated_users=False)
return {user.get_uid(): user for user in user_list}
def _notify_status(self, user):
diff --git a/vk/user.py b/vk/user.py
index c3c997b..eb65eaf 100644
--- a/vk/user.py
+++ b/vk/user.py
@@ -9,31 +9,33 @@ from enum import Enum
from .last_seen import LastSeen
+
def _parse_last_seen(x):
if isinstance(x, LastSeen):
return x
- elif isinstance(x, Mapping):
+ if isinstance(x, Mapping):
return LastSeen.from_api_response(x)
- else:
- raise TypeError()
+ raise TypeError()
+
def _parse_bool(x):
if isinstance(x, str):
if str(True) == x:
return True
- elif str(False) == x:
+ if str(False) == x:
return False
- else:
- raise ValueError()
- else:
- return bool(x)
+ raise ValueError()
+ return bool(x)
+
def _parse_hidden(x):
return _parse_bool(x)
+
def _parse_online_flag(x):
return _parse_bool(x)
+
class UserField(Enum):
UID = 'id'
FIRST_NAME = 'first_name'
@@ -48,6 +50,7 @@ class UserField(Enum):
def __str__(self):
return self.value
+
class DeactivationReason(Enum):
DELETED = 'deleted'
BANNED = 'banned'
@@ -55,9 +58,11 @@ class DeactivationReason(Enum):
def __str__(self):
return self.value
+
def _parse_deactivated(s):
return DeactivationReason(s)
+
class User(Hashable, MutableMapping):
@staticmethod
def from_api_response(source):