From 9a724815823e034e384f7a0d2de946ef6e090487 Mon Sep 17 00:00:00 2001 From: Egor Tensin Date: Mon, 3 May 2021 21:24:16 +0300 Subject: move scripts from bin/ to vk/ This is done in preparation to moving to PyPI. TODO: * update docs, * merge/rename some scripts. --- .ci/bin/mutual_friends.sh | 4 +- .ci/bin/online_sessions.sh | 2 +- .ci/bin/show_status.sh | 2 +- .ci/bin/track_status.sh | 2 +- bin/__init__.py | 0 bin/mutual_friends.py | 119 -------------- bin/online_sessions.py | 367 -------------------------------------------- bin/show_status.py | 44 ------ bin/track_status.py | 87 ----------- bin/utils/__init__.py | 0 bin/utils/bar_chart.py | 182 ---------------------- bin/utils/io.py | 51 ------ vk/mutuals.py | 118 ++++++++++++++ vk/tracking/sessions.py | 366 +++++++++++++++++++++++++++++++++++++++++++ vk/tracking/show_status.py | 44 ++++++ vk/tracking/track_status.py | 87 +++++++++++ vk/utils/__init__.py | 0 vk/utils/bar_chart.py | 182 ++++++++++++++++++++++ vk/utils/io.py | 51 ++++++ 19 files changed, 853 insertions(+), 855 deletions(-) delete mode 100644 bin/__init__.py delete mode 100644 bin/mutual_friends.py delete mode 100644 bin/online_sessions.py delete mode 100644 bin/show_status.py delete mode 100644 bin/track_status.py delete mode 100644 bin/utils/__init__.py delete mode 100644 bin/utils/bar_chart.py delete mode 100644 bin/utils/io.py create mode 100644 vk/mutuals.py create mode 100644 vk/tracking/sessions.py create mode 100644 vk/tracking/show_status.py create mode 100644 vk/tracking/track_status.py create mode 100644 vk/utils/__init__.py create mode 100644 vk/utils/bar_chart.py create mode 100644 vk/utils/io.py diff --git a/.ci/bin/mutual_friends.sh b/.ci/bin/mutual_friends.sh index 4a00fb8..383dbdd 100755 --- a/.ci/bin/mutual_friends.sh +++ b/.ci/bin/mutual_friends.sh @@ -12,8 +12,8 @@ script_dir="$( cd -- "$script_dir" && pwd )" readonly script_dir test_users() { - "$script_dir/../lib/test.sh" bin.mutual_friends --format csv "$@" - "$script_dir/../lib/test.sh" bin.mutual_friends --format json "$@" + "$script_dir/../lib/test.sh" vk.mutuals --format csv "$@" + "$script_dir/../lib/test.sh" vk.mutuals --format json "$@" } main() { diff --git a/.ci/bin/online_sessions.sh b/.ci/bin/online_sessions.sh index d64ef2f..7a01dc3 100755 --- a/.ci/bin/online_sessions.sh +++ b/.ci/bin/online_sessions.sh @@ -31,7 +31,7 @@ test_output() { trap "$rm_aux_files" RETURN - "$script_dir/../lib/test.sh" bin.online_sessions "$@" "$db_path" "$output_path" + "$script_dir/../lib/test.sh" vk.tracking.sessions "$@" "$db_path" "$output_path" if file --brief --dereference --mime -- "$output_path" | grep --quiet -- 'charset=binary$'; then dump 'Output is a binary file, not going to show that' diff --git a/.ci/bin/show_status.sh b/.ci/bin/show_status.sh index c9f8e26..b3f4cfe 100755 --- a/.ci/bin/show_status.sh +++ b/.ci/bin/show_status.sh @@ -12,7 +12,7 @@ script_dir="$( cd -- "$script_dir" && pwd )" readonly script_dir test_users() { - "$script_dir/../lib/test.sh" bin.show_status "$@" + "$script_dir/../lib/test.sh" vk.tracking.show_status "$@" } main() { diff --git a/.ci/bin/track_status.sh b/.ci/bin/track_status.sh index 1a7a7dc..9118144 100755 --- a/.ci/bin/track_status.sh +++ b/.ci/bin/track_status.sh @@ -30,7 +30,7 @@ test_users() { rm_aux_files="$( printf -- 'rm -f -- %q %q' "$log_path" "$db_path" )" trap "$rm_aux_files" RETURN - "$script_dir/../lib/test.sh" bin.track_status "$@" --log "$log_path" --format csv --output "$db_path" & + "$script_dir/../lib/test.sh" vk.tracking.track_status "$@" --log "$log_path" --format csv --output "$db_path" & local pid="$!" sleep 3 diff --git a/bin/__init__.py b/bin/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bin/mutual_friends.py b/bin/mutual_friends.py deleted file mode 100644 index 550d957..0000000 --- a/bin/mutual_friends.py +++ /dev/null @@ -1,119 +0,0 @@ -# Copyright (c) 2015 Egor Tensin -# This file is part of the "VK scripts" project. -# For details, see https://github.com/egor-tensin/vk-scripts. -# Distributed under the MIT License. - -import abc -import argparse -from collections import OrderedDict -from enum import Enum -import sys - -from vk.api import API -from vk.user import UserField - -from .utils import io - - -_OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME - - -def _query_friend_list(api, user): - return api.friends_get(user.get_uid(), fields=_OUTPUT_USER_FIELDS) - - -def _filter_user_fields(user): - new_user = OrderedDict() - for field in _OUTPUT_USER_FIELDS: - new_user[str(field)] = user[field] if field in user else None - return new_user - - -class OutputSinkMutualFriends(metaclass=abc.ABCMeta): - @abc.abstractmethod - def write_mutual_friends(self, friend_list): - pass - - -class OutputSinkCSV(OutputSinkMutualFriends): - def __init__(self, fd=sys.stdout): - self._writer = io.FileWriterCSV(fd) - - def write_mutual_friends(self, friend_list): - for user in friend_list: - self._writer.write_row(user.values()) - - -class OutputSinkJSON(OutputSinkMutualFriends): - def __init__(self, fd=sys.stdout): - self._writer = io.FileWriterJSON(fd) - - def write_mutual_friends(self, friend_list): - self._writer.write(friend_list) - - -class OutputFormat(Enum): - CSV = 'csv' - JSON = 'json' - - def __str__(self): - return self.value - - @staticmethod - def open_file(path=None): - return io.open_output_text_file(path) - - def create_sink(self, fd=sys.stdout): - if self is OutputFormat.CSV: - return OutputSinkCSV(fd) - if self is OutputFormat.JSON: - return OutputSinkJSON(fd) - raise NotImplementedError('unsupported output format: ' + str(self)) - - -def _parse_output_format(s): - try: - return OutputFormat(s) - except ValueError: - raise argparse.ArgumentTypeError('invalid output format: ' + s) - - -def _parse_args(args=None): - if args is None: - args = sys.argv[1:] - - parser = argparse.ArgumentParser( - description='Learn who your ex and her new boyfriend are both friends with.') - - parser.add_argument('uids', metavar='UID', nargs='+', - help='user IDs or "screen names"') - parser.add_argument('-f', '--format', dest='out_fmt', - type=_parse_output_format, - default=OutputFormat.CSV, - choices=OutputFormat, - help='specify output format') - parser.add_argument('-o', '--output', metavar='PATH', dest='out_path', - help='set output file path (standard output by default)') - - return parser.parse_args(args) - - -def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): - api = API() - users = api.users_get(uids) - - friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) - mutual_friends = frozenset.intersection(*friend_lists) - mutual_friends = [_filter_user_fields(user) for user in mutual_friends] - - with out_fmt.open_file(out_path) as out_fd: - sink = out_fmt.create_sink(out_fd) - sink.write_mutual_friends(mutual_friends) - - -def main(args=None): - write_mutual_friends(**vars(_parse_args(args))) - - -if __name__ == '__main__': - main() diff --git a/bin/online_sessions.py b/bin/online_sessions.py deleted file mode 100644 index e0de7c9..0000000 --- a/bin/online_sessions.py +++ /dev/null @@ -1,367 +0,0 @@ -# Copyright (c) 2016 Egor Tensin -# This file is part of the "VK scripts" project. -# For details, see https://github.com/egor-tensin/vk-scripts. -# Distributed under the MIT License. - -import abc -import argparse -from collections import OrderedDict -from datetime import datetime, timedelta, timezone -from enum import Enum -import sys - -from vk.tracking import OnlineSessionEnumerator -from vk.tracking.db import Format as DatabaseFormat -from vk.user import UserField - -from .utils.bar_chart import BarChartBuilder -from .utils import io - - -class GroupBy(Enum): - USER = 'user' - DATE = 'date' - WEEKDAY = 'weekday' - HOUR = 'hour' - - def __str__(self): - return self.value - - def group(self, db_reader, time_from=None, time_to=None): - online_streaks = OnlineSessionEnumerator(time_from, time_to) - if self is GroupBy.USER: - return online_streaks.group_by_user(db_reader) - if self is GroupBy.DATE: - return online_streaks.group_by_date(db_reader) - if self is GroupBy.WEEKDAY: - return online_streaks.group_by_weekday(db_reader) - if self is GroupBy.HOUR: - return online_streaks.group_by_hour(db_reader) - raise NotImplementedError('unsupported grouping: ' + str(self)) - - -_OUTPUT_USER_FIELDS = ( - UserField.UID, - UserField.FIRST_NAME, - UserField.LAST_NAME, - UserField.DOMAIN, -) - - -class OutputSinkOnlineSessions(metaclass=abc.ABCMeta): - @abc.abstractmethod - def process_database(self, group_by, db_reader, time_from=None, time_to=None): - pass - - -class OutputConverterCSV: - @staticmethod - def convert_user(user): - return [user[field] for field in _OUTPUT_USER_FIELDS] - - @staticmethod - def convert_date(date): - return [str(date)] - - @staticmethod - def convert_weekday(weekday): - return [str(weekday)] - - @staticmethod - def convert_hour(hour): - return [str(timedelta(hours=hour))] - - -class OutputSinkCSV(OutputSinkOnlineSessions): - def __init__(self, fd=sys.stdout): - self._writer = io.FileWriterCSV(fd) - - _CONVERT_KEY = { - GroupBy.USER: OutputConverterCSV.convert_user, - GroupBy.DATE: OutputConverterCSV.convert_date, - GroupBy.WEEKDAY: OutputConverterCSV.convert_weekday, - GroupBy.HOUR: OutputConverterCSV.convert_hour, - } - - @staticmethod - def _key_to_row(group_by, key): - if group_by not in OutputSinkCSV._CONVERT_KEY: - raise NotImplementedError('unsupported grouping: ' + str(group_by)) - return OutputSinkCSV._CONVERT_KEY[group_by](key) - - def process_database(self, group_by, db_reader, time_from=None, time_to=None): - for key, duration in group_by.group(db_reader, time_from, time_to).items(): - row = self._key_to_row(group_by, key) - row.append(str(duration)) - self._writer.write_row(row) - - -class OutputConverterJSON: - _DATE_FIELD = 'date' - _WEEKDAY_FIELD = 'weekday' - _HOUR_FIELD = 'hour' - - assert _DATE_FIELD not in map(str, _OUTPUT_USER_FIELDS) - assert _WEEKDAY_FIELD not in map(str, _OUTPUT_USER_FIELDS) - assert _HOUR_FIELD not in map(str, _OUTPUT_USER_FIELDS) - - @staticmethod - def convert_user(user): - obj = OrderedDict() - for field in _OUTPUT_USER_FIELDS: - obj[str(field)] = user[field] - return obj - - @staticmethod - def convert_date(date): - obj = OrderedDict() - obj[OutputConverterJSON._DATE_FIELD] = str(date) - return obj - - @staticmethod - def convert_weekday(weekday): - obj = OrderedDict() - obj[OutputConverterJSON._WEEKDAY_FIELD] = str(weekday) - return obj - - @staticmethod - def convert_hour(hour): - obj = OrderedDict() - obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour)) - return obj - - -class OutputSinkJSON(OutputSinkOnlineSessions): - def __init__(self, fd=sys.stdout): - self._writer = io.FileWriterJSON(fd) - - _DURATION_FIELD = 'duration' - - assert _DURATION_FIELD not in map(str, _OUTPUT_USER_FIELDS) - - _CONVERT_KEY = { - GroupBy.USER: OutputConverterJSON.convert_user, - GroupBy.DATE: OutputConverterJSON.convert_date, - GroupBy.WEEKDAY: OutputConverterJSON.convert_weekday, - GroupBy.HOUR: OutputConverterJSON.convert_hour, - } - - @staticmethod - def _key_to_object(group_by, key): - if group_by not in OutputSinkJSON._CONVERT_KEY: - raise NotImplementedError('unsupported grouping: ' + str(group_by)) - return OutputSinkJSON._CONVERT_KEY[group_by](key) - - def process_database(self, group_by, db_reader, time_from=None, time_to=None): - entries = [] - for key, duration in group_by.group(db_reader, time_from, time_to).items(): - entry = self._key_to_object(group_by, key) - entry[self._DURATION_FIELD] = str(duration) - entries.append(entry) - self._writer.write(entries) - - -class OutputConverterPlot: - @staticmethod - def convert_user(user): - return '{}\n{}'.format(user.get_first_name(), user.get_last_name()) - - @staticmethod - def convert_date(date): - return str(date) - - @staticmethod - def convert_weekday(weekday): - return str(weekday) - - @staticmethod - def convert_hour(hour): - return '{}:00'.format(hour) - - -class OutputSinkPlot(OutputSinkOnlineSessions): - def __init__(self, fd=sys.stdout): - self._fd = fd - - TITLE = 'How much time people spend online' - - _FORMAT_KEY = { - GroupBy.USER: OutputConverterPlot.convert_user, - GroupBy.DATE: OutputConverterPlot.convert_date, - GroupBy.WEEKDAY: OutputConverterPlot.convert_weekday, - GroupBy.HOUR: OutputConverterPlot.convert_hour, - } - - @staticmethod - def _format_key(group_by, key): - if group_by not in OutputSinkPlot._FORMAT_KEY: - raise NotImplementedError('unsupported grouping: ' + str(group_by)) - return OutputSinkPlot._FORMAT_KEY[group_by](key) - - @staticmethod - def _format_duration(seconds, _): - return str(timedelta(seconds=seconds)) - - @staticmethod - def _duration_to_seconds(td): - return td.total_seconds() - - @staticmethod - def _extract_labels(group_by, durations): - return (OutputSinkPlot._format_key(group_by, key) for key in durations.keys()) - - @staticmethod - def _extract_values(durations): - return (OutputSinkPlot._duration_to_seconds(duration) for duration in durations.values()) - - def process_database( - self, group_by, db_reader, time_from=None, time_to=None): - - durations = group_by.group(db_reader, time_from, time_to) - - bar_chart = BarChartBuilder() - bar_chart.set_title(OutputSinkPlot.TITLE) - bar_chart.enable_grid_for_values() - bar_chart.only_integer_values() - bar_chart.set_property(bar_chart.get_values_labels(), - fontsize='small', rotation=30) - bar_chart.set_value_label_formatter(self._format_duration) - - labels = tuple(self._extract_labels(group_by, durations)) - durations = tuple(self._extract_values(durations)) - - if group_by is GroupBy.HOUR: - bar_chart.labels_align_middle = False - bar_height = bar_chart.THIN_BAR_HEIGHT - else: - bar_height = bar_chart.THICK_BAR_HEIGHT - - bars = bar_chart.plot_bars( - labels, durations, bar_height=bar_height) - bar_chart.set_property(bars, alpha=.33) - - if self._fd is sys.stdout: - bar_chart.show() - else: - bar_chart.save(self._fd) - - -class OutputFormat(Enum): - CSV = 'csv' - JSON = 'json' - PLOT = 'plot' - - def __str__(self): - return self.value - - def create_sink(self, fd=sys.stdout): - if self is OutputFormat.CSV: - return OutputSinkCSV(fd) - if self is OutputFormat.JSON: - return OutputSinkJSON(fd) - if self is OutputFormat.PLOT: - return OutputSinkPlot(fd) - raise NotImplementedError('unsupported output format: ' + str(self)) - - def open_file(self, path=None): - if self is OutputFormat.PLOT: - return io.open_output_binary_file(path) - return io.open_output_text_file(path) - - -def _parse_group_by(s): - try: - return GroupBy(s) - except ValueError: - raise argparse.ArgumentTypeError('invalid "group by" value: ' + s) - - -def _parse_database_format(s): - try: - return DatabaseFormat(s) - except ValueError: - raise argparse.ArgumentTypeError('invalid database format: ' + s) - - -def _parse_output_format(s): - try: - return OutputFormat(s) - except ValueError: - raise argparse.ArgumentTypeError('invalid output format: ' + s) - - -_DATE_RANGE_LIMIT_FORMAT = '%Y-%m-%dT%H:%M:%SZ' - - -def _parse_date_range_limit(s): - try: - dt = datetime.strptime(s, _DATE_RANGE_LIMIT_FORMAT) - return dt.replace(tzinfo=timezone.utc) - except ValueError: - msg = 'invalid date range limit (must be in the \'{}\' format): {}' - raise argparse.ArgumentTypeError( - msg.format(_DATE_RANGE_LIMIT_FORMAT, s)) - - -def _parse_args(args=None): - if args is None: - args = sys.argv[1:] - - parser = argparse.ArgumentParser( - description='View/visualize the amount of time people spend online.') - - parser.add_argument('db_path', metavar='input', nargs='?', - help='database file path (standard input by default)') - parser.add_argument('out_path', metavar='output', nargs='?', - help='output file path (standard output by default)') - parser.add_argument('-g', '--group-by', - type=_parse_group_by, - choices=GroupBy, - default=GroupBy.USER, - help='group online sessions by user/date/etc.') - parser.add_argument('-i', '--input-format', dest='db_fmt', - type=_parse_database_format, - default=DatabaseFormat.CSV, - choices=DatabaseFormat, - help='specify database format') - parser.add_argument('-o', '--output-format', dest='out_fmt', - type=_parse_output_format, - choices=OutputFormat, - default=OutputFormat.CSV, - help='specify output format') - parser.add_argument('-a', '--from', dest='time_from', - type=_parse_date_range_limit, default=None, - help='discard online activity prior to this moment') - parser.add_argument('-b', '--to', dest='time_to', - type=_parse_date_range_limit, default=None, - help='discard online activity after this moment') - - return parser.parse_args(args) - - -def process_online_sessions( - db_path=None, db_fmt=DatabaseFormat.CSV, - out_path=None, out_fmt=OutputFormat.CSV, - group_by=GroupBy.USER, - time_from=None, time_to=None): - - if time_from is not None and time_to is not None: - if time_from > time_to: - time_from, time_to = time_to, time_from - - with db_fmt.open_input_file(db_path) as db_fd: - db_reader = db_fmt.create_reader(db_fd) - with out_fmt.open_file(out_path) as out_fd: - out_sink = out_fmt.create_sink(out_fd) - out_sink.process_database( - group_by, db_reader, - time_from=time_from, - time_to=time_to) - - -def main(args=None): - process_online_sessions(**vars(_parse_args(args))) - - -if __name__ == '__main__': - main() diff --git a/bin/show_status.py b/bin/show_status.py deleted file mode 100644 index cf59280..0000000 --- a/bin/show_status.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright (c) 2019 Egor Tensin -# This file is part of the "VK scripts" project. -# For details, see https://github.com/egor-tensin/vk-scripts. -# Distributed under the MIT License. - -import argparse -import sys - -from vk.api import API -from vk.tracking import StatusTracker -from vk.tracking.db import Format as DatabaseFormat - - -def _parse_args(args=None): - if args is None: - args = sys.argv[1:] - - parser = argparse.ArgumentParser( - description='Show if people are online/offline.') - - parser.add_argument('uids', metavar='UID', nargs='+', - help='user IDs or "screen names"') - parser.add_argument('-l', '--log', metavar='PATH', dest='log_path', - help='set log file path (standard output by default)') - - return parser.parse_args(args) - - -def track_status(uids, log_path=None): - api = API() - tracker = StatusTracker(api) - - with DatabaseFormat.LOG.open_output_file(log_path) as log_fd: - log_writer = DatabaseFormat.LOG.create_writer(log_fd) - tracker.add_database_writer(log_writer) - tracker.query_status(uids) - - -def main(args=None): - track_status(**vars(_parse_args(args))) - - -if __name__ == '__main__': - main() diff --git a/bin/track_status.py b/bin/track_status.py deleted file mode 100644 index 2a974a5..0000000 --- a/bin/track_status.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) 2015 Egor Tensin -# This file is part of the "VK scripts" project. -# For details, see https://github.com/egor-tensin/vk-scripts. -# Distributed under the MIT License. - -import argparse -import sys - -from vk.api import API -from vk.tracking import StatusTracker -from vk.tracking.db import Format as DatabaseFormat - - -DEFAULT_TIMEOUT = StatusTracker.DEFAULT_TIMEOUT -DEFAULT_DB_FORMAT = DatabaseFormat.CSV - - -def _parse_positive_integer(s): - try: - n = int(s) - except ValueError: - raise argparse.ArgumentTypeError('must be a positive integer: ' + s) - if n < 1: - raise argparse.ArgumentTypeError('must be a positive integer: ' + s) - return n - - -def _parse_database_format(s): - try: - return DatabaseFormat(s) - except ValueError: - raise argparse.ArgumentTypeError('invalid database format: ' + s) - - -def _parse_args(args=None): - if args is None: - args = sys.argv[1:] - - parser = argparse.ArgumentParser( - description='Track when people go online/offline.') - - parser.add_argument('uids', metavar='UID', nargs='+', - help='user IDs or "screen names"') - parser.add_argument('-t', '--timeout', metavar='SECONDS', - type=_parse_positive_integer, - default=DEFAULT_TIMEOUT, - help='set refresh interval') - parser.add_argument('-l', '--log', metavar='PATH', dest='log_path', - help='set log file path (standard output by default)') - parser.add_argument('-f', '--format', dest='db_fmt', - type=_parse_database_format, - choices=DatabaseFormat, - default=DEFAULT_DB_FORMAT, - help='specify database format') - parser.add_argument('-o', '--output', metavar='PATH', dest='db_path', - help='set database file path') - - return parser.parse_args(args) - - -def track_status( - uids, timeout=DEFAULT_TIMEOUT, - log_path=None, - db_path=None, db_fmt=DEFAULT_DB_FORMAT): - - api = API() - tracker = StatusTracker(api, timeout) - - if db_fmt is DatabaseFormat.LOG or db_path is None: - db_fmt = DatabaseFormat.NULL - - with DatabaseFormat.LOG.open_output_file(log_path) as log_fd: - log_writer = DatabaseFormat.LOG.create_writer(log_fd) - tracker.add_database_writer(log_writer) - with db_fmt.open_output_file(db_path) as db_fd: - db_writer = db_fmt.create_writer(db_fd) - tracker.add_database_writer(db_writer) - - tracker.loop(uids) - - -def main(args=None): - track_status(**vars(_parse_args(args))) - - -if __name__ == '__main__': - main() diff --git a/bin/utils/__init__.py b/bin/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bin/utils/bar_chart.py b/bin/utils/bar_chart.py deleted file mode 100644 index f051efc..0000000 --- a/bin/utils/bar_chart.py +++ /dev/null @@ -1,182 +0,0 @@ -# Copyright (c) 2017 Egor Tensin -# This file is part of the "VK scripts" project. -# For details, see https://github.com/egor-tensin/vk-scripts. -# Distributed under the MIT License. - -import matplotlib.pyplot as plt -from matplotlib import ticker -import numpy as np - - -class BarChartBuilder: - _BAR_HEIGHT = .5 - - THICK_BAR_HEIGHT = _BAR_HEIGHT - THIN_BAR_HEIGHT = THICK_BAR_HEIGHT / 2 - - def __init__(self, labels_align_middle=True): - self._fig, self._ax = plt.subplots() - self.labels_align_middle = labels_align_middle - - def set_title(self, title): - self._ax.set_title(title) - - def _get_categories_axis(self): - return self._ax.get_yaxis() - - def _get_values_axis(self): - return self._ax.get_xaxis() - - def set_categories_axis_limits(self, start=None, end=None): - bottom, top = self._ax.get_ylim() - if start is not None: - bottom = start - if end is not None: - top = end - self._ax.set_ylim(bottom=bottom, top=top) - - def set_values_axis_limits(self, start=None, end=None): - left, right = self._ax.get_xlim() - if start is not None: - left = start - if end is not None: - right = end - self._ax.set_xlim(left=left, right=right) - - def enable_grid_for_categories(self): - self._get_categories_axis().grid() - - def enable_grid_for_values(self): - self._get_values_axis().grid() - - def get_categories_labels(self): - return self._get_categories_axis().get_ticklabels() - - def get_values_labels(self): - return self._get_values_axis().get_ticklabels() - - def hide_categories(self): - self._get_categories_axis().set_major_locator(ticker.NullLocator()) - - def set_value_label_formatter(self, fn): - self._get_values_axis().set_major_formatter(ticker.FuncFormatter(fn)) - - def any_values(self): - self._get_values_axis().set_major_locator(ticker.AutoLocator()) - - def only_integer_values(self): - self._get_values_axis().set_major_locator(ticker.MaxNLocator(integer=True)) - - @staticmethod - def set_property(*args, **kwargs): - plt.setp(*args, **kwargs) - - def _set_size(self, inches, dim=0): - fig_size = self._fig.get_size_inches() - assert len(fig_size) == 2 - fig_size[dim] = inches - self._fig.set_size_inches(fig_size, forward=True) - - def set_width(self, inches): - self._set_size(inches) - - def set_height(self, inches): - self._set_size(inches, dim=1) - - _DEFAULT_VALUES_AXIS_MAX = 1 - assert _DEFAULT_VALUES_AXIS_MAX > 0 - - def plot_bars(self, categories, values, bar_height=THICK_BAR_HEIGHT): - numof_bars = len(categories) - inches_per_bar = 2 * bar_height - categories_axis_max = inches_per_bar * numof_bars - - if not numof_bars: - categories_axis_max += inches_per_bar - - self.set_height(categories_axis_max) - self.set_categories_axis_limits(0, categories_axis_max) - - if not numof_bars: - self.set_values_axis_limits(0, self._DEFAULT_VALUES_AXIS_MAX) - self.hide_categories() - return [] - - bar_offset = inches_per_bar / 2 - bar_offsets = inches_per_bar * np.arange(numof_bars) + bar_offset - - if self.labels_align_middle: - self._get_categories_axis().set_ticks(bar_offsets) - else: - self._get_categories_axis().set_ticks(bar_offsets - bar_offset) - - self._get_categories_axis().set_ticklabels(categories) - - bars = self._ax.barh(bar_offsets, values, align='center', - height=bar_height) - - if min(values) >= 0: - self.set_values_axis_limits(start=0) - if np.isclose(max(values), 0.): - self.set_values_axis_limits(end=self._DEFAULT_VALUES_AXIS_MAX) - elif max(values) < 0: - self.set_values_axis_limits(end=0) - - return bars - - @staticmethod - def show(): - plt.show() - - def save(self, path): - self._fig.savefig(path, bbox_inches='tight') - - -if __name__ == '__main__': - import argparse - parser = argparse.ArgumentParser() - - parser.add_argument('--categories', nargs='*', metavar='LABEL', - default=[]) - parser.add_argument('--values', nargs='*', metavar='N', - default=[], type=float) - - parser.add_argument('--output', '-o', help='set output file path') - - parser.add_argument('--align-middle', action='store_true', - dest='labels_align_middle', - help='align labels to the middle of the bars') - - parser.add_argument('--integer-values', action='store_true', - dest='only_integer_values') - parser.add_argument('--any-values', action='store_false', - dest='only_integer_values') - - parser.add_argument('--grid-categories', action='store_true') - parser.add_argument('--grid-values', action='store_true') - - args = parser.parse_args() - - if len(args.categories) < len(args.values): - parser.error('too many bar values') - if len(args.categories) > len(args.values): - args.values.extend([0.] * (len(args.categories) - len(args.values))) - - builder = BarChartBuilder(labels_align_middle=args.labels_align_middle) - - if args.only_integer_values: - builder.only_integer_values() - else: - builder.any_values() - - if args.grid_categories: - builder.enable_grid_for_categories() - if args.grid_values: - builder.enable_grid_for_values() - - builder.plot_bars(args.categories, args.values) - - if args.output is None: - builder.show() - else: - builder.save(args.output) diff --git a/bin/utils/io.py b/bin/utils/io.py deleted file mode 100644 index bb8eef9..0000000 --- a/bin/utils/io.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) 2017 Egor Tensin -# This file is part of the "VK scripts" project. -# For details, see https://github.com/egor-tensin/vk-scripts. -# Distributed under the MIT License. - -from contextlib import contextmanager -import csv -import json -import sys - - -class FileWriterJSON: - def __init__(self, fd=sys.stdout): - self._fd = fd - - def write(self, something): - self._fd.write(json.dumps(something, indent=3, ensure_ascii=False)) - self._fd.write('\n') - - -class FileWriterCSV: - def __init__(self, fd=sys.stdout): - self._writer = csv.writer(fd, lineterminator='\n') - - @staticmethod - def _convert_row_old_python(row): - if isinstance(row, (list, tuple)): - return row - return list(row) - - def write_row(self, row): - if sys.version_info < (3, 5): - row = self._convert_row_old_python(row) - self._writer.writerow(row) - - -@contextmanager -def _open_file(path=None, default=None, **kwargs): - if path is None: - yield default - else: - with open(path, **kwargs) as fd: - yield fd - - -def open_output_text_file(path=None): - return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8') - - -def open_output_binary_file(path=None): - return _open_file(path, default=sys.stdout, mode='wb') diff --git a/vk/mutuals.py b/vk/mutuals.py new file mode 100644 index 0000000..644baf8 --- /dev/null +++ b/vk/mutuals.py @@ -0,0 +1,118 @@ +# Copyright (c) 2015 Egor Tensin +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +import abc +import argparse +from collections import OrderedDict +from enum import Enum +import sys + +from vk.api import API +from vk.user import UserField +from vk.utils import io + + +_OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME + + +def _query_friend_list(api, user): + return api.friends_get(user.get_uid(), fields=_OUTPUT_USER_FIELDS) + + +def _filter_user_fields(user): + new_user = OrderedDict() + for field in _OUTPUT_USER_FIELDS: + new_user[str(field)] = user[field] if field in user else None + return new_user + + +class OutputSinkMutualFriends(metaclass=abc.ABCMeta): + @abc.abstractmethod + def write_mutual_friends(self, friend_list): + pass + + +class OutputSinkCSV(OutputSinkMutualFriends): + def __init__(self, fd=sys.stdout): + self._writer = io.FileWriterCSV(fd) + + def write_mutual_friends(self, friend_list): + for user in friend_list: + self._writer.write_row(user.values()) + + +class OutputSinkJSON(OutputSinkMutualFriends): + def __init__(self, fd=sys.stdout): + self._writer = io.FileWriterJSON(fd) + + def write_mutual_friends(self, friend_list): + self._writer.write(friend_list) + + +class OutputFormat(Enum): + CSV = 'csv' + JSON = 'json' + + def __str__(self): + return self.value + + @staticmethod + def open_file(path=None): + return io.open_output_text_file(path) + + def create_sink(self, fd=sys.stdout): + if self is OutputFormat.CSV: + return OutputSinkCSV(fd) + if self is OutputFormat.JSON: + return OutputSinkJSON(fd) + raise NotImplementedError('unsupported output format: ' + str(self)) + + +def _parse_output_format(s): + try: + return OutputFormat(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid output format: ' + s) + + +def _parse_args(args=None): + if args is None: + args = sys.argv[1:] + + parser = argparse.ArgumentParser( + description='Learn who your ex and her new boyfriend are both friends with.') + + parser.add_argument('uids', metavar='UID', nargs='+', + help='user IDs or "screen names"') + parser.add_argument('-f', '--format', dest='out_fmt', + type=_parse_output_format, + default=OutputFormat.CSV, + choices=OutputFormat, + help='specify output format') + parser.add_argument('-o', '--output', metavar='PATH', dest='out_path', + help='set output file path (standard output by default)') + + return parser.parse_args(args) + + +def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV): + api = API() + users = api.users_get(uids) + + friend_lists = (frozenset(_query_friend_list(api, user)) for user in users) + mutual_friends = frozenset.intersection(*friend_lists) + mutual_friends = [_filter_user_fields(user) for user in mutual_friends] + + with out_fmt.open_file(out_path) as out_fd: + sink = out_fmt.create_sink(out_fd) + sink.write_mutual_friends(mutual_friends) + + +def main(args=None): + write_mutual_friends(**vars(_parse_args(args))) + + +if __name__ == '__main__': + main() diff --git a/vk/tracking/sessions.py b/vk/tracking/sessions.py new file mode 100644 index 0000000..dd8b32f --- /dev/null +++ b/vk/tracking/sessions.py @@ -0,0 +1,366 @@ +# Copyright (c) 2016 Egor Tensin +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +import abc +import argparse +from collections import OrderedDict +from datetime import datetime, timedelta, timezone +from enum import Enum +import sys + +from vk.tracking import OnlineSessionEnumerator +from vk.tracking.db import Format as DatabaseFormat +from vk.user import UserField +from vk.utils.bar_chart import BarChartBuilder +from vk.utils import io + + +class GroupBy(Enum): + USER = 'user' + DATE = 'date' + WEEKDAY = 'weekday' + HOUR = 'hour' + + def __str__(self): + return self.value + + def group(self, db_reader, time_from=None, time_to=None): + online_streaks = OnlineSessionEnumerator(time_from, time_to) + if self is GroupBy.USER: + return online_streaks.group_by_user(db_reader) + if self is GroupBy.DATE: + return online_streaks.group_by_date(db_reader) + if self is GroupBy.WEEKDAY: + return online_streaks.group_by_weekday(db_reader) + if self is GroupBy.HOUR: + return online_streaks.group_by_hour(db_reader) + raise NotImplementedError('unsupported grouping: ' + str(self)) + + +_OUTPUT_USER_FIELDS = ( + UserField.UID, + UserField.FIRST_NAME, + UserField.LAST_NAME, + UserField.DOMAIN, +) + + +class OutputSinkOnlineSessions(metaclass=abc.ABCMeta): + @abc.abstractmethod + def process_database(self, group_by, db_reader, time_from=None, time_to=None): + pass + + +class OutputConverterCSV: + @staticmethod + def convert_user(user): + return [user[field] for field in _OUTPUT_USER_FIELDS] + + @staticmethod + def convert_date(date): + return [str(date)] + + @staticmethod + def convert_weekday(weekday): + return [str(weekday)] + + @staticmethod + def convert_hour(hour): + return [str(timedelta(hours=hour))] + + +class OutputSinkCSV(OutputSinkOnlineSessions): + def __init__(self, fd=sys.stdout): + self._writer = io.FileWriterCSV(fd) + + _CONVERT_KEY = { + GroupBy.USER: OutputConverterCSV.convert_user, + GroupBy.DATE: OutputConverterCSV.convert_date, + GroupBy.WEEKDAY: OutputConverterCSV.convert_weekday, + GroupBy.HOUR: OutputConverterCSV.convert_hour, + } + + @staticmethod + def _key_to_row(group_by, key): + if group_by not in OutputSinkCSV._CONVERT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(group_by)) + return OutputSinkCSV._CONVERT_KEY[group_by](key) + + def process_database(self, group_by, db_reader, time_from=None, time_to=None): + for key, duration in group_by.group(db_reader, time_from, time_to).items(): + row = self._key_to_row(group_by, key) + row.append(str(duration)) + self._writer.write_row(row) + + +class OutputConverterJSON: + _DATE_FIELD = 'date' + _WEEKDAY_FIELD = 'weekday' + _HOUR_FIELD = 'hour' + + assert _DATE_FIELD not in map(str, _OUTPUT_USER_FIELDS) + assert _WEEKDAY_FIELD not in map(str, _OUTPUT_USER_FIELDS) + assert _HOUR_FIELD not in map(str, _OUTPUT_USER_FIELDS) + + @staticmethod + def convert_user(user): + obj = OrderedDict() + for field in _OUTPUT_USER_FIELDS: + obj[str(field)] = user[field] + return obj + + @staticmethod + def convert_date(date): + obj = OrderedDict() + obj[OutputConverterJSON._DATE_FIELD] = str(date) + return obj + + @staticmethod + def convert_weekday(weekday): + obj = OrderedDict() + obj[OutputConverterJSON._WEEKDAY_FIELD] = str(weekday) + return obj + + @staticmethod + def convert_hour(hour): + obj = OrderedDict() + obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour)) + return obj + + +class OutputSinkJSON(OutputSinkOnlineSessions): + def __init__(self, fd=sys.stdout): + self._writer = io.FileWriterJSON(fd) + + _DURATION_FIELD = 'duration' + + assert _DURATION_FIELD not in map(str, _OUTPUT_USER_FIELDS) + + _CONVERT_KEY = { + GroupBy.USER: OutputConverterJSON.convert_user, + GroupBy.DATE: OutputConverterJSON.convert_date, + GroupBy.WEEKDAY: OutputConverterJSON.convert_weekday, + GroupBy.HOUR: OutputConverterJSON.convert_hour, + } + + @staticmethod + def _key_to_object(group_by, key): + if group_by not in OutputSinkJSON._CONVERT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(group_by)) + return OutputSinkJSON._CONVERT_KEY[group_by](key) + + def process_database(self, group_by, db_reader, time_from=None, time_to=None): + entries = [] + for key, duration in group_by.group(db_reader, time_from, time_to).items(): + entry = self._key_to_object(group_by, key) + entry[self._DURATION_FIELD] = str(duration) + entries.append(entry) + self._writer.write(entries) + + +class OutputConverterPlot: + @staticmethod + def convert_user(user): + return '{}\n{}'.format(user.get_first_name(), user.get_last_name()) + + @staticmethod + def convert_date(date): + return str(date) + + @staticmethod + def convert_weekday(weekday): + return str(weekday) + + @staticmethod + def convert_hour(hour): + return '{}:00'.format(hour) + + +class OutputSinkPlot(OutputSinkOnlineSessions): + def __init__(self, fd=sys.stdout): + self._fd = fd + + TITLE = 'How much time people spend online' + + _FORMAT_KEY = { + GroupBy.USER: OutputConverterPlot.convert_user, + GroupBy.DATE: OutputConverterPlot.convert_date, + GroupBy.WEEKDAY: OutputConverterPlot.convert_weekday, + GroupBy.HOUR: OutputConverterPlot.convert_hour, + } + + @staticmethod + def _format_key(group_by, key): + if group_by not in OutputSinkPlot._FORMAT_KEY: + raise NotImplementedError('unsupported grouping: ' + str(group_by)) + return OutputSinkPlot._FORMAT_KEY[group_by](key) + + @staticmethod + def _format_duration(seconds, _): + return str(timedelta(seconds=seconds)) + + @staticmethod + def _duration_to_seconds(td): + return td.total_seconds() + + @staticmethod + def _extract_labels(group_by, durations): + return (OutputSinkPlot._format_key(group_by, key) for key in durations.keys()) + + @staticmethod + def _extract_values(durations): + return (OutputSinkPlot._duration_to_seconds(duration) for duration in durations.values()) + + def process_database( + self, group_by, db_reader, time_from=None, time_to=None): + + durations = group_by.group(db_reader, time_from, time_to) + + bar_chart = BarChartBuilder() + bar_chart.set_title(OutputSinkPlot.TITLE) + bar_chart.enable_grid_for_values() + bar_chart.only_integer_values() + bar_chart.set_property(bar_chart.get_values_labels(), + fontsize='small', rotation=30) + bar_chart.set_value_label_formatter(self._format_duration) + + labels = tuple(self._extract_labels(group_by, durations)) + durations = tuple(self._extract_values(durations)) + + if group_by is GroupBy.HOUR: + bar_chart.labels_align_middle = False + bar_height = bar_chart.THIN_BAR_HEIGHT + else: + bar_height = bar_chart.THICK_BAR_HEIGHT + + bars = bar_chart.plot_bars( + labels, durations, bar_height=bar_height) + bar_chart.set_property(bars, alpha=.33) + + if self._fd is sys.stdout: + bar_chart.show() + else: + bar_chart.save(self._fd) + + +class OutputFormat(Enum): + CSV = 'csv' + JSON = 'json' + PLOT = 'plot' + + def __str__(self): + return self.value + + def create_sink(self, fd=sys.stdout): + if self is OutputFormat.CSV: + return OutputSinkCSV(fd) + if self is OutputFormat.JSON: + return OutputSinkJSON(fd) + if self is OutputFormat.PLOT: + return OutputSinkPlot(fd) + raise NotImplementedError('unsupported output format: ' + str(self)) + + def open_file(self, path=None): + if self is OutputFormat.PLOT: + return io.open_output_binary_file(path) + return io.open_output_text_file(path) + + +def _parse_group_by(s): + try: + return GroupBy(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid "group by" value: ' + s) + + +def _parse_database_format(s): + try: + return DatabaseFormat(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid database format: ' + s) + + +def _parse_output_format(s): + try: + return OutputFormat(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid output format: ' + s) + + +_DATE_RANGE_LIMIT_FORMAT = '%Y-%m-%dT%H:%M:%SZ' + + +def _parse_date_range_limit(s): + try: + dt = datetime.strptime(s, _DATE_RANGE_LIMIT_FORMAT) + return dt.replace(tzinfo=timezone.utc) + except ValueError: + msg = 'invalid date range limit (must be in the \'{}\' format): {}' + raise argparse.ArgumentTypeError( + msg.format(_DATE_RANGE_LIMIT_FORMAT, s)) + + +def _parse_args(args=None): + if args is None: + args = sys.argv[1:] + + parser = argparse.ArgumentParser( + description='View/visualize the amount of time people spend online.') + + parser.add_argument('db_path', metavar='input', nargs='?', + help='database file path (standard input by default)') + parser.add_argument('out_path', metavar='output', nargs='?', + help='output file path (standard output by default)') + parser.add_argument('-g', '--group-by', + type=_parse_group_by, + choices=GroupBy, + default=GroupBy.USER, + help='group online sessions by user/date/etc.') + parser.add_argument('-i', '--input-format', dest='db_fmt', + type=_parse_database_format, + default=DatabaseFormat.CSV, + choices=DatabaseFormat, + help='specify database format') + parser.add_argument('-o', '--output-format', dest='out_fmt', + type=_parse_output_format, + choices=OutputFormat, + default=OutputFormat.CSV, + help='specify output format') + parser.add_argument('-a', '--from', dest='time_from', + type=_parse_date_range_limit, default=None, + help='discard online activity prior to this moment') + parser.add_argument('-b', '--to', dest='time_to', + type=_parse_date_range_limit, default=None, + help='discard online activity after this moment') + + return parser.parse_args(args) + + +def process_online_sessions( + db_path=None, db_fmt=DatabaseFormat.CSV, + out_path=None, out_fmt=OutputFormat.CSV, + group_by=GroupBy.USER, + time_from=None, time_to=None): + + if time_from is not None and time_to is not None: + if time_from > time_to: + time_from, time_to = time_to, time_from + + with db_fmt.open_input_file(db_path) as db_fd: + db_reader = db_fmt.create_reader(db_fd) + with out_fmt.open_file(out_path) as out_fd: + out_sink = out_fmt.create_sink(out_fd) + out_sink.process_database( + group_by, db_reader, + time_from=time_from, + time_to=time_to) + + +def main(args=None): + process_online_sessions(**vars(_parse_args(args))) + + +if __name__ == '__main__': + main() diff --git a/vk/tracking/show_status.py b/vk/tracking/show_status.py new file mode 100644 index 0000000..cf59280 --- /dev/null +++ b/vk/tracking/show_status.py @@ -0,0 +1,44 @@ +# Copyright (c) 2019 Egor Tensin +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +import argparse +import sys + +from vk.api import API +from vk.tracking import StatusTracker +from vk.tracking.db import Format as DatabaseFormat + + +def _parse_args(args=None): + if args is None: + args = sys.argv[1:] + + parser = argparse.ArgumentParser( + description='Show if people are online/offline.') + + parser.add_argument('uids', metavar='UID', nargs='+', + help='user IDs or "screen names"') + parser.add_argument('-l', '--log', metavar='PATH', dest='log_path', + help='set log file path (standard output by default)') + + return parser.parse_args(args) + + +def track_status(uids, log_path=None): + api = API() + tracker = StatusTracker(api) + + with DatabaseFormat.LOG.open_output_file(log_path) as log_fd: + log_writer = DatabaseFormat.LOG.create_writer(log_fd) + tracker.add_database_writer(log_writer) + tracker.query_status(uids) + + +def main(args=None): + track_status(**vars(_parse_args(args))) + + +if __name__ == '__main__': + main() diff --git a/vk/tracking/track_status.py b/vk/tracking/track_status.py new file mode 100644 index 0000000..2a974a5 --- /dev/null +++ b/vk/tracking/track_status.py @@ -0,0 +1,87 @@ +# Copyright (c) 2015 Egor Tensin +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +import argparse +import sys + +from vk.api import API +from vk.tracking import StatusTracker +from vk.tracking.db import Format as DatabaseFormat + + +DEFAULT_TIMEOUT = StatusTracker.DEFAULT_TIMEOUT +DEFAULT_DB_FORMAT = DatabaseFormat.CSV + + +def _parse_positive_integer(s): + try: + n = int(s) + except ValueError: + raise argparse.ArgumentTypeError('must be a positive integer: ' + s) + if n < 1: + raise argparse.ArgumentTypeError('must be a positive integer: ' + s) + return n + + +def _parse_database_format(s): + try: + return DatabaseFormat(s) + except ValueError: + raise argparse.ArgumentTypeError('invalid database format: ' + s) + + +def _parse_args(args=None): + if args is None: + args = sys.argv[1:] + + parser = argparse.ArgumentParser( + description='Track when people go online/offline.') + + parser.add_argument('uids', metavar='UID', nargs='+', + help='user IDs or "screen names"') + parser.add_argument('-t', '--timeout', metavar='SECONDS', + type=_parse_positive_integer, + default=DEFAULT_TIMEOUT, + help='set refresh interval') + parser.add_argument('-l', '--log', metavar='PATH', dest='log_path', + help='set log file path (standard output by default)') + parser.add_argument('-f', '--format', dest='db_fmt', + type=_parse_database_format, + choices=DatabaseFormat, + default=DEFAULT_DB_FORMAT, + help='specify database format') + parser.add_argument('-o', '--output', metavar='PATH', dest='db_path', + help='set database file path') + + return parser.parse_args(args) + + +def track_status( + uids, timeout=DEFAULT_TIMEOUT, + log_path=None, + db_path=None, db_fmt=DEFAULT_DB_FORMAT): + + api = API() + tracker = StatusTracker(api, timeout) + + if db_fmt is DatabaseFormat.LOG or db_path is None: + db_fmt = DatabaseFormat.NULL + + with DatabaseFormat.LOG.open_output_file(log_path) as log_fd: + log_writer = DatabaseFormat.LOG.create_writer(log_fd) + tracker.add_database_writer(log_writer) + with db_fmt.open_output_file(db_path) as db_fd: + db_writer = db_fmt.create_writer(db_fd) + tracker.add_database_writer(db_writer) + + tracker.loop(uids) + + +def main(args=None): + track_status(**vars(_parse_args(args))) + + +if __name__ == '__main__': + main() diff --git a/vk/utils/__init__.py b/vk/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vk/utils/bar_chart.py b/vk/utils/bar_chart.py new file mode 100644 index 0000000..f051efc --- /dev/null +++ b/vk/utils/bar_chart.py @@ -0,0 +1,182 @@ +# Copyright (c) 2017 Egor Tensin +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +import matplotlib.pyplot as plt +from matplotlib import ticker +import numpy as np + + +class BarChartBuilder: + _BAR_HEIGHT = .5 + + THICK_BAR_HEIGHT = _BAR_HEIGHT + THIN_BAR_HEIGHT = THICK_BAR_HEIGHT / 2 + + def __init__(self, labels_align_middle=True): + self._fig, self._ax = plt.subplots() + self.labels_align_middle = labels_align_middle + + def set_title(self, title): + self._ax.set_title(title) + + def _get_categories_axis(self): + return self._ax.get_yaxis() + + def _get_values_axis(self): + return self._ax.get_xaxis() + + def set_categories_axis_limits(self, start=None, end=None): + bottom, top = self._ax.get_ylim() + if start is not None: + bottom = start + if end is not None: + top = end + self._ax.set_ylim(bottom=bottom, top=top) + + def set_values_axis_limits(self, start=None, end=None): + left, right = self._ax.get_xlim() + if start is not None: + left = start + if end is not None: + right = end + self._ax.set_xlim(left=left, right=right) + + def enable_grid_for_categories(self): + self._get_categories_axis().grid() + + def enable_grid_for_values(self): + self._get_values_axis().grid() + + def get_categories_labels(self): + return self._get_categories_axis().get_ticklabels() + + def get_values_labels(self): + return self._get_values_axis().get_ticklabels() + + def hide_categories(self): + self._get_categories_axis().set_major_locator(ticker.NullLocator()) + + def set_value_label_formatter(self, fn): + self._get_values_axis().set_major_formatter(ticker.FuncFormatter(fn)) + + def any_values(self): + self._get_values_axis().set_major_locator(ticker.AutoLocator()) + + def only_integer_values(self): + self._get_values_axis().set_major_locator(ticker.MaxNLocator(integer=True)) + + @staticmethod + def set_property(*args, **kwargs): + plt.setp(*args, **kwargs) + + def _set_size(self, inches, dim=0): + fig_size = self._fig.get_size_inches() + assert len(fig_size) == 2 + fig_size[dim] = inches + self._fig.set_size_inches(fig_size, forward=True) + + def set_width(self, inches): + self._set_size(inches) + + def set_height(self, inches): + self._set_size(inches, dim=1) + + _DEFAULT_VALUES_AXIS_MAX = 1 + assert _DEFAULT_VALUES_AXIS_MAX > 0 + + def plot_bars(self, categories, values, bar_height=THICK_BAR_HEIGHT): + numof_bars = len(categories) + inches_per_bar = 2 * bar_height + categories_axis_max = inches_per_bar * numof_bars + + if not numof_bars: + categories_axis_max += inches_per_bar + + self.set_height(categories_axis_max) + self.set_categories_axis_limits(0, categories_axis_max) + + if not numof_bars: + self.set_values_axis_limits(0, self._DEFAULT_VALUES_AXIS_MAX) + self.hide_categories() + return [] + + bar_offset = inches_per_bar / 2 + bar_offsets = inches_per_bar * np.arange(numof_bars) + bar_offset + + if self.labels_align_middle: + self._get_categories_axis().set_ticks(bar_offsets) + else: + self._get_categories_axis().set_ticks(bar_offsets - bar_offset) + + self._get_categories_axis().set_ticklabels(categories) + + bars = self._ax.barh(bar_offsets, values, align='center', + height=bar_height) + + if min(values) >= 0: + self.set_values_axis_limits(start=0) + if np.isclose(max(values), 0.): + self.set_values_axis_limits(end=self._DEFAULT_VALUES_AXIS_MAX) + elif max(values) < 0: + self.set_values_axis_limits(end=0) + + return bars + + @staticmethod + def show(): + plt.show() + + def save(self, path): + self._fig.savefig(path, bbox_inches='tight') + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument('--categories', nargs='*', metavar='LABEL', + default=[]) + parser.add_argument('--values', nargs='*', metavar='N', + default=[], type=float) + + parser.add_argument('--output', '-o', help='set output file path') + + parser.add_argument('--align-middle', action='store_true', + dest='labels_align_middle', + help='align labels to the middle of the bars') + + parser.add_argument('--integer-values', action='store_true', + dest='only_integer_values') + parser.add_argument('--any-values', action='store_false', + dest='only_integer_values') + + parser.add_argument('--grid-categories', action='store_true') + parser.add_argument('--grid-values', action='store_true') + + args = parser.parse_args() + + if len(args.categories) < len(args.values): + parser.error('too many bar values') + if len(args.categories) > len(args.values): + args.values.extend([0.] * (len(args.categories) - len(args.values))) + + builder = BarChartBuilder(labels_align_middle=args.labels_align_middle) + + if args.only_integer_values: + builder.only_integer_values() + else: + builder.any_values() + + if args.grid_categories: + builder.enable_grid_for_categories() + if args.grid_values: + builder.enable_grid_for_values() + + builder.plot_bars(args.categories, args.values) + + if args.output is None: + builder.show() + else: + builder.save(args.output) diff --git a/vk/utils/io.py b/vk/utils/io.py new file mode 100644 index 0000000..bb8eef9 --- /dev/null +++ b/vk/utils/io.py @@ -0,0 +1,51 @@ +# Copyright (c) 2017 Egor Tensin +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +from contextlib import contextmanager +import csv +import json +import sys + + +class FileWriterJSON: + def __init__(self, fd=sys.stdout): + self._fd = fd + + def write(self, something): + self._fd.write(json.dumps(something, indent=3, ensure_ascii=False)) + self._fd.write('\n') + + +class FileWriterCSV: + def __init__(self, fd=sys.stdout): + self._writer = csv.writer(fd, lineterminator='\n') + + @staticmethod + def _convert_row_old_python(row): + if isinstance(row, (list, tuple)): + return row + return list(row) + + def write_row(self, row): + if sys.version_info < (3, 5): + row = self._convert_row_old_python(row) + self._writer.writerow(row) + + +@contextmanager +def _open_file(path=None, default=None, **kwargs): + if path is None: + yield default + else: + with open(path, **kwargs) as fd: + yield fd + + +def open_output_text_file(path=None): + return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8') + + +def open_output_binary_file(path=None): + return _open_file(path, default=sys.stdout, mode='wb') -- cgit v1.2.3