diff options
author | Egor Tensin <Egor.Tensin@gmail.com> | 2016-06-16 23:32:24 +0300 |
---|---|---|
committer | Egor Tensin <Egor.Tensin@gmail.com> | 2016-06-16 23:32:24 +0300 |
commit | eb930123454771b80465505579d723c92b3dd84c (patch) | |
tree | 9aeded9c85edc2ae9c00ae904d252558731d93f7 | |
parent | make "last seen" timestamps timezone-aware (diff) | |
download | vk-scripts-eb930123454771b80465505579d723c92b3dd84c.tar.gz vk-scripts-eb930123454771b80465505579d723c92b3dd84c.zip |
refactoring & support more user fields
And a bunch of other minor improvements.
-rw-r--r-- | mutual_friends.py | 10 | ||||
-rw-r--r-- | vk/api.py | 4 | ||||
-rw-r--r-- | vk/user.py | 208 | ||||
-rw-r--r-- | vk/utils/tracking/db/reader/csv.py | 14 | ||||
-rw-r--r-- | vk/utils/tracking/db/record.py | 114 | ||||
-rw-r--r-- | vk/utils/tracking/db/writer/csv.py | 6 | ||||
-rw-r--r-- | vk/utils/tracking/logger.py | 2 | ||||
-rw-r--r-- | vk/utils/tracking/status_tracker.py | 4 | ||||
-rw-r--r-- | vk/utils/tracking/utils/how_much_online.py | 18 |
9 files changed, 262 insertions, 118 deletions
diff --git a/mutual_friends.py b/mutual_friends.py index 7df408a..b0bf363 100644 --- a/mutual_friends.py +++ b/mutual_friends.py @@ -9,12 +9,12 @@ import json import sys import vk.api -from vk.user import Field +from vk.user import UserField -def query_friend_list(api, user): - return api.friends_get(user.get_uid(), fields=Field.SCREEN_NAME) +OUTPUT_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME, UserField.SCREEN_NAME -OUTPUT_FIELDS = Field.UID, Field.FIRST_NAME, Field.LAST_NAME, Field.SCREEN_NAME +def query_friend_list(api, user): + return api.friends_get(user.get_uid(), fields=OUTPUT_FIELDS) def extract_output_fields(user): new_user = OrderedDict() @@ -67,7 +67,7 @@ if __name__ == '__main__': args = parser.parse_args() api = vk.api.API(vk.api.Language.EN) - users = api.users_get(args.uids, fields=Field.SCREEN_NAME) + users = api.users_get(args.uids) friend_lists = map(lambda user: frozenset(query_friend_list(api, user)), users) mutual_friends = frozenset.intersection(*friend_lists) @@ -64,13 +64,13 @@ class API: return str(xs) def users_get(self, user_ids, fields=()): - return map(User, self._call_method( + return map(User.from_api_response, self._call_method( Method.USERS_GET, user_ids=self._format_param_values(user_ids), fields=self._format_param_values(fields))) def friends_get(self, user_id, fields=()): - return map(User, self._call_method( + return map(User.from_api_response, self._call_method( Method.FRIENDS_GET, user_id=str(user_id), fields=self._format_param_values(fields))) @@ -2,11 +2,13 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +from collections import OrderedDict +from collections.abc import Hashable, Mapping, MutableMapping from datetime import datetime, timezone from enum import Enum from numbers import Real, Integral -class Field(Enum): +class UserField(Enum): UID = 'uid' FIRST_NAME = 'first_name' LAST_NAME = 'last_name' @@ -17,115 +19,187 @@ class Field(Enum): def __str__(self): return self.value -class User: - def __init__(self, impl): - self._impl = impl +class LastSeenField(Enum): + TIME = 'time' def __str__(self): - return str(self._impl) + return self.value + +class LastSeen(MutableMapping): + @staticmethod + def from_api_response(source): + instance = LastSeen() + for field in LastSeenField: + if str(field) in source: + instance[field] = source[str(field)] + return instance + + def __init__(self, fields=None): + if fields is None: + fields = OrderedDict() + self._fields = fields + + def __getitem__(self, field): + return self._fields[field] + + def __setitem__(self, field, value): + self._fields[field] = self.parse(field, value) + + def __delitem__(self, field): + del self._fields[field] + + def __iter__(self, field): + return iter(self._fields) + + def __len__(self, field): + return len(self._fields) + + @staticmethod + def parse(field, value): + if field in LastSeen._FIELD_PARSERS: + return LastSeen._FIELD_PARSERS[field](value) + else: + return LastSeen._DEFAULT_FIELD_PARSER(value) + + def _parse_time(x): + if isinstance(x, datetime): + if x.tzinfo is None or x.tzinfo.utcoffset(x) is None: + x = x.replace(tzinfo=timezone.utc) + return x + elif isinstance(x, Real) or isinstance(x, Integral): + return datetime.fromtimestamp(x, tz=timezone.utc) + else: + raise TypeError() + + _FIELD_PARSERS = { + LastSeenField.TIME: _parse_time, + } + + _DEFAULT_FIELD_PARSER = str + + def has_time(self): + return LastSeenField.TIME in self + + def get_time(self): + return self[LastSeenField.TIME] + + def set_time(self, t): + self[LastSeenField.TIME] = t + +class User(Hashable, MutableMapping): + @staticmethod + def from_api_response(source): + instance = User() + for field in UserField: + if str(field) in source: + instance[field] = source[str(field)] + return instance + + def __init__(self, fields=None): + if fields is None: + fields = OrderedDict() + self._fields = fields def __eq__(self, other): - return self.get_uid() == other.get_uid() + return self._fields == other._fields - def __hash__(self): + def __hash__(self, fields=None): return hash(self.get_uid()) + def __getitem__(self, field): + return self._fields[field] + + def __setitem__(self, field, value): + self._fields[field] = self.parse(field, value) + + def __delitem__(self, field): + del self._fields[field] + def __iter__(self): - return iter(self._impl) + return iter(self._fields) - def __contains__(self, field): - if field is Field.LAST_SEEN: - return self._has_last_seen() - return self._normalize_field(field) in self._impl + def __len__(self): + return len(self._fields) - def __getitem__(self, field): - if field is Field.LAST_SEEN: - return self._get_last_seen() - return self._impl[self._normalize_field(field)] + @staticmethod + def parse(field, value): + if field in User._FIELD_PARSERS: + return User._FIELD_PARSERS[field](value) + else: + return User._DEFAULT_FIELD_PARSER(value) - def __setitem__(self, field, value): - if field is Field.LAST_SEEN: - self._set_last_seen(value) + def _parse_last_seen(x): + if isinstance(x, LastSeen): + return x + elif isinstance(x, Mapping): + return LastSeen.from_api_response(x) else: - self._impl[self._normalize_field(field)] = value + raise TypeError() - @staticmethod - def _normalize_field(field): - if isinstance(field, Field): - return field.value - return field + _FIELD_PARSERS = { + UserField.UID: int, + UserField.ONLINE: bool, + UserField.LAST_SEEN: _parse_last_seen, + } + + _DEFAULT_FIELD_PARSER = str def get_uid(self): - return self[Field.UID] + return self[UserField.UID] def get_first_name(self): - return self[Field.FIRST_NAME] + return self[UserField.FIRST_NAME] def set_first_name(self, name): - self[Field.FIRST_NAME] = name + self[UserField.FIRST_NAME] = name def has_last_name(self): - return Field.LAST_NAME in self and self.get_last_name() + return UserField.LAST_NAME in self and self.get_last_name() def get_last_name(self): - return self[Field.LAST_NAME] + return self[UserField.LAST_NAME] def set_last_name(self, name): - self[Field.LAST_NAME] = name + self[UserField.LAST_NAME] = name def has_screen_name(self): - return Field.SCREEN_NAME in self + return UserField.SCREEN_NAME in self def get_screen_name(self): if self.has_screen_name(): - return self[Field.SCREEN_NAME] + return self[UserField.SCREEN_NAME] else: return 'id' + str(self.get_uid()) def set_screen_name(self, name): - self[Field.SCREEN_NAME] = name + self[UserField.SCREEN_NAME] = name - def has_online(self): - return Field.ONLINE in self + def has_online_flag(self): + return UserField.ONLINE in self def is_online(self): - return bool(self[Field.ONLINE]) - - def set_online(self, value=True): - self[Field.ONLINE] = value + return self[UserField.ONLINE] - @staticmethod - def _last_seen_from_timestamp(t): - return datetime.fromtimestamp(t, timezone.utc) - - @staticmethod - def _last_seen_to_timestamp(t): - if isinstance(t, datetime): - return t.timestamp() - elif isinstance(t, Real) or isinstance(t, Integral): - return t - else: - raise TypeError('"last seen" time must be either a `datetime` or a POSIX timestamp') + def is_offline(self): + return not self.is_online() - def _has_last_seen(self): - return Field.LAST_SEEN.value in self._impl and 'time' in self._impl[Field.LAST_SEEN.value] + def set_online_flag(self, value=True): + self[UserField.ONLINE] = value def has_last_seen(self): - return self._has_last_seen() + return UserField.LAST_SEEN in self - def _get_last_seen(self): - return self._last_seen_from_timestamp(self._impl[Field.LAST_SEEN.value]['time']) + def get_last_seen(self): + return self[UserField.LAST_SEEN] - def get_last_seen_utc(self): - return self._get_last_seen() + def set_last_seen(self, last_seen): + self[UserField.LAST_SEEN] = last_seen - def get_last_seen_local(self): - return self._get_last_seen().astimezone() + def get_last_seen_time(self): + return self.has_last_seen() and self.get_last_seen().has_time() - def _set_last_seen(self, t): - if Field.LAST_SEEN.value not in self._impl: - self._impl[Field.LAST_SEEN.value] = {} - self._impl[Field.LAST_SEEN.value]['time'] = self._last_seen_to_timestamp(t) + def get_last_seen_time(self): + return self[UserField.LAST_SEEN].get_time() - def set_last_seen(self, t): - self._set_last_seen(t) + def get_last_seen_time_local(self): + return self[UserField.LAST_SEEN].get_time().astimezone() diff --git a/vk/utils/tracking/db/reader/csv.py b/vk/utils/tracking/db/reader/csv.py index b66e397..e9c9407 100644 --- a/vk/utils/tracking/db/reader/csv.py +++ b/vk/utils/tracking/db/reader/csv.py @@ -2,11 +2,12 @@ # This file is licensed under the terms of the MIT License. # See LICENSE.txt for details. +from collections.abc import Iterable import csv -from ..record import Record +from ..record import Record, Timestamp -class Reader: +class Reader(Iterable): def __init__(self, path): self._fd = open(path) self._reader = csv.reader(self._fd) @@ -19,4 +20,11 @@ class Reader: self._fd.__exit__(*args) def __iter__(self): - return map(Record.from_row, self._reader) + return map(Reader._record_from_row, self._reader) + + @staticmethod + def _record_from_row(row): + record = Record(Timestamp.from_string(row[0])) + for i in range(len(Record.FIELDS)): + record[Record.FIELDS[i]] = row[i + 1] + return record diff --git a/vk/utils/tracking/db/record.py b/vk/utils/tracking/db/record.py index 7cb054f..4748a37 100644 --- a/vk/utils/tracking/db/record.py +++ b/vk/utils/tracking/db/record.py @@ -3,15 +3,51 @@ # See LICENSE.txt for details. from collections import OrderedDict +from collections.abc import MutableMapping from datetime import datetime, timezone -from vk.user import Field as UserField +from vk.user import LastSeen, User, UserField -def _gen_timestamp(): - return datetime.now(timezone.utc).replace(microsecond=0) +class Timestamp: + @staticmethod + def _new(): + return datetime.utcnow() + + @staticmethod + def _is_timezone_aware(dt): + return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None -class Record: - _USER_FIELDS = ( + @staticmethod + def _lose_timezone(dt): + if Timestamp._is_timezone_aware(dt): + return dt.astimezone(timezone.utc).replace(tzinfo=None) + return dt + + def __init__(self, dt=None): + if dt is None: + dt = self._new() + dt = dt.replace(microsecond=0) + dt = self._lose_timezone(dt) + self._dt = dt + + @staticmethod + def from_string(s): + return Timestamp(datetime.strptime(s, '%Y-%m-%dT%H:%M:%SZ')) + + def __str__(self): + return self._dt.isoformat() + 'Z' + + @staticmethod + def from_last_seen(ls): + return Timestamp(ls.get_time()) + + def to_last_seen(self): + ls = LastSeen() + ls.set_time(self._dt) + return ls + +class Record(MutableMapping): + FIELDS = ( UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME, @@ -20,48 +56,52 @@ class Record: UserField.LAST_SEEN, ) - def __init__(self, fields, timestamp=None): + def __init__(self, timestamp=None, fields=None): + if timestamp is None: + timestamp = Timestamp() + if fields is None: + fields = OrderedDict() + self._timestamp = timestamp self._fields = fields - self._timestamp = timestamp if timestamp is not None else _gen_timestamp() - - def __iter__(self): - return iter(self._fields) - - def __contains__(self, field): - return field in self._fields def __getitem__(self, field): + if field is UserField.LAST_SEEN: + return Timestamp.from_last_seen(self._fields[field]) return self._fields[field] def __setitem__(self, field, value): - self._fields[field] = value + if field is UserField.LAST_SEEN: + if isinstance(value, str): + value = Timestamp.from_string(value).to_last_seen() + elif isinstance(value, Timestamp): + value = value.to_last_seen() + elif isinstance(value, LastSeen): + pass + else: + raise TypeError() + self._fields[field] = User.parse(field, value) - def get_timestamp(self): - return self._timestamp + def __delitem__(self, field): + del self._fields[field] - @staticmethod - def _timestamp_from_string(s): - return datetime.fromtimestamp(s) + def __iter__(self): + return iter(self._fields) + + def __len__(self): + return len(self._fields) - def timestamp_to_string(self): - return self.get_timestamp().isoformat() + def get_timestamp(self): + return self._timestamp @staticmethod def from_user(user): - fields = OrderedDict() - for field in Record._USER_FIELDS: - fields[field] = user[field] - if UserField.LAST_SEEN in Record._USER_FIELDS: - fields[UserField.LAST_SEEN] = fields[UserField.LAST_SEEN].isoformat() - return Record(fields) + instance = Record() + for field in Record.FIELDS: + instance[field] = user[field] + return instance - @staticmethod - def from_row(row): - timestamp = Record._timestamp_from_string(row[0]) - fields = OrderedDict() - for i in range(len(Record._USER_FIELDS)): - fields[Record._USER_FIELDS[i]] = row[i + 1] - return Record(fields, timestamp) - - def to_row(self): - return [self.timestamp_to_string()] + [self[field] for field in self] + def to_user(self): + user = User() + for field in self: + user[field] = self[field] + return user diff --git a/vk/utils/tracking/db/writer/csv.py b/vk/utils/tracking/db/writer/csv.py index 8dc2299..8c635b4 100644 --- a/vk/utils/tracking/db/writer/csv.py +++ b/vk/utils/tracking/db/writer/csv.py @@ -36,8 +36,12 @@ class Writer: def write_record(self, user): if not self: return - self._write_row(Record.from_user(user).to_row()) + self._write_row(self._record_to_row(Record.from_user(user))) self.flush() def _write_row(self, row): self._writer.writerow(row) + + @staticmethod + def _record_to_row(record): + return [str(record.get_timestamp())] + [str(record[field]) for field in record] diff --git a/vk/utils/tracking/logger.py b/vk/utils/tracking/logger.py index 795f799..a36e679 100644 --- a/vk/utils/tracking/logger.py +++ b/vk/utils/tracking/logger.py @@ -50,7 +50,7 @@ class Logger: @staticmethod def _format_user_last_seen(user): - return '{} was last seen at {}'.format(Logger._format_user(user), user.get_last_seen_local()) + return '{} was last seen at {}'.format(Logger._format_user(user), user.get_last_seen_time_local()) @staticmethod def _format_user_went_online(user): diff --git a/vk/utils/tracking/status_tracker.py b/vk/utils/tracking/status_tracker.py index dad14c0..f208884 100644 --- a/vk/utils/tracking/status_tracker.py +++ b/vk/utils/tracking/status_tracker.py @@ -6,7 +6,7 @@ from collections import Callable import time import vk.error -from vk.user import Field +from vk.user import UserField class StatusTracker: DEFAULT_TIMEOUT = 5 @@ -38,7 +38,7 @@ class StatusTracker: if not isinstance(fn, Callable): raise TypeError() - _USER_FIELDS = Field.SCREEN_NAME, Field.ONLINE, Field.LAST_SEEN + _USER_FIELDS = UserField.SCREEN_NAME, UserField.ONLINE, UserField.LAST_SEEN def _query_status(self, uids): return {user.get_uid(): user for user in self._api.users_get(uids, self._USER_FIELDS)} diff --git a/vk/utils/tracking/utils/how_much_online.py b/vk/utils/tracking/utils/how_much_online.py new file mode 100644 index 0000000..c1cf05f --- /dev/null +++ b/vk/utils/tracking/utils/how_much_online.py @@ -0,0 +1,18 @@ +# Copyright 2016 Egor Tensin <Egor.Tensin@gmail.com> +# This file is licensed under the terms of the MIT License. +# See LICENSE.txt for details. + +from vk.utils.tracking.db.reader import * + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument('input', help='status database path') + + args = parser.parse_args() + + with csv.Reader(args.input) as csv_reader: + for record in csv_reader: + print(record.get_timestamp()) |