diff options
author | Egor Tensin <Egor.Tensin@gmail.com> | 2019-12-23 07:20:36 +0300 |
---|---|---|
committer | Egor Tensin <Egor.Tensin@gmail.com> | 2019-12-23 07:20:36 +0300 |
commit | 7b8cc8a9f455eda41b9c7d70f4561a84fcda941e (patch) | |
tree | b9e262e9a1dbb663c3b9f704a9fe4daf54be0ce9 | |
parent | Travis: online_sessions.sh: refactoring (diff) | |
download | vk-scripts-7b8cc8a9f455eda41b9c7d70f4561a84fcda941e.tar.gz vk-scripts-7b8cc8a9f455eda41b9c7d70f4561a84fcda941e.zip |
pylint/pep8 fixes
-rw-r--r-- | bin/mutual_friends.py | 17 | ||||
-rw-r--r-- | bin/online_sessions.py | 45 | ||||
-rw-r--r-- | bin/show_status.py | 7 | ||||
-rw-r--r-- | bin/track_status.py | 10 | ||||
-rw-r--r-- | bin/utils/bar_chart.py | 2 | ||||
-rw-r--r-- | bin/utils/io.py | 7 | ||||
-rw-r--r-- | vk/api.py | 23 | ||||
-rw-r--r-- | vk/error.py | 3 | ||||
-rw-r--r-- | vk/last_seen.py | 4 | ||||
-rw-r--r-- | vk/platform.py | 2 | ||||
-rw-r--r-- | vk/tracking/db/backend/csv.py | 2 | ||||
-rw-r--r-- | vk/tracking/db/backend/log.py | 1 | ||||
-rw-r--r-- | vk/tracking/db/backend/null.py | 2 | ||||
-rw-r--r-- | vk/tracking/db/format.py | 29 | ||||
-rw-r--r-- | vk/tracking/db/io.py | 8 | ||||
-rw-r--r-- | vk/tracking/db/meta.py | 2 | ||||
-rw-r--r-- | vk/tracking/db/record.py | 1 | ||||
-rw-r--r-- | vk/tracking/db/timestamp.py | 1 | ||||
-rw-r--r-- | vk/tracking/online_sessions.py | 2 | ||||
-rw-r--r-- | vk/tracking/status_tracker.py | 4 | ||||
-rw-r--r-- | vk/user.py | 21 |
21 files changed, 141 insertions, 52 deletions
diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py index 82015bb..550d957 100644 --- a/bin/mutual_friends.py +++ b/bin/mutual_friends.py @@ -14,22 +14,27 @@ from vk.user import UserField from .utils import io + _OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME + def _query_friend_list(api, user): return api.friends_get(user.get_uid(), fields=_OUTPUT_USER_FIELDS) + def _filter_user_fields(user): new_user = OrderedDict() for field in _OUTPUT_USER_FIELDS: new_user[str(field)] = user[field] if field in user else None return new_user + class OutputSinkMutualFriends(metaclass=abc.ABCMeta): @abc.abstractmethod def write_mutual_friends(self, friend_list): pass + class OutputSinkCSV(OutputSinkMutualFriends): def __init__(self, fd=sys.stdout): self._writer = io.FileWriterCSV(fd) @@ -38,6 +43,7 @@ class OutputSinkCSV(OutputSinkMutualFriends): for user in friend_list: self._writer.write_row(user.values()) + class OutputSinkJSON(OutputSinkMutualFriends): def __init__(self, fd=sys.stdout): self._writer = io.FileWriterJSON(fd) @@ -45,6 +51,7 @@ class OutputSinkJSON(OutputSinkMutualFriends): def write_mutual_friends(self, friend_list): self._writer.write(friend_list) + class OutputFormat(Enum): CSV = 'csv' JSON = 'json' @@ -59,10 +66,10 @@ class OutputFormat(Enum): def create_sink(self, fd=sys.stdout): if self is OutputFormat.CSV: return OutputSinkCSV(fd) - elif self is OutputFormat.JSON: + if self is OutputFormat.JSON: return OutputSinkJSON(fd) - else: - raise NotImplementedError('unsupported output format: ' + str(self)) + raise NotImplementedError('unsupported output format: ' + str(self)) + def _parse_output_format(s): try: @@ -70,6 +77,7 @@ def _parse_output_format(s): except ValueError: raise argparse.ArgumentTypeError('invalid output format: ' + s) + def _parse_args(args=None): if args is None: args = sys.argv[1:] @@ -89,6 +97,7 @@ def _parse_args(args=None): return parser.parse_args(args) + def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): api = API() users = api.users_get(uids) @@ -101,8 +110,10 @@ def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): sink = out_fmt.create_sink(out_fd) sink.write_mutual_friends(mutual_friends) + def main(args=None): write_mutual_friends(**vars(_parse_args(args))) + if __name__ == '__main__': main() diff --git a/bin/online_sessions.py b/bin/online_sessions.py index 670a8c3..e0de7c9 100644 --- a/bin/online_sessions.py +++ b/bin/online_sessions.py @@ -17,6 +17,7 @@ from vk.user import UserField from .utils.bar_chart import BarChartBuilder from .utils import io + class GroupBy(Enum): USER = 'user' DATE = 'date' @@ -30,14 +31,14 @@ class GroupBy(Enum): online_streaks = OnlineSessionEnumerator(time_from, time_to) if self is GroupBy.USER: return online_streaks.group_by_user(db_reader) - elif self is GroupBy.DATE: + if self is GroupBy.DATE: return online_streaks.group_by_date(db_reader) - elif self is GroupBy.WEEKDAY: + if self is GroupBy.WEEKDAY: return online_streaks.group_by_weekday(db_reader) - elif self is GroupBy.HOUR: + if self is GroupBy.HOUR: return online_streaks.group_by_hour(db_reader) - else: - raise NotImplementedError('unsupported grouping: ' + str(self)) + raise NotImplementedError('unsupported grouping: ' + str(self)) + _OUTPUT_USER_FIELDS = ( UserField.UID, @@ -46,11 +47,13 @@ _OUTPUT_USER_FIELDS = ( UserField.DOMAIN, ) + class OutputSinkOnlineSessions(metaclass=abc.ABCMeta): @abc.abstractmethod def process_database(self, group_by, db_reader, time_from=None, time_to=None): pass + class OutputConverterCSV: @staticmethod def convert_user(user): @@ -68,6 +71,7 @@ class OutputConverterCSV: def convert_hour(hour): return [str(timedelta(hours=hour))] + class OutputSinkCSV(OutputSinkOnlineSessions): def __init__(self, fd=sys.stdout): self._writer = io.FileWriterCSV(fd) @@ -91,6 +95,7 @@ class OutputSinkCSV(OutputSinkOnlineSessions): row.append(str(duration)) self._writer.write_row(row) + class OutputConverterJSON: _DATE_FIELD = 'date' _WEEKDAY_FIELD = 'weekday' @@ -125,6 +130,7 @@ class OutputConverterJSON: obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour)) return obj + class OutputSinkJSON(OutputSinkOnlineSessions): def __init__(self, fd=sys.stdout): self._writer = io.FileWriterJSON(fd) @@ -142,7 +148,7 @@ class OutputSinkJSON(OutputSinkOnlineSessions): @staticmethod def _key_to_object(group_by, key): - if not group_by in OutputSinkJSON._CONVERT_KEY: + if group_by not in OutputSinkJSON._CONVERT_KEY: raise NotImplementedError('unsupported grouping: ' + str(group_by)) return OutputSinkJSON._CONVERT_KEY[group_by](key) @@ -154,6 +160,7 @@ class OutputSinkJSON(OutputSinkOnlineSessions): entries.append(entry) self._writer.write(entries) + class OutputConverterPlot: @staticmethod def convert_user(user): @@ -171,6 +178,7 @@ class OutputConverterPlot: def convert_hour(hour): return '{}:00'.format(hour) + class OutputSinkPlot(OutputSinkOnlineSessions): def __init__(self, fd=sys.stdout): self._fd = fd @@ -200,11 +208,11 @@ class OutputSinkPlot(OutputSinkOnlineSessions): @staticmethod def _extract_labels(group_by, durations): - return tuple(map(lambda key: OutputSinkPlot._format_key(group_by, key), durations.keys())) + return (OutputSinkPlot._format_key(group_by, key) for key in durations.keys()) @staticmethod def _extract_values(durations): - return tuple(map(OutputSinkPlot._duration_to_seconds, durations.values())) + return (OutputSinkPlot._duration_to_seconds(duration) for duration in durations.values()) def process_database( self, group_by, db_reader, time_from=None, time_to=None): @@ -219,8 +227,8 @@ class OutputSinkPlot(OutputSinkOnlineSessions): fontsize='small', rotation=30) bar_chart.set_value_label_formatter(self._format_duration) - labels = self._extract_labels(group_by, durations) - durations = self._extract_values(durations) + labels = tuple(self._extract_labels(group_by, durations)) + durations = tuple(self._extract_values(durations)) if group_by is GroupBy.HOUR: bar_chart.labels_align_middle = False @@ -237,6 +245,7 @@ class OutputSinkPlot(OutputSinkOnlineSessions): else: bar_chart.save(self._fd) + class OutputFormat(Enum): CSV = 'csv' JSON = 'json' @@ -248,38 +257,42 @@ class OutputFormat(Enum): def create_sink(self, fd=sys.stdout): if self is OutputFormat.CSV: return OutputSinkCSV(fd) - elif self is OutputFormat.JSON: + if self is OutputFormat.JSON: return OutputSinkJSON(fd) - elif self is OutputFormat.PLOT: + if self is OutputFormat.PLOT: return OutputSinkPlot(fd) - else: - raise NotImplementedError('unsupported output format: ' + str(self)) + raise NotImplementedError('unsupported output format: ' + str(self)) def open_file(self, path=None): if self is OutputFormat.PLOT: return io.open_output_binary_file(path) return io.open_output_text_file(path) + def _parse_group_by(s): try: return GroupBy(s) except ValueError: raise argparse.ArgumentTypeError('invalid "group by" value: ' + s) + def _parse_database_format(s): try: return DatabaseFormat(s) except ValueError: raise argparse.ArgumentTypeError('invalid database format: ' + s) + def _parse_output_format(s): try: return OutputFormat(s) except ValueError: raise argparse.ArgumentTypeError('invalid output format: ' + s) + _DATE_RANGE_LIMIT_FORMAT = '%Y-%m-%dT%H:%M:%SZ' + def _parse_date_range_limit(s): try: dt = datetime.strptime(s, _DATE_RANGE_LIMIT_FORMAT) @@ -289,6 +302,7 @@ def _parse_date_range_limit(s): raise argparse.ArgumentTypeError( msg.format(_DATE_RANGE_LIMIT_FORMAT, s)) + def _parse_args(args=None): if args is None: args = sys.argv[1:] @@ -324,6 +338,7 @@ def _parse_args(args=None): return parser.parse_args(args) + def process_online_sessions( db_path=None, db_fmt=DatabaseFormat.CSV, out_path=None, out_fmt=OutputFormat.CSV, @@ -343,8 +358,10 @@ def process_online_sessions( time_from=time_from, time_to=time_to) + def main(args=None): process_online_sessions(**vars(_parse_args(args))) + if __name__ == '__main__': main() diff --git a/bin/show_status.py b/bin/show_status.py index 1ca9cb5..cf59280 100644 --- a/bin/show_status.py +++ b/bin/show_status.py @@ -3,12 +3,14 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. -import argparse, sys +import argparse +import sys from vk.api import API from vk.tracking import StatusTracker from vk.tracking.db import Format as DatabaseFormat + def _parse_args(args=None): if args is None: args = sys.argv[1:] @@ -23,6 +25,7 @@ def _parse_args(args=None): return parser.parse_args(args) + def track_status(uids, log_path=None): api = API() tracker = StatusTracker(api) @@ -32,8 +35,10 @@ def track_status(uids, log_path=None): tracker.add_database_writer(log_writer) tracker.query_status(uids) + def main(args=None): track_status(**vars(_parse_args(args))) + if __name__ == '__main__': main() diff --git a/bin/track_status.py b/bin/track_status.py index 9c32ae9..2a974a5 100644 --- a/bin/track_status.py +++ b/bin/track_status.py @@ -3,15 +3,18 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. -import argparse, sys +import argparse +import sys from vk.api import API from vk.tracking import StatusTracker from vk.tracking.db import Format as DatabaseFormat + DEFAULT_TIMEOUT = StatusTracker.DEFAULT_TIMEOUT DEFAULT_DB_FORMAT = DatabaseFormat.CSV + def _parse_positive_integer(s): try: n = int(s) @@ -21,12 +24,14 @@ def _parse_positive_integer(s): raise argparse.ArgumentTypeError('must be a positive integer: ' + s) return n + def _parse_database_format(s): try: return DatabaseFormat(s) except ValueError: raise argparse.ArgumentTypeError('invalid database format: ' + s) + def _parse_args(args=None): if args is None: args = sys.argv[1:] @@ -52,6 +57,7 @@ def _parse_args(args=None): return parser.parse_args(args) + def track_status( uids, timeout=DEFAULT_TIMEOUT, log_path=None, @@ -72,8 +78,10 @@ def track_status( tracker.loop(uids) + def main(args=None): track_status(**vars(_parse_args(args))) + if __name__ == '__main__': main() diff --git a/bin/utils/bar_chart.py b/bin/utils/bar_chart.py index 522dfed..f051efc 100644 --- a/bin/utils/bar_chart.py +++ b/bin/utils/bar_chart.py @@ -7,6 +7,7 @@ import matplotlib.pyplot as plt from matplotlib import ticker import numpy as np + class BarChartBuilder: _BAR_HEIGHT = .5 @@ -130,6 +131,7 @@ class BarChartBuilder: def save(self, path): self._fig.savefig(path, bbox_inches='tight') + if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() diff --git a/bin/utils/io.py b/bin/utils/io.py index 04baa6a..bb8eef9 100644 --- a/bin/utils/io.py +++ b/bin/utils/io.py @@ -8,6 +8,7 @@ import csv import json import sys + class FileWriterJSON: def __init__(self, fd=sys.stdout): self._fd = fd @@ -16,13 +17,14 @@ class FileWriterJSON: self._fd.write(json.dumps(something, indent=3, ensure_ascii=False)) self._fd.write('\n') + class FileWriterCSV: def __init__(self, fd=sys.stdout): self._writer = csv.writer(fd, lineterminator='\n') @staticmethod def _convert_row_old_python(row): - if isinstance(row, list) or isinstance(row, tuple): + if isinstance(row, (list, tuple)): return row return list(row) @@ -31,6 +33,7 @@ class FileWriterCSV: row = self._convert_row_old_python(row) self._writer.writerow(row) + @contextmanager def _open_file(path=None, default=None, **kwargs): if path is None: @@ -39,8 +42,10 @@ def _open_file(path=None, default=None, **kwargs): with open(path, **kwargs) as fd: yield fd + def open_output_text_file(path=None): return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8') + def open_output_binary_file(path=None): return _open_file(path, default=sys.stdout, mode='wb') @@ -3,8 +3,7 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. -from collections import Iterable -from collections.abc import Mapping +from collections.abc import Iterable, Mapping from enum import Enum import json from urllib.error import URLError @@ -14,21 +13,24 @@ from urllib.request import urlopen import vk.error from vk.user import User + def _split_url(url): return urllib.parse.urlsplit(url)[:3] + def _is_empty_param_value(value): return isinstance(value, str) and not value + def _filter_empty_params(params, empty_params=False): if empty_params: return params if isinstance(params, Mapping): return {name: value for name, value in params.items() if not _is_empty_param_value(value)} - elif isinstance(params, Iterable): + if isinstance(params, Iterable): return [(name, value) for name, value in params if not _is_empty_param_value(value)] - else: - raise TypeError() + raise TypeError() + def _build_url(scheme, host, path, params=None, empty_params=False): if params is None: @@ -43,6 +45,7 @@ def _build_url(scheme, host, path, params=None, empty_params=False): path = urllib.parse.quote(path) return urllib.parse.urlunsplit((scheme, host, path, params, '')) + def _join_param_values(values): if isinstance(values, str): return values @@ -50,13 +53,16 @@ def _join_param_values(values): return ','.join(map(str, values)) return values + def _join_path(base, url): if not base.endswith('/'): base += '/' return urllib.parse.urljoin(base, url) + ACCESS_TOKEN = '9722cef09722cef09722cef071974b8cbe997229722cef0cbabfd816916af6c7bd37006' + class Version(Enum): V5_73 = '5.73' DEFAULT = V5_73 @@ -64,6 +70,7 @@ class Version(Enum): def __str__(self): return self.value + class Language(Enum): EN = 'en' DEFAULT = EN @@ -71,6 +78,7 @@ class Language(Enum): def __str__(self): return self.value + class Method(Enum): USERS_GET = 'users.get' FRIENDS_GET = 'friends.get' @@ -78,6 +86,7 @@ class Method(Enum): def __str__(self): return self.value + class CommonParameters(Enum): ACCESS_TOKEN = 'access_token' VERSION = 'v' @@ -86,6 +95,7 @@ class CommonParameters(Enum): def __str__(self): return self.value + class API: _ROOT_URL = 'https://api.vk.com/method/' @@ -136,4 +146,5 @@ class API: fields=_join_param_values(fields)) if 'items' not in response: raise vk.error.InvalidAPIResponseError(response) - return self._filter_response_with_users(response['items'], deactivated_users) + return self._filter_response_with_users(response['items'], + deactivated_users) diff --git a/vk/error.py b/vk/error.py index 3a309c3..5f71e47 100644 --- a/vk/error.py +++ b/vk/error.py @@ -3,9 +3,11 @@ # For details, see https://github.com/egor-tensin/vk-scripts. # Distributed under the MIT License. + class APIError(RuntimeError): pass + class InvalidAPIResponseError(APIError): def __init__(self, response): super().__init__() @@ -14,5 +16,6 @@ class InvalidAPIResponseError(APIError): def __str__(self): return str(self.response) + class APIConnectionError(APIError): pass diff --git a/vk/last_seen.py b/vk/last_seen.py index 25975f2..cf76db5 100644 --- a/vk/last_seen.py +++ b/vk/last_seen.py @@ -11,6 +11,7 @@ from numbers import Integral, Real from .platform import Platform + def _parse_time(x): if isinstance(x, datetime): if x.tzinfo is None or x.tzinfo.utcoffset(x) is None: @@ -20,6 +21,7 @@ def _parse_time(x): return datetime.fromtimestamp(x, tz=timezone.utc) raise TypeError() + def _parse_platform(x): if x in Platform: return x @@ -27,6 +29,7 @@ def _parse_platform(x): return Platform.from_string(x) return Platform(x) + class LastSeenField(Enum): TIME = 'time' PLATFORM = 'platform' @@ -34,6 +37,7 @@ class LastSeenField(Enum): def __str__(self): return self.value + class LastSeen(MutableMapping): @staticmethod def from_api_response(source): diff --git a/vk/platform.py b/vk/platform.py index 8bf23c5..5cd78b8 100644 --- a/vk/platform.py +++ b/vk/platform.py @@ -6,6 +6,7 @@ from enum import Enum import re + class Platform(Enum): MOBILE = 1 IPHONE = 2 @@ -43,6 +44,7 @@ class Platform(Enum): def get_descr_text_capitalized(self): return self._capitalize_first_letter(self.get_descr_text()) + _PLATFORM_DESCRIPTIONS = { Platform.MOBILE: '"mobile" web version (or unrecognized mobile app)', Platform.IPHONE: 'official iPhone app', diff --git a/vk/tracking/db/backend/csv.py b/vk/tracking/db/backend/csv.py index 4943ff2..43038e4 100644 --- a/vk/tracking/db/backend/csv.py +++ b/vk/tracking/db/backend/csv.py @@ -8,6 +8,7 @@ from ..io import FileReaderCSV, FileWriterCSV from ..record import Record from ..timestamp import Timestamp + class Writer(meta.Writer): def __init__(self, fd): self._writer = FileWriterCSV(fd) @@ -30,6 +31,7 @@ class Writer(meta.Writer): def _record_to_row(record): return [str(record.get_timestamp())] + [str(record[field]) for field in record] + class Reader(meta.Reader): def __init__(self, fd): self._reader = FileReaderCSV(fd) diff --git a/vk/tracking/db/backend/log.py b/vk/tracking/db/backend/log.py index 814cabc..d301856 100644 --- a/vk/tracking/db/backend/log.py +++ b/vk/tracking/db/backend/log.py @@ -7,6 +7,7 @@ import logging from .. import meta + class Writer(meta.Writer): def __init__(self, fd): self._logger = logging.getLogger(__file__) diff --git a/vk/tracking/db/backend/null.py b/vk/tracking/db/backend/null.py index 663af10..80a66b4 100644 --- a/vk/tracking/db/backend/null.py +++ b/vk/tracking/db/backend/null.py @@ -5,6 +5,7 @@ from .. import meta + class Writer(meta.Writer): def __init__(self): pass @@ -18,6 +19,7 @@ class Writer(meta.Writer): def on_connection_error(self, e): pass + class Reader(meta.Reader): def __init__(self): pass diff --git a/vk/tracking/db/format.py b/vk/tracking/db/format.py index f9a670c..028d403 100644 --- a/vk/tracking/db/format.py +++ b/vk/tracking/db/format.py @@ -8,6 +8,7 @@ import sys from . import backend, io + class Format(Enum): CSV = 'csv' LOG = 'log' @@ -19,22 +20,20 @@ class Format(Enum): def create_writer(self, fd=sys.stdout): if self is Format.CSV: return backend.csv.Writer(fd) - elif self is Format.LOG: + if self is Format.LOG: return backend.log.Writer(fd) - elif self is Format.NULL: + if self is Format.NULL: return backend.null.Writer() - else: - raise NotImplementedError('unsupported database format: ' + str(self)) + raise NotImplementedError('unsupported database format: ' + str(self)) def open_output_file(self, path=None): if self is Format.CSV: return self._open_output_database_file(path) - elif self is Format.LOG: + if self is Format.LOG: return self._open_output_log_file(path) - elif self is Format.NULL: + if self is Format.NULL: return self._open_output_database_file(None) - else: - raise NotImplementedError('unsupported database format: ' + str(self)) + raise NotImplementedError('unsupported database format: ' + str(self)) @staticmethod def _open_output_log_file(path): @@ -47,19 +46,17 @@ class Format(Enum): def create_reader(self, fd=sys.stdin): if self is Format.CSV: return backend.csv.Reader(fd) - elif self is Format.LOG: + if self is Format.LOG: return NotImplementedError('cannot read from a log file') - elif self is Format.NULL: + if self is Format.NULL: return backend.null.Reader() - else: - raise NotImplementedError('unsupported database format: ' + str(self)) + raise NotImplementedError('unsupported database format: ' + str(self)) def open_input_file(self, path=None): if self is Format.CSV: return io.open_input_text_file(path) - elif self is Format.LOG: + if self is Format.LOG: raise NotImplementedError('cannot read from a log file') - elif self is Format.NULL: + if self is Format.NULL: return io.open_input_text_file(None) - else: - raise NotImplementedError('unsupported database format: ' + str(self)) + raise NotImplementedError('unsupported database format: ' + str(self)) diff --git a/vk/tracking/db/io.py b/vk/tracking/db/io.py index 37d9c53..a89865f 100644 --- a/vk/tracking/db/io.py +++ b/vk/tracking/db/io.py @@ -7,6 +7,7 @@ from contextlib import contextmanager import csv import sys + class FileWriterCSV: def __init__(self, fd=sys.stdout): self._fd = fd @@ -14,7 +15,7 @@ class FileWriterCSV: @staticmethod def _convert_row_old_python(row): - if isinstance(row, list) or isinstance(row, tuple): + if isinstance(row, (list, tuple)): return row return list(row) @@ -24,6 +25,7 @@ class FileWriterCSV: self._writer.writerow(row) self._fd.flush() + class FileReaderCSV: def __init__(self, fd=sys.stdin): self._reader = csv.reader(fd) @@ -31,6 +33,7 @@ class FileReaderCSV: def __iter__(self): return iter(self._reader) + @contextmanager def _open_file(path=None, default=None, **kwargs): if path is None: @@ -39,12 +42,15 @@ def _open_file(path=None, default=None, **kwargs): with open(path, **kwargs) as fd: yield fd + _DEFAULT_ENCODING = 'utf-8' + def open_output_text_file(path=None, mode='w'): return _open_file(path, default=sys.stdout, mode=mode, encoding=_DEFAULT_ENCODING) + def open_input_text_file(path=None): return _open_file(path, default=sys.stdin, mode='r', encoding=_DEFAULT_ENCODING) diff --git a/vk/tracking/db/meta.py b/vk/tracking/db/meta.py index 024d9d8..eb15c74 100644 --- a/vk/tracking/db/meta.py +++ b/vk/tracking/db/meta.py @@ -6,6 +6,7 @@ import abc from collections.abc import Iterable + class Writer(metaclass=abc.ABCMeta): @abc.abstractmethod def on_initial_status(self, user): @@ -19,6 +20,7 @@ class Writer(metaclass=abc.ABCMeta): def on_connection_error(self, e): pass + class Reader(Iterable, metaclass=abc.ABCMeta): @abc.abstractmethod def __iter__(self): diff --git a/vk/tracking/db/record.py b/vk/tracking/db/record.py index dfd47c6..6998238 100644 --- a/vk/tracking/db/record.py +++ b/vk/tracking/db/record.py @@ -12,6 +12,7 @@ from vk.user import User, UserField from .timestamp import Timestamp + class Record(MutableMapping): FIELDS = ( UserField.UID, diff --git a/vk/tracking/db/timestamp.py b/vk/tracking/db/timestamp.py index b2219ca..1309797 100644 --- a/vk/tracking/db/timestamp.py +++ b/vk/tracking/db/timestamp.py @@ -5,6 +5,7 @@ from datetime import datetime, timezone + class Timestamp: @staticmethod def _new(): diff --git a/vk/tracking/online_sessions.py b/vk/tracking/online_sessions.py index 204e1cc..c43e11c 100644 --- a/vk/tracking/online_sessions.py +++ b/vk/tracking/online_sessions.py @@ -8,6 +8,7 @@ from collections.abc import MutableMapping from datetime import timedelta from enum import Enum + class Weekday(Enum): MONDAY = 0 TUESDAY = 1 @@ -20,6 +21,7 @@ class Weekday(Enum): def __str__(self): return self.name[0] + self.name[1:].lower() + class OnlineSessionEnumerator(MutableMapping): def __init__(self, time_from=None, time_to=None): self._records = {} diff --git a/vk/tracking/status_tracker.py b/vk/tracking/status_tracker.py index 30e0f97..b87e059 100644 --- a/vk/tracking/status_tracker.py +++ b/vk/tracking/status_tracker.py @@ -11,6 +11,7 @@ import signal import vk.error from vk.user import UserField + class StatusTracker: DEFAULT_TIMEOUT = 5 @@ -49,7 +50,8 @@ class StatusTracker: _USER_FIELDS = UserField.DOMAIN, UserField.ONLINE, UserField.LAST_SEEN, def _query_status(self, uids): - user_list = self._api.users_get(uids, self._USER_FIELDS, deactivated_users=False) + user_list = self._api.users_get(uids, self._USER_FIELDS, + deactivated_users=False) return {user.get_uid(): user for user in user_list} def _notify_status(self, user): @@ -9,31 +9,33 @@ from enum import Enum from .last_seen import LastSeen + def _parse_last_seen(x): if isinstance(x, LastSeen): return x - elif isinstance(x, Mapping): + if isinstance(x, Mapping): return LastSeen.from_api_response(x) - else: - raise TypeError() + raise TypeError() + def _parse_bool(x): if isinstance(x, str): if str(True) == x: return True - elif str(False) == x: + if str(False) == x: return False - else: - raise ValueError() - else: - return bool(x) + raise ValueError() + return bool(x) + def _parse_hidden(x): return _parse_bool(x) + def _parse_online_flag(x): return _parse_bool(x) + class UserField(Enum): UID = 'id' FIRST_NAME = 'first_name' @@ -48,6 +50,7 @@ class UserField(Enum): def __str__(self): return self.value + class DeactivationReason(Enum): DELETED = 'deleted' BANNED = 'banned' @@ -55,9 +58,11 @@ class DeactivationReason(Enum): def __str__(self): return self.value + def _parse_deactivated(s): return DeactivationReason(s) + class User(Hashable, MutableMapping): @staticmethod def from_api_response(source): |