aboutsummaryrefslogtreecommitdiffstatshomepage
path: root/vk
diff options
context:
space:
mode:
Diffstat (limited to 'vk')
-rw-r--r--vk/mutuals.py118
-rw-r--r--vk/tracking/sessions.py366
-rw-r--r--vk/tracking/show_status.py44
-rw-r--r--vk/tracking/track_status.py87
-rw-r--r--vk/utils/__init__.py0
-rw-r--r--vk/utils/bar_chart.py182
-rw-r--r--vk/utils/io.py51
7 files changed, 848 insertions, 0 deletions
diff --git a/vk/mutuals.py b/vk/mutuals.py
new file mode 100644
index 0000000..644baf8
--- /dev/null
+++ b/vk/mutuals.py
@@ -0,0 +1,118 @@
+# Copyright (c) 2015 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+import abc
+import argparse
+from collections import OrderedDict
+from enum import Enum
+import sys
+
+from vk.api import API
+from vk.user import UserField
+from vk.utils import io
+
+
+_OUTPUT_USER_FIELDS = UserField.UID, UserField.FIRST_NAME, UserField.LAST_NAME
+
+
+def _query_friend_list(api, user):
+ return api.friends_get(user.get_uid(), fields=_OUTPUT_USER_FIELDS)
+
+
+def _filter_user_fields(user):
+ new_user = OrderedDict()
+ for field in _OUTPUT_USER_FIELDS:
+ new_user[str(field)] = user[field] if field in user else None
+ return new_user
+
+
+class OutputSinkMutualFriends(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def write_mutual_friends(self, friend_list):
+ pass
+
+
+class OutputSinkCSV(OutputSinkMutualFriends):
+ def __init__(self, fd=sys.stdout):
+ self._writer = io.FileWriterCSV(fd)
+
+ def write_mutual_friends(self, friend_list):
+ for user in friend_list:
+ self._writer.write_row(user.values())
+
+
+class OutputSinkJSON(OutputSinkMutualFriends):
+ def __init__(self, fd=sys.stdout):
+ self._writer = io.FileWriterJSON(fd)
+
+ def write_mutual_friends(self, friend_list):
+ self._writer.write(friend_list)
+
+
+class OutputFormat(Enum):
+ CSV = 'csv'
+ JSON = 'json'
+
+ def __str__(self):
+ return self.value
+
+ @staticmethod
+ def open_file(path=None):
+ return io.open_output_text_file(path)
+
+ def create_sink(self, fd=sys.stdout):
+ if self is OutputFormat.CSV:
+ return OutputSinkCSV(fd)
+ if self is OutputFormat.JSON:
+ return OutputSinkJSON(fd)
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
+
+def _parse_output_format(s):
+ try:
+ return OutputFormat(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('invalid output format: ' + s)
+
+
+def _parse_args(args=None):
+ if args is None:
+ args = sys.argv[1:]
+
+ parser = argparse.ArgumentParser(
+ description='Learn who your ex and her new boyfriend are both friends with.')
+
+ parser.add_argument('uids', metavar='UID', nargs='+',
+ help='user IDs or "screen names"')
+ parser.add_argument('-f', '--format', dest='out_fmt',
+ type=_parse_output_format,
+ default=OutputFormat.CSV,
+ choices=OutputFormat,
+ help='specify output format')
+ parser.add_argument('-o', '--output', metavar='PATH', dest='out_path',
+ help='set output file path (standard output by default)')
+
+ return parser.parse_args(args)
+
+
+def write_mutual_friends(uids, out_path=None, out_fmt=OutputFormat.CSV):
+ api = API()
+ users = api.users_get(uids)
+
+ friend_lists = (frozenset(_query_friend_list(api, user)) for user in users)
+ mutual_friends = frozenset.intersection(*friend_lists)
+ mutual_friends = [_filter_user_fields(user) for user in mutual_friends]
+
+ with out_fmt.open_file(out_path) as out_fd:
+ sink = out_fmt.create_sink(out_fd)
+ sink.write_mutual_friends(mutual_friends)
+
+
+def main(args=None):
+ write_mutual_friends(**vars(_parse_args(args)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/vk/tracking/sessions.py b/vk/tracking/sessions.py
new file mode 100644
index 0000000..dd8b32f
--- /dev/null
+++ b/vk/tracking/sessions.py
@@ -0,0 +1,366 @@
+# Copyright (c) 2016 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+import abc
+import argparse
+from collections import OrderedDict
+from datetime import datetime, timedelta, timezone
+from enum import Enum
+import sys
+
+from vk.tracking import OnlineSessionEnumerator
+from vk.tracking.db import Format as DatabaseFormat
+from vk.user import UserField
+from vk.utils.bar_chart import BarChartBuilder
+from vk.utils import io
+
+
+class GroupBy(Enum):
+ USER = 'user'
+ DATE = 'date'
+ WEEKDAY = 'weekday'
+ HOUR = 'hour'
+
+ def __str__(self):
+ return self.value
+
+ def group(self, db_reader, time_from=None, time_to=None):
+ online_streaks = OnlineSessionEnumerator(time_from, time_to)
+ if self is GroupBy.USER:
+ return online_streaks.group_by_user(db_reader)
+ if self is GroupBy.DATE:
+ return online_streaks.group_by_date(db_reader)
+ if self is GroupBy.WEEKDAY:
+ return online_streaks.group_by_weekday(db_reader)
+ if self is GroupBy.HOUR:
+ return online_streaks.group_by_hour(db_reader)
+ raise NotImplementedError('unsupported grouping: ' + str(self))
+
+
+_OUTPUT_USER_FIELDS = (
+ UserField.UID,
+ UserField.FIRST_NAME,
+ UserField.LAST_NAME,
+ UserField.DOMAIN,
+)
+
+
+class OutputSinkOnlineSessions(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ def process_database(self, group_by, db_reader, time_from=None, time_to=None):
+ pass
+
+
+class OutputConverterCSV:
+ @staticmethod
+ def convert_user(user):
+ return [user[field] for field in _OUTPUT_USER_FIELDS]
+
+ @staticmethod
+ def convert_date(date):
+ return [str(date)]
+
+ @staticmethod
+ def convert_weekday(weekday):
+ return [str(weekday)]
+
+ @staticmethod
+ def convert_hour(hour):
+ return [str(timedelta(hours=hour))]
+
+
+class OutputSinkCSV(OutputSinkOnlineSessions):
+ def __init__(self, fd=sys.stdout):
+ self._writer = io.FileWriterCSV(fd)
+
+ _CONVERT_KEY = {
+ GroupBy.USER: OutputConverterCSV.convert_user,
+ GroupBy.DATE: OutputConverterCSV.convert_date,
+ GroupBy.WEEKDAY: OutputConverterCSV.convert_weekday,
+ GroupBy.HOUR: OutputConverterCSV.convert_hour,
+ }
+
+ @staticmethod
+ def _key_to_row(group_by, key):
+ if group_by not in OutputSinkCSV._CONVERT_KEY:
+ raise NotImplementedError('unsupported grouping: ' + str(group_by))
+ return OutputSinkCSV._CONVERT_KEY[group_by](key)
+
+ def process_database(self, group_by, db_reader, time_from=None, time_to=None):
+ for key, duration in group_by.group(db_reader, time_from, time_to).items():
+ row = self._key_to_row(group_by, key)
+ row.append(str(duration))
+ self._writer.write_row(row)
+
+
+class OutputConverterJSON:
+ _DATE_FIELD = 'date'
+ _WEEKDAY_FIELD = 'weekday'
+ _HOUR_FIELD = 'hour'
+
+ assert _DATE_FIELD not in map(str, _OUTPUT_USER_FIELDS)
+ assert _WEEKDAY_FIELD not in map(str, _OUTPUT_USER_FIELDS)
+ assert _HOUR_FIELD not in map(str, _OUTPUT_USER_FIELDS)
+
+ @staticmethod
+ def convert_user(user):
+ obj = OrderedDict()
+ for field in _OUTPUT_USER_FIELDS:
+ obj[str(field)] = user[field]
+ return obj
+
+ @staticmethod
+ def convert_date(date):
+ obj = OrderedDict()
+ obj[OutputConverterJSON._DATE_FIELD] = str(date)
+ return obj
+
+ @staticmethod
+ def convert_weekday(weekday):
+ obj = OrderedDict()
+ obj[OutputConverterJSON._WEEKDAY_FIELD] = str(weekday)
+ return obj
+
+ @staticmethod
+ def convert_hour(hour):
+ obj = OrderedDict()
+ obj[OutputConverterJSON._HOUR_FIELD] = str(timedelta(hours=hour))
+ return obj
+
+
+class OutputSinkJSON(OutputSinkOnlineSessions):
+ def __init__(self, fd=sys.stdout):
+ self._writer = io.FileWriterJSON(fd)
+
+ _DURATION_FIELD = 'duration'
+
+ assert _DURATION_FIELD not in map(str, _OUTPUT_USER_FIELDS)
+
+ _CONVERT_KEY = {
+ GroupBy.USER: OutputConverterJSON.convert_user,
+ GroupBy.DATE: OutputConverterJSON.convert_date,
+ GroupBy.WEEKDAY: OutputConverterJSON.convert_weekday,
+ GroupBy.HOUR: OutputConverterJSON.convert_hour,
+ }
+
+ @staticmethod
+ def _key_to_object(group_by, key):
+ if group_by not in OutputSinkJSON._CONVERT_KEY:
+ raise NotImplementedError('unsupported grouping: ' + str(group_by))
+ return OutputSinkJSON._CONVERT_KEY[group_by](key)
+
+ def process_database(self, group_by, db_reader, time_from=None, time_to=None):
+ entries = []
+ for key, duration in group_by.group(db_reader, time_from, time_to).items():
+ entry = self._key_to_object(group_by, key)
+ entry[self._DURATION_FIELD] = str(duration)
+ entries.append(entry)
+ self._writer.write(entries)
+
+
+class OutputConverterPlot:
+ @staticmethod
+ def convert_user(user):
+ return '{}\n{}'.format(user.get_first_name(), user.get_last_name())
+
+ @staticmethod
+ def convert_date(date):
+ return str(date)
+
+ @staticmethod
+ def convert_weekday(weekday):
+ return str(weekday)
+
+ @staticmethod
+ def convert_hour(hour):
+ return '{}:00'.format(hour)
+
+
+class OutputSinkPlot(OutputSinkOnlineSessions):
+ def __init__(self, fd=sys.stdout):
+ self._fd = fd
+
+ TITLE = 'How much time people spend online'
+
+ _FORMAT_KEY = {
+ GroupBy.USER: OutputConverterPlot.convert_user,
+ GroupBy.DATE: OutputConverterPlot.convert_date,
+ GroupBy.WEEKDAY: OutputConverterPlot.convert_weekday,
+ GroupBy.HOUR: OutputConverterPlot.convert_hour,
+ }
+
+ @staticmethod
+ def _format_key(group_by, key):
+ if group_by not in OutputSinkPlot._FORMAT_KEY:
+ raise NotImplementedError('unsupported grouping: ' + str(group_by))
+ return OutputSinkPlot._FORMAT_KEY[group_by](key)
+
+ @staticmethod
+ def _format_duration(seconds, _):
+ return str(timedelta(seconds=seconds))
+
+ @staticmethod
+ def _duration_to_seconds(td):
+ return td.total_seconds()
+
+ @staticmethod
+ def _extract_labels(group_by, durations):
+ return (OutputSinkPlot._format_key(group_by, key) for key in durations.keys())
+
+ @staticmethod
+ def _extract_values(durations):
+ return (OutputSinkPlot._duration_to_seconds(duration) for duration in durations.values())
+
+ def process_database(
+ self, group_by, db_reader, time_from=None, time_to=None):
+
+ durations = group_by.group(db_reader, time_from, time_to)
+
+ bar_chart = BarChartBuilder()
+ bar_chart.set_title(OutputSinkPlot.TITLE)
+ bar_chart.enable_grid_for_values()
+ bar_chart.only_integer_values()
+ bar_chart.set_property(bar_chart.get_values_labels(),
+ fontsize='small', rotation=30)
+ bar_chart.set_value_label_formatter(self._format_duration)
+
+ labels = tuple(self._extract_labels(group_by, durations))
+ durations = tuple(self._extract_values(durations))
+
+ if group_by is GroupBy.HOUR:
+ bar_chart.labels_align_middle = False
+ bar_height = bar_chart.THIN_BAR_HEIGHT
+ else:
+ bar_height = bar_chart.THICK_BAR_HEIGHT
+
+ bars = bar_chart.plot_bars(
+ labels, durations, bar_height=bar_height)
+ bar_chart.set_property(bars, alpha=.33)
+
+ if self._fd is sys.stdout:
+ bar_chart.show()
+ else:
+ bar_chart.save(self._fd)
+
+
+class OutputFormat(Enum):
+ CSV = 'csv'
+ JSON = 'json'
+ PLOT = 'plot'
+
+ def __str__(self):
+ return self.value
+
+ def create_sink(self, fd=sys.stdout):
+ if self is OutputFormat.CSV:
+ return OutputSinkCSV(fd)
+ if self is OutputFormat.JSON:
+ return OutputSinkJSON(fd)
+ if self is OutputFormat.PLOT:
+ return OutputSinkPlot(fd)
+ raise NotImplementedError('unsupported output format: ' + str(self))
+
+ def open_file(self, path=None):
+ if self is OutputFormat.PLOT:
+ return io.open_output_binary_file(path)
+ return io.open_output_text_file(path)
+
+
+def _parse_group_by(s):
+ try:
+ return GroupBy(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('invalid "group by" value: ' + s)
+
+
+def _parse_database_format(s):
+ try:
+ return DatabaseFormat(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('invalid database format: ' + s)
+
+
+def _parse_output_format(s):
+ try:
+ return OutputFormat(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('invalid output format: ' + s)
+
+
+_DATE_RANGE_LIMIT_FORMAT = '%Y-%m-%dT%H:%M:%SZ'
+
+
+def _parse_date_range_limit(s):
+ try:
+ dt = datetime.strptime(s, _DATE_RANGE_LIMIT_FORMAT)
+ return dt.replace(tzinfo=timezone.utc)
+ except ValueError:
+ msg = 'invalid date range limit (must be in the \'{}\' format): {}'
+ raise argparse.ArgumentTypeError(
+ msg.format(_DATE_RANGE_LIMIT_FORMAT, s))
+
+
+def _parse_args(args=None):
+ if args is None:
+ args = sys.argv[1:]
+
+ parser = argparse.ArgumentParser(
+ description='View/visualize the amount of time people spend online.')
+
+ parser.add_argument('db_path', metavar='input', nargs='?',
+ help='database file path (standard input by default)')
+ parser.add_argument('out_path', metavar='output', nargs='?',
+ help='output file path (standard output by default)')
+ parser.add_argument('-g', '--group-by',
+ type=_parse_group_by,
+ choices=GroupBy,
+ default=GroupBy.USER,
+ help='group online sessions by user/date/etc.')
+ parser.add_argument('-i', '--input-format', dest='db_fmt',
+ type=_parse_database_format,
+ default=DatabaseFormat.CSV,
+ choices=DatabaseFormat,
+ help='specify database format')
+ parser.add_argument('-o', '--output-format', dest='out_fmt',
+ type=_parse_output_format,
+ choices=OutputFormat,
+ default=OutputFormat.CSV,
+ help='specify output format')
+ parser.add_argument('-a', '--from', dest='time_from',
+ type=_parse_date_range_limit, default=None,
+ help='discard online activity prior to this moment')
+ parser.add_argument('-b', '--to', dest='time_to',
+ type=_parse_date_range_limit, default=None,
+ help='discard online activity after this moment')
+
+ return parser.parse_args(args)
+
+
+def process_online_sessions(
+ db_path=None, db_fmt=DatabaseFormat.CSV,
+ out_path=None, out_fmt=OutputFormat.CSV,
+ group_by=GroupBy.USER,
+ time_from=None, time_to=None):
+
+ if time_from is not None and time_to is not None:
+ if time_from > time_to:
+ time_from, time_to = time_to, time_from
+
+ with db_fmt.open_input_file(db_path) as db_fd:
+ db_reader = db_fmt.create_reader(db_fd)
+ with out_fmt.open_file(out_path) as out_fd:
+ out_sink = out_fmt.create_sink(out_fd)
+ out_sink.process_database(
+ group_by, db_reader,
+ time_from=time_from,
+ time_to=time_to)
+
+
+def main(args=None):
+ process_online_sessions(**vars(_parse_args(args)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/vk/tracking/show_status.py b/vk/tracking/show_status.py
new file mode 100644
index 0000000..cf59280
--- /dev/null
+++ b/vk/tracking/show_status.py
@@ -0,0 +1,44 @@
+# Copyright (c) 2019 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+import argparse
+import sys
+
+from vk.api import API
+from vk.tracking import StatusTracker
+from vk.tracking.db import Format as DatabaseFormat
+
+
+def _parse_args(args=None):
+ if args is None:
+ args = sys.argv[1:]
+
+ parser = argparse.ArgumentParser(
+ description='Show if people are online/offline.')
+
+ parser.add_argument('uids', metavar='UID', nargs='+',
+ help='user IDs or "screen names"')
+ parser.add_argument('-l', '--log', metavar='PATH', dest='log_path',
+ help='set log file path (standard output by default)')
+
+ return parser.parse_args(args)
+
+
+def track_status(uids, log_path=None):
+ api = API()
+ tracker = StatusTracker(api)
+
+ with DatabaseFormat.LOG.open_output_file(log_path) as log_fd:
+ log_writer = DatabaseFormat.LOG.create_writer(log_fd)
+ tracker.add_database_writer(log_writer)
+ tracker.query_status(uids)
+
+
+def main(args=None):
+ track_status(**vars(_parse_args(args)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/vk/tracking/track_status.py b/vk/tracking/track_status.py
new file mode 100644
index 0000000..2a974a5
--- /dev/null
+++ b/vk/tracking/track_status.py
@@ -0,0 +1,87 @@
+# Copyright (c) 2015 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+import argparse
+import sys
+
+from vk.api import API
+from vk.tracking import StatusTracker
+from vk.tracking.db import Format as DatabaseFormat
+
+
+DEFAULT_TIMEOUT = StatusTracker.DEFAULT_TIMEOUT
+DEFAULT_DB_FORMAT = DatabaseFormat.CSV
+
+
+def _parse_positive_integer(s):
+ try:
+ n = int(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('must be a positive integer: ' + s)
+ if n < 1:
+ raise argparse.ArgumentTypeError('must be a positive integer: ' + s)
+ return n
+
+
+def _parse_database_format(s):
+ try:
+ return DatabaseFormat(s)
+ except ValueError:
+ raise argparse.ArgumentTypeError('invalid database format: ' + s)
+
+
+def _parse_args(args=None):
+ if args is None:
+ args = sys.argv[1:]
+
+ parser = argparse.ArgumentParser(
+ description='Track when people go online/offline.')
+
+ parser.add_argument('uids', metavar='UID', nargs='+',
+ help='user IDs or "screen names"')
+ parser.add_argument('-t', '--timeout', metavar='SECONDS',
+ type=_parse_positive_integer,
+ default=DEFAULT_TIMEOUT,
+ help='set refresh interval')
+ parser.add_argument('-l', '--log', metavar='PATH', dest='log_path',
+ help='set log file path (standard output by default)')
+ parser.add_argument('-f', '--format', dest='db_fmt',
+ type=_parse_database_format,
+ choices=DatabaseFormat,
+ default=DEFAULT_DB_FORMAT,
+ help='specify database format')
+ parser.add_argument('-o', '--output', metavar='PATH', dest='db_path',
+ help='set database file path')
+
+ return parser.parse_args(args)
+
+
+def track_status(
+ uids, timeout=DEFAULT_TIMEOUT,
+ log_path=None,
+ db_path=None, db_fmt=DEFAULT_DB_FORMAT):
+
+ api = API()
+ tracker = StatusTracker(api, timeout)
+
+ if db_fmt is DatabaseFormat.LOG or db_path is None:
+ db_fmt = DatabaseFormat.NULL
+
+ with DatabaseFormat.LOG.open_output_file(log_path) as log_fd:
+ log_writer = DatabaseFormat.LOG.create_writer(log_fd)
+ tracker.add_database_writer(log_writer)
+ with db_fmt.open_output_file(db_path) as db_fd:
+ db_writer = db_fmt.create_writer(db_fd)
+ tracker.add_database_writer(db_writer)
+
+ tracker.loop(uids)
+
+
+def main(args=None):
+ track_status(**vars(_parse_args(args)))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/vk/utils/__init__.py b/vk/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/vk/utils/__init__.py
diff --git a/vk/utils/bar_chart.py b/vk/utils/bar_chart.py
new file mode 100644
index 0000000..f051efc
--- /dev/null
+++ b/vk/utils/bar_chart.py
@@ -0,0 +1,182 @@
+# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+import matplotlib.pyplot as plt
+from matplotlib import ticker
+import numpy as np
+
+
+class BarChartBuilder:
+ _BAR_HEIGHT = .5
+
+ THICK_BAR_HEIGHT = _BAR_HEIGHT
+ THIN_BAR_HEIGHT = THICK_BAR_HEIGHT / 2
+
+ def __init__(self, labels_align_middle=True):
+ self._fig, self._ax = plt.subplots()
+ self.labels_align_middle = labels_align_middle
+
+ def set_title(self, title):
+ self._ax.set_title(title)
+
+ def _get_categories_axis(self):
+ return self._ax.get_yaxis()
+
+ def _get_values_axis(self):
+ return self._ax.get_xaxis()
+
+ def set_categories_axis_limits(self, start=None, end=None):
+ bottom, top = self._ax.get_ylim()
+ if start is not None:
+ bottom = start
+ if end is not None:
+ top = end
+ self._ax.set_ylim(bottom=bottom, top=top)
+
+ def set_values_axis_limits(self, start=None, end=None):
+ left, right = self._ax.get_xlim()
+ if start is not None:
+ left = start
+ if end is not None:
+ right = end
+ self._ax.set_xlim(left=left, right=right)
+
+ def enable_grid_for_categories(self):
+ self._get_categories_axis().grid()
+
+ def enable_grid_for_values(self):
+ self._get_values_axis().grid()
+
+ def get_categories_labels(self):
+ return self._get_categories_axis().get_ticklabels()
+
+ def get_values_labels(self):
+ return self._get_values_axis().get_ticklabels()
+
+ def hide_categories(self):
+ self._get_categories_axis().set_major_locator(ticker.NullLocator())
+
+ def set_value_label_formatter(self, fn):
+ self._get_values_axis().set_major_formatter(ticker.FuncFormatter(fn))
+
+ def any_values(self):
+ self._get_values_axis().set_major_locator(ticker.AutoLocator())
+
+ def only_integer_values(self):
+ self._get_values_axis().set_major_locator(ticker.MaxNLocator(integer=True))
+
+ @staticmethod
+ def set_property(*args, **kwargs):
+ plt.setp(*args, **kwargs)
+
+ def _set_size(self, inches, dim=0):
+ fig_size = self._fig.get_size_inches()
+ assert len(fig_size) == 2
+ fig_size[dim] = inches
+ self._fig.set_size_inches(fig_size, forward=True)
+
+ def set_width(self, inches):
+ self._set_size(inches)
+
+ def set_height(self, inches):
+ self._set_size(inches, dim=1)
+
+ _DEFAULT_VALUES_AXIS_MAX = 1
+ assert _DEFAULT_VALUES_AXIS_MAX > 0
+
+ def plot_bars(self, categories, values, bar_height=THICK_BAR_HEIGHT):
+ numof_bars = len(categories)
+ inches_per_bar = 2 * bar_height
+ categories_axis_max = inches_per_bar * numof_bars
+
+ if not numof_bars:
+ categories_axis_max += inches_per_bar
+
+ self.set_height(categories_axis_max)
+ self.set_categories_axis_limits(0, categories_axis_max)
+
+ if not numof_bars:
+ self.set_values_axis_limits(0, self._DEFAULT_VALUES_AXIS_MAX)
+ self.hide_categories()
+ return []
+
+ bar_offset = inches_per_bar / 2
+ bar_offsets = inches_per_bar * np.arange(numof_bars) + bar_offset
+
+ if self.labels_align_middle:
+ self._get_categories_axis().set_ticks(bar_offsets)
+ else:
+ self._get_categories_axis().set_ticks(bar_offsets - bar_offset)
+
+ self._get_categories_axis().set_ticklabels(categories)
+
+ bars = self._ax.barh(bar_offsets, values, align='center',
+ height=bar_height)
+
+ if min(values) >= 0:
+ self.set_values_axis_limits(start=0)
+ if np.isclose(max(values), 0.):
+ self.set_values_axis_limits(end=self._DEFAULT_VALUES_AXIS_MAX)
+ elif max(values) < 0:
+ self.set_values_axis_limits(end=0)
+
+ return bars
+
+ @staticmethod
+ def show():
+ plt.show()
+
+ def save(self, path):
+ self._fig.savefig(path, bbox_inches='tight')
+
+
+if __name__ == '__main__':
+ import argparse
+ parser = argparse.ArgumentParser()
+
+ parser.add_argument('--categories', nargs='*', metavar='LABEL',
+ default=[])
+ parser.add_argument('--values', nargs='*', metavar='N',
+ default=[], type=float)
+
+ parser.add_argument('--output', '-o', help='set output file path')
+
+ parser.add_argument('--align-middle', action='store_true',
+ dest='labels_align_middle',
+ help='align labels to the middle of the bars')
+
+ parser.add_argument('--integer-values', action='store_true',
+ dest='only_integer_values')
+ parser.add_argument('--any-values', action='store_false',
+ dest='only_integer_values')
+
+ parser.add_argument('--grid-categories', action='store_true')
+ parser.add_argument('--grid-values', action='store_true')
+
+ args = parser.parse_args()
+
+ if len(args.categories) < len(args.values):
+ parser.error('too many bar values')
+ if len(args.categories) > len(args.values):
+ args.values.extend([0.] * (len(args.categories) - len(args.values)))
+
+ builder = BarChartBuilder(labels_align_middle=args.labels_align_middle)
+
+ if args.only_integer_values:
+ builder.only_integer_values()
+ else:
+ builder.any_values()
+
+ if args.grid_categories:
+ builder.enable_grid_for_categories()
+ if args.grid_values:
+ builder.enable_grid_for_values()
+
+ builder.plot_bars(args.categories, args.values)
+
+ if args.output is None:
+ builder.show()
+ else:
+ builder.save(args.output)
diff --git a/vk/utils/io.py b/vk/utils/io.py
new file mode 100644
index 0000000..bb8eef9
--- /dev/null
+++ b/vk/utils/io.py
@@ -0,0 +1,51 @@
+# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com>
+# This file is part of the "VK scripts" project.
+# For details, see https://github.com/egor-tensin/vk-scripts.
+# Distributed under the MIT License.
+
+from contextlib import contextmanager
+import csv
+import json
+import sys
+
+
+class FileWriterJSON:
+ def __init__(self, fd=sys.stdout):
+ self._fd = fd
+
+ def write(self, something):
+ self._fd.write(json.dumps(something, indent=3, ensure_ascii=False))
+ self._fd.write('\n')
+
+
+class FileWriterCSV:
+ def __init__(self, fd=sys.stdout):
+ self._writer = csv.writer(fd, lineterminator='\n')
+
+ @staticmethod
+ def _convert_row_old_python(row):
+ if isinstance(row, (list, tuple)):
+ return row
+ return list(row)
+
+ def write_row(self, row):
+ if sys.version_info < (3, 5):
+ row = self._convert_row_old_python(row)
+ self._writer.writerow(row)
+
+
+@contextmanager
+def _open_file(path=None, default=None, **kwargs):
+ if path is None:
+ yield default
+ else:
+ with open(path, **kwargs) as fd:
+ yield fd
+
+
+def open_output_text_file(path=None):
+ return _open_file(path, default=sys.stdout, mode='w', encoding='utf-8')
+
+
+def open_output_binary_file(path=None):
+ return _open_file(path, default=sys.stdout, mode='wb')