diff options
Diffstat (limited to 'vk/utils/bar_chart.py')
-rw-r--r-- | vk/utils/bar_chart.py | 182 |
1 files changed, 182 insertions, 0 deletions
diff --git a/vk/utils/bar_chart.py b/vk/utils/bar_chart.py new file mode 100644 index 0000000..f051efc --- /dev/null +++ b/vk/utils/bar_chart.py @@ -0,0 +1,182 @@ +# Copyright (c) 2017 Egor Tensin <Egor.Tensin@gmail.com> +# This file is part of the "VK scripts" project. +# For details, see https://github.com/egor-tensin/vk-scripts. +# Distributed under the MIT License. + +import matplotlib.pyplot as plt +from matplotlib import ticker +import numpy as np + + +class BarChartBuilder: + _BAR_HEIGHT = .5 + + THICK_BAR_HEIGHT = _BAR_HEIGHT + THIN_BAR_HEIGHT = THICK_BAR_HEIGHT / 2 + + def __init__(self, labels_align_middle=True): + self._fig, self._ax = plt.subplots() + self.labels_align_middle = labels_align_middle + + def set_title(self, title): + self._ax.set_title(title) + + def _get_categories_axis(self): + return self._ax.get_yaxis() + + def _get_values_axis(self): + return self._ax.get_xaxis() + + def set_categories_axis_limits(self, start=None, end=None): + bottom, top = self._ax.get_ylim() + if start is not None: + bottom = start + if end is not None: + top = end + self._ax.set_ylim(bottom=bottom, top=top) + + def set_values_axis_limits(self, start=None, end=None): + left, right = self._ax.get_xlim() + if start is not None: + left = start + if end is not None: + right = end + self._ax.set_xlim(left=left, right=right) + + def enable_grid_for_categories(self): + self._get_categories_axis().grid() + + def enable_grid_for_values(self): + self._get_values_axis().grid() + + def get_categories_labels(self): + return self._get_categories_axis().get_ticklabels() + + def get_values_labels(self): + return self._get_values_axis().get_ticklabels() + + def hide_categories(self): + self._get_categories_axis().set_major_locator(ticker.NullLocator()) + + def set_value_label_formatter(self, fn): + self._get_values_axis().set_major_formatter(ticker.FuncFormatter(fn)) + + def any_values(self): + self._get_values_axis().set_major_locator(ticker.AutoLocator()) + + def only_integer_values(self): + self._get_values_axis().set_major_locator(ticker.MaxNLocator(integer=True)) + + @staticmethod + def set_property(*args, **kwargs): + plt.setp(*args, **kwargs) + + def _set_size(self, inches, dim=0): + fig_size = self._fig.get_size_inches() + assert len(fig_size) == 2 + fig_size[dim] = inches + self._fig.set_size_inches(fig_size, forward=True) + + def set_width(self, inches): + self._set_size(inches) + + def set_height(self, inches): + self._set_size(inches, dim=1) + + _DEFAULT_VALUES_AXIS_MAX = 1 + assert _DEFAULT_VALUES_AXIS_MAX > 0 + + def plot_bars(self, categories, values, bar_height=THICK_BAR_HEIGHT): + numof_bars = len(categories) + inches_per_bar = 2 * bar_height + categories_axis_max = inches_per_bar * numof_bars + + if not numof_bars: + categories_axis_max += inches_per_bar + + self.set_height(categories_axis_max) + self.set_categories_axis_limits(0, categories_axis_max) + + if not numof_bars: + self.set_values_axis_limits(0, self._DEFAULT_VALUES_AXIS_MAX) + self.hide_categories() + return [] + + bar_offset = inches_per_bar / 2 + bar_offsets = inches_per_bar * np.arange(numof_bars) + bar_offset + + if self.labels_align_middle: + self._get_categories_axis().set_ticks(bar_offsets) + else: + self._get_categories_axis().set_ticks(bar_offsets - bar_offset) + + self._get_categories_axis().set_ticklabels(categories) + + bars = self._ax.barh(bar_offsets, values, align='center', + height=bar_height) + + if min(values) >= 0: + self.set_values_axis_limits(start=0) + if np.isclose(max(values), 0.): + self.set_values_axis_limits(end=self._DEFAULT_VALUES_AXIS_MAX) + elif max(values) < 0: + self.set_values_axis_limits(end=0) + + return bars + + @staticmethod + def show(): + plt.show() + + def save(self, path): + self._fig.savefig(path, bbox_inches='tight') + + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument('--categories', nargs='*', metavar='LABEL', + default=[]) + parser.add_argument('--values', nargs='*', metavar='N', + default=[], type=float) + + parser.add_argument('--output', '-o', help='set output file path') + + parser.add_argument('--align-middle', action='store_true', + dest='labels_align_middle', + help='align labels to the middle of the bars') + + parser.add_argument('--integer-values', action='store_true', + dest='only_integer_values') + parser.add_argument('--any-values', action='store_false', + dest='only_integer_values') + + parser.add_argument('--grid-categories', action='store_true') + parser.add_argument('--grid-values', action='store_true') + + args = parser.parse_args() + + if len(args.categories) < len(args.values): + parser.error('too many bar values') + if len(args.categories) > len(args.values): + args.values.extend([0.] * (len(args.categories) - len(args.values))) + + builder = BarChartBuilder(labels_align_middle=args.labels_align_middle) + + if args.only_integer_values: + builder.only_integer_values() + else: + builder.any_values() + + if args.grid_categories: + builder.enable_grid_for_categories() + if args.grid_values: + builder.enable_grid_for_values() + + builder.plot_bars(args.categories, args.values) + + if args.output is None: + builder.show() + else: + builder.save(args.output) |